From 01525f17fa522282d94fa4cc2eb9555e74647a7f Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 4 Sep 2024 14:32:20 -0400 Subject: [PATCH] assistant: Add basic tool invocation (#17368) This PR adds the initial groundwork for invoking tools in response to tool uses from the model. Tool uses are run when the model responds with a `stop_reason` of `tool_use`. Currently the tool results are just inserted as text into the user message. We'll want to include these as `tool_result` content on the message, but Claude seems to understand it regardless. Release Notes: - N/A --- crates/assistant/src/assistant_panel.rs | 22 +++++++ crates/assistant/src/context.rs | 75 +++++++++++++++++++--- crates/assistant_tool/src/tool_registry.rs | 5 ++ 3 files changed, 92 insertions(+), 10 deletions(-) diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 6b3a230098..3c8eb1a6d9 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -20,6 +20,7 @@ use crate::{ }; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; +use assistant_tool::ToolRegistry; use client::{proto, Client, Status}; use collections::{BTreeSet, HashMap, HashSet}; use editor::{ @@ -2091,6 +2092,27 @@ impl ContextEditor { } } } + ContextEvent::UsePendingTools => { + let pending_tool_uses = self + .context + .read(cx) + .pending_tool_uses() + .into_iter() + .filter(|tool_use| tool_use.status.is_idle()) + .cloned() + .collect::>(); + + for tool_use in pending_tool_uses { + let tool_registry = ToolRegistry::global(cx); + if let Some(tool) = tool_registry.tool(&tool_use.name) { + let task = tool.run(tool_use.input, self.workspace.clone(), cx); + + self.context.update(cx, |context, cx| { + context.insert_tool_output(tool_use.id.clone(), task, cx); + }); + } + } + } ContextEvent::Operation(_) => {} ContextEvent::ShowAssistError(error_message) => { self.error_message = Some(error_message.clone()); diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index ec2248f19f..b4842ccc8f 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -29,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P use language_model::{ LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelRequestTool, MessageContent, Role, + LanguageModelRequestTool, MessageContent, Role, StopReason, }; use open_ai::Model as OpenAiModel; use paths::{context_images_dir, contexts_dir}; @@ -306,6 +306,7 @@ pub enum ContextEvent { run_commands_in_output: bool, expand_result: bool, }, + UsePendingTools, Operation(ContextOperation), } @@ -416,6 +417,7 @@ impl Message { range_start = *image_offset; } + if range_start != self.offset_range.end { if let Some(text) = Self::collect_text_content(buffer, range_start..self.offset_range.end) @@ -492,7 +494,7 @@ pub struct Context { edits_since_last_parse: language::Subscription, finished_slash_commands: HashSet, slash_command_output_sections: Vec>, - pending_tool_uses_by_id: HashMap, + pending_tool_uses_by_id: HashMap, PendingToolUse>, message_anchors: Vec, images: HashMap, Shared>>)>, image_anchors: Vec, @@ -1012,7 +1014,7 @@ impl Context { self.pending_tool_uses_by_id.values().collect() } - pub fn get_tool_use_by_id(&self, id: &String) -> Option<&PendingToolUse> { + pub fn get_tool_use_by_id(&self, id: &Arc) -> Option<&PendingToolUse> { self.pending_tool_uses_by_id.get(id) } @@ -1919,6 +1921,45 @@ impl Context { } } + pub fn insert_tool_output( + &mut self, + tool_id: Arc, + output: Task>, + cx: &mut ModelContext, + ) { + let insert_output_task = cx.spawn(|this, mut cx| { + let tool_id = tool_id.clone(); + async move { + let output = output.await; + this.update(&mut cx, |this, cx| match output { + Ok(mut output) => { + if !output.ends_with('\n') { + output.push('\n'); + } + + this.buffer.update(cx, |buffer, cx| { + let buffer_end = buffer.len().to_offset(buffer); + + buffer.edit([(buffer_end..buffer_end, output)], None, cx); + }); + } + Err(err) => { + if let Some(tool_use) = this.pending_tool_uses_by_id.get_mut(&tool_id) { + tool_use.status = PendingToolUseStatus::Error(err.to_string()); + } + } + }) + .ok(); + } + }); + + if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_id) { + tool_use.status = PendingToolUseStatus::Running { + _task: insert_output_task.shared(), + }; + } + } + pub fn completion_provider_changed(&mut self, cx: &mut ModelContext) { self.count_remaining_tokens(cx); } @@ -1990,7 +2031,7 @@ impl Context { .message_anchors .iter() .position(|message| message.id == assistant_message_id)?; - this.buffer.update(cx, |buffer, cx| { + let event_to_emit = this.buffer.update(cx, |buffer, cx| { let message_old_end_offset = this.message_anchors[message_ix + 1..] .iter() .find(|message| message.start.is_valid(buffer)) @@ -2000,9 +2041,11 @@ impl Context { match event { LanguageModelCompletionEvent::Stop(reason) => match reason { - language_model::StopReason::ToolUse => {} - language_model::StopReason::EndTurn => {} - language_model::StopReason::MaxTokens => {} + StopReason::ToolUse => { + return Some(ContextEvent::UsePendingTools); + } + StopReason::EndTurn => {} + StopReason::MaxTokens => {} }, LanguageModelCompletionEvent::Text(chunk) => { buffer.edit( @@ -2041,10 +2084,11 @@ impl Context { let source_range = buffer.anchor_after(start_ix) ..buffer.anchor_after(end_ix); + let tool_use_id: Arc = tool_use.id.into(); this.pending_tool_uses_by_id.insert( - tool_use.id.clone(), + tool_use_id.clone(), PendingToolUse { - id: tool_use.id, + id: tool_use_id, name: tool_use.name, input: tool_use.input, status: PendingToolUseStatus::Idle, @@ -2053,9 +2097,14 @@ impl Context { ); } } + + None }); cx.emit(ContextEvent::StreamedCompletion); + if let Some(event) = event_to_emit { + cx.emit(event); + } Some(()) })?; @@ -2821,7 +2870,7 @@ impl FeatureFlag for ToolUseFeatureFlag { #[derive(Debug, Clone)] pub struct PendingToolUse { - pub id: String, + pub id: Arc, pub name: String, pub input: serde_json::Value, pub status: PendingToolUseStatus, @@ -2835,6 +2884,12 @@ pub enum PendingToolUseStatus { Error(String), } +impl PendingToolUseStatus { + pub fn is_idle(&self) -> bool { + matches!(self, PendingToolUseStatus::Idle) + } +} + #[derive(Serialize, Deserialize)] pub struct SavedMessage { pub id: MessageId, diff --git a/crates/assistant_tool/src/tool_registry.rs b/crates/assistant_tool/src/tool_registry.rs index d7e1fde4c3..d225b6a9a3 100644 --- a/crates/assistant_tool/src/tool_registry.rs +++ b/crates/assistant_tool/src/tool_registry.rs @@ -66,4 +66,9 @@ impl ToolRegistry { pub fn tools(&self) -> Vec> { self.state.read().tools.values().cloned().collect() } + + /// Returns the [`Tool`] with the given name. + pub fn tool(&self, name: &str) -> Option> { + self.state.read().tools.get(name).cloned() + } }