assistant: Make scripting a first-class concept instead of a tool (#26338)

This PR makes refactors the scripting functionality to be a first-class
concept of the assistant instead of a generic tool, which will allow us
to build a more customized experience.

- The tool prompt has been slightly tweaked and is now included as a
system message in all conversations. I'm getting decent results, but now
that it isn't in the tools framework, it will probably require more
refining.

- The model will now include an `<eval ...>` tag at the end of the
message with the script. We parse this tag incrementally as it streams
in so that we can indicate that we are generating a script before we see
the closing `</eval>` tag. Later, this will help us interpret the script
as it arrives also.

- Threads now hold a `ScriptSession` entity which manages the state of
all scripts (from parsing to exited) in a centralized way, and will
later collect all script operations so they can be displayed in the UI.

- `script_tool` has been renamed to `assistant_scripting` 

- Script source now opens in a regular read-only buffer  

Note: We still need to handle persistence properly

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Agus Zubiaga 2025-03-09 06:01:49 -03:00 committed by GitHub
parent ed6bf7f161
commit e298301b40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 811 additions and 197 deletions

View file

@ -1,11 +1,14 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_scripting::{
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
};
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@ -75,14 +78,21 @@ pub struct Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tool_use: ToolUseState,
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
script_output_messages: HashSet<MessageId>,
script_session: Entity<ScriptSession>,
_script_session_subscription: Subscription,
}
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Self {
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@ -97,6 +107,10 @@ impl Thread {
project,
tools,
tool_use: ToolUseState::new(),
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
}
}
@ -105,7 +119,7 @@ impl Thread {
saved: SavedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
_cx: &mut Context<Self>,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
saved
@ -115,6 +129,8 @@ impl Thread {
.unwrap_or(0),
);
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
Self {
id,
@ -138,6 +154,10 @@ impl Thread {
project,
tools,
tool_use,
scripts_by_assistant_message: HashMap::default(),
script_output_messages: HashSet::default(),
script_session,
_script_session_subscription: script_session_subscription,
}
}
@ -223,17 +243,22 @@ impl Thread {
self.tool_use.message_has_tool_results(message_id)
}
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
self.script_output_messages.contains(&message_id)
}
pub fn insert_user_message(
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
cx: &mut Context<Self>,
) {
) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
message_id
}
pub fn insert_message(
@ -302,6 +327,39 @@ impl Thread {
text
}
pub fn script_for_message<'a>(
&'a self,
message_id: MessageId,
cx: &'a App,
) -> Option<&'a Script> {
self.scripts_by_assistant_message
.get(&message_id)
.map(|script_id| self.script_session.read(cx).get(*script_id))
}
fn handle_script_event(
&mut self,
_script_session: Entity<ScriptSession>,
event: &ScriptEvent,
cx: &mut Context<Self>,
) {
match event {
ScriptEvent::Spawned(_) => {}
ScriptEvent::Exited(script_id) => {
if let Some(output_message) = self
.script_session
.read(cx)
.get(*script_id)
.output_message_for_llm()
{
let message_id = self.insert_user_message(output_message, vec![], cx);
self.script_output_messages.insert(message_id);
cx.emit(ThreadEvent::ScriptFinished)
}
}
}
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@ -330,7 +388,7 @@ impl Thread {
pub fn to_completion_request(
&self,
request_kind: RequestKind,
_cx: &App,
cx: &App,
) -> LanguageModelRequest {
let mut request = LanguageModelRequest {
messages: vec![],
@ -339,6 +397,12 @@ impl Thread {
temperature: None,
};
request.messages.push(LanguageModelRequestMessage {
role: Role::System,
content: vec![SCRIPTING_PROMPT.to_string().into()],
cache: true,
});
let mut referenced_context_ids = HashSet::default();
for message in &self.messages {
@ -351,6 +415,7 @@ impl Thread {
content: Vec::new(),
cache: false,
};
match request_kind {
RequestKind::Chat => {
self.tool_use
@ -371,11 +436,20 @@ impl Thread {
RequestKind::Chat => {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
if matches!(message.role, Role::Assistant) {
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
{
let script = self.script_session.read(cx).get(*script_id);
request_message.content.push(script.source_tag().into());
}
}
}
RequestKind::Summarize => {
// We don't care about tool use during summarization.
}
}
};
request.messages.push(request_message);
}
@ -412,6 +486,8 @@ impl Thread {
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
let mut script_tag_parser = ScriptTagParser::new();
let mut script_id = None;
while let Some(event) = events.next().await {
let event = event?;
@ -426,19 +502,43 @@ impl Thread {
}
LanguageModelCompletionEvent::Text(chunk) => {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk);
let chunk = script_tag_parser.parse_chunk(&chunk);
let message_id = if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk.content);
cx.emit(ThreadEvent::StreamedAssistantText(
last_message.id,
chunk,
chunk.content,
));
last_message.id
} else {
// If we won't have an Assistant message yet, assume this chunk marks the beginning
// of a new Assistant response.
//
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
// will result in duplicating the text of the chunk in the rendered Markdown.
thread.insert_message(Role::Assistant, chunk, cx);
thread.insert_message(Role::Assistant, chunk.content, cx)
};
if script_id.is_none() && script_tag_parser.found_script() {
let id = thread
.script_session
.update(cx, |session, _cx| session.new_script());
thread.scripts_by_assistant_message.insert(message_id, id);
script_id = Some(id);
}
if let (Some(script_source), Some(script_id)) =
(chunk.script_source, script_id)
{
// TODO: move buffer to script and run as it streams
thread
.script_session
.update(cx, |this, cx| {
this.run_script(script_id, script_source, cx)
})
.detach_and_log_err(cx);
}
}
}
@ -661,6 +761,7 @@ pub enum ThreadEvent {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
},
ScriptFinished,
}
impl EventEmitter<ThreadEvent> for Thread {}