assistant: Make scripting a first-class concept instead of a tool (#26338)
This PR makes refactors the scripting functionality to be a first-class concept of the assistant instead of a generic tool, which will allow us to build a more customized experience. - The tool prompt has been slightly tweaked and is now included as a system message in all conversations. I'm getting decent results, but now that it isn't in the tools framework, it will probably require more refining. - The model will now include an `<eval ...>` tag at the end of the message with the script. We parse this tag incrementally as it streams in so that we can indicate that we are generating a script before we see the closing `</eval>` tag. Later, this will help us interpret the script as it arrives also. - Threads now hold a `ScriptSession` entity which manages the state of all scripts (from parsing to exited) in a centralized way, and will later collect all script operations so they can be displayed in the UI. - `script_tool` has been renamed to `assistant_scripting` - Script source now opens in a regular read-only buffer Note: We still need to handle persistence properly Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
ed6bf7f161
commit
e298301b40
16 changed files with 811 additions and 197 deletions
|
@ -1,11 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_scripting::{
|
||||
Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT,
|
||||
};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
|
@ -75,14 +78,21 @@ pub struct Thread {
|
|||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
tool_use: ToolUseState,
|
||||
scripts_by_assistant_message: HashMap<MessageId, ScriptId>,
|
||||
script_output_messages: HashSet<MessageId>,
|
||||
script_session: Entity<ScriptSession>,
|
||||
_script_session_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
_cx: &mut Context<Self>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
updated_at: Utc::now(),
|
||||
|
@ -97,6 +107,10 @@ impl Thread {
|
|||
project,
|
||||
tools,
|
||||
tool_use: ToolUseState::new(),
|
||||
scripts_by_assistant_message: HashMap::default(),
|
||||
script_output_messages: HashSet::default(),
|
||||
script_session,
|
||||
_script_session_subscription: script_session_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,7 +119,7 @@ impl Thread {
|
|||
saved: SavedThread,
|
||||
project: Entity<Project>,
|
||||
tools: Arc<ToolWorkingSet>,
|
||||
_cx: &mut Context<Self>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let next_message_id = MessageId(
|
||||
saved
|
||||
|
@ -115,6 +129,8 @@ impl Thread {
|
|||
.unwrap_or(0),
|
||||
);
|
||||
let tool_use = ToolUseState::from_saved_messages(&saved.messages);
|
||||
let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx));
|
||||
let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event);
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
@ -138,6 +154,10 @@ impl Thread {
|
|||
project,
|
||||
tools,
|
||||
tool_use,
|
||||
scripts_by_assistant_message: HashMap::default(),
|
||||
script_output_messages: HashSet::default(),
|
||||
script_session,
|
||||
_script_session_subscription: script_session_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -223,17 +243,22 @@ impl Thread {
|
|||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
|
||||
pub fn message_has_script_output(&self, message_id: MessageId) -> bool {
|
||||
self.script_output_messages.contains(&message_id)
|
||||
}
|
||||
|
||||
pub fn insert_user_message(
|
||||
&mut self,
|
||||
text: impl Into<String>,
|
||||
context: Vec<ContextSnapshot>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> MessageId {
|
||||
let message_id = self.insert_message(Role::User, text, cx);
|
||||
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
||||
self.context
|
||||
.extend(context.into_iter().map(|context| (context.id, context)));
|
||||
self.context_by_message.insert(message_id, context_ids);
|
||||
message_id
|
||||
}
|
||||
|
||||
pub fn insert_message(
|
||||
|
@ -302,6 +327,39 @@ impl Thread {
|
|||
text
|
||||
}
|
||||
|
||||
pub fn script_for_message<'a>(
|
||||
&'a self,
|
||||
message_id: MessageId,
|
||||
cx: &'a App,
|
||||
) -> Option<&'a Script> {
|
||||
self.scripts_by_assistant_message
|
||||
.get(&message_id)
|
||||
.map(|script_id| self.script_session.read(cx).get(*script_id))
|
||||
}
|
||||
|
||||
fn handle_script_event(
|
||||
&mut self,
|
||||
_script_session: Entity<ScriptSession>,
|
||||
event: &ScriptEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
ScriptEvent::Spawned(_) => {}
|
||||
ScriptEvent::Exited(script_id) => {
|
||||
if let Some(output_message) = self
|
||||
.script_session
|
||||
.read(cx)
|
||||
.get(*script_id)
|
||||
.output_message_for_llm()
|
||||
{
|
||||
let message_id = self.insert_user_message(output_message, vec![], cx);
|
||||
self.script_output_messages.insert(message_id);
|
||||
cx.emit(ThreadEvent::ScriptFinished)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_to_model(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -330,7 +388,7 @@ impl Thread {
|
|||
pub fn to_completion_request(
|
||||
&self,
|
||||
request_kind: RequestKind,
|
||||
_cx: &App,
|
||||
cx: &App,
|
||||
) -> LanguageModelRequest {
|
||||
let mut request = LanguageModelRequest {
|
||||
messages: vec![],
|
||||
|
@ -339,6 +397,12 @@ impl Thread {
|
|||
temperature: None,
|
||||
};
|
||||
|
||||
request.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![SCRIPTING_PROMPT.to_string().into()],
|
||||
cache: true,
|
||||
});
|
||||
|
||||
let mut referenced_context_ids = HashSet::default();
|
||||
|
||||
for message in &self.messages {
|
||||
|
@ -351,6 +415,7 @@ impl Thread {
|
|||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
match request_kind {
|
||||
RequestKind::Chat => {
|
||||
self.tool_use
|
||||
|
@ -371,11 +436,20 @@ impl Thread {
|
|||
RequestKind::Chat => {
|
||||
self.tool_use
|
||||
.attach_tool_uses(message.id, &mut request_message);
|
||||
|
||||
if matches!(message.role, Role::Assistant) {
|
||||
if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id)
|
||||
{
|
||||
let script = self.script_session.read(cx).get(*script_id);
|
||||
|
||||
request_message.content.push(script.source_tag().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
RequestKind::Summarize => {
|
||||
// We don't care about tool use during summarization.
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
request.messages.push(request_message);
|
||||
}
|
||||
|
@ -412,6 +486,8 @@ impl Thread {
|
|||
let stream_completion = async {
|
||||
let mut events = stream.await?;
|
||||
let mut stop_reason = StopReason::EndTurn;
|
||||
let mut script_tag_parser = ScriptTagParser::new();
|
||||
let mut script_id = None;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
let event = event?;
|
||||
|
@ -426,19 +502,43 @@ impl Thread {
|
|||
}
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk);
|
||||
let chunk = script_tag_parser.parse_chunk(&chunk);
|
||||
|
||||
let message_id = if last_message.role == Role::Assistant {
|
||||
last_message.text.push_str(&chunk.content);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
chunk,
|
||||
chunk.content,
|
||||
));
|
||||
last_message.id
|
||||
} else {
|
||||
// If we won't have an Assistant message yet, assume this chunk marks the beginning
|
||||
// of a new Assistant response.
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(Role::Assistant, chunk, cx);
|
||||
thread.insert_message(Role::Assistant, chunk.content, cx)
|
||||
};
|
||||
|
||||
if script_id.is_none() && script_tag_parser.found_script() {
|
||||
let id = thread
|
||||
.script_session
|
||||
.update(cx, |session, _cx| session.new_script());
|
||||
thread.scripts_by_assistant_message.insert(message_id, id);
|
||||
|
||||
script_id = Some(id);
|
||||
}
|
||||
|
||||
if let (Some(script_source), Some(script_id)) =
|
||||
(chunk.script_source, script_id)
|
||||
{
|
||||
// TODO: move buffer to script and run as it streams
|
||||
thread
|
||||
.script_session
|
||||
.update(cx, |this, cx| {
|
||||
this.run_script(script_id, script_source, cx)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -661,6 +761,7 @@ pub enum ThreadEvent {
|
|||
#[allow(unused)]
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
},
|
||||
ScriptFinished,
|
||||
}
|
||||
|
||||
impl EventEmitter<ThreadEvent> for Thread {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue