From f059b6a24bac5a7bd65ea54a8f48b919d928d75e Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 25 Nov 2024 19:44:34 -0500 Subject: [PATCH] assistant2: Add support for using tools (#21190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds rudimentary support for using tools to `assistant2`. There are currently no visual affordances for tool use. This is gated behind the `assistant-tool-use` feature flag. Screenshot 2024-11-25 at 7 21 31 PM Release Notes: - N/A --- Cargo.lock | 3 + crates/assistant/src/context.rs | 12 +- crates/assistant2/Cargo.toml | 3 + crates/assistant2/src/assistant_panel.rs | 61 ++++++- crates/assistant2/src/message_editor.rs | 19 ++- crates/assistant2/src/thread.rs | 190 ++++++++++++++++++++-- crates/assistant_tools/src/now_tool.rs | 2 +- crates/feature_flags/src/feature_flags.rs | 10 ++ 8 files changed, 263 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7152bf8d08..5a18caa3d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -455,6 +455,8 @@ name = "assistant2" version = "0.1.0" dependencies = [ "anyhow", + "assistant_tool", + "collections", "command_palette_hooks", "editor", "feature_flags", @@ -463,6 +465,7 @@ dependencies = [ "language_model", "language_model_selector", "proto", + "serde_json", "settings", "smol", "theme", diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 2a7985a8c7..ac032accc3 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet; use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; -use feature_flags::{FeatureFlag, FeatureFlagAppExt}; +use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag}; use fs::{Fs, RemoveOptions}; use futures::{future::Shared, FutureExt, StreamExt}; use gpui::{ @@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus { Error(String), } -pub(crate) struct ToolUseFeatureFlag; - -impl FeatureFlag for ToolUseFeatureFlag { - const NAME: &'static str = "assistant-tool-use"; - - fn enabled_for_staff() -> bool { - false - } -} - #[derive(Debug, Clone)] pub struct PendingToolUse { pub id: Arc, diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 02cbdadb62..60c168079d 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -14,6 +14,8 @@ doctest = false [dependencies] anyhow.workspace = true +assistant_tool.workspace = true +collections.workspace = true command_palette_hooks.workspace = true editor.workspace = true feature_flags.workspace = true @@ -23,6 +25,7 @@ language_model.workspace = true language_model_selector.workspace = true proto.workspace = true settings.workspace = true +serde_json.workspace = true smol.workspace = true theme.workspace = true ui.workspace = true diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index c33e9d520d..b05a39a1cd 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use anyhow::Result; +use assistant_tool::ToolWorkingSet; use gpui::{ prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext, @@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::Workspace; use crate::message_editor::MessageEditor; -use crate::thread::Thread; +use crate::thread::{Thread, ThreadEvent}; use crate::{NewThread, ToggleFocus, ToggleModelSelector}; pub fn init(cx: &mut AppContext) { @@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) { } pub struct AssistantPanel { + workspace: WeakView, thread: Model, message_editor: View, + tools: Arc, _subscriptions: Vec, } @@ -36,26 +41,36 @@ impl AssistantPanel { cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { + let tools = Arc::new(ToolWorkingSet::default()); workspace.update(&mut cx, |workspace, cx| { - cx.new_view(|cx| Self::new(workspace, cx)) + cx.new_view(|cx| Self::new(workspace, tools, cx)) }) }) } - fn new(_workspace: &Workspace, cx: &mut ViewContext) -> Self { - let thread = cx.new_model(Thread::new); - let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())]; + fn new(workspace: &Workspace, tools: Arc, cx: &mut ViewContext) -> Self { + let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx)); + let subscriptions = vec![ + cx.observe(&thread, |_, _, cx| cx.notify()), + cx.subscribe(&thread, Self::handle_thread_event), + ]; Self { + workspace: workspace.weak_handle(), thread: thread.clone(), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), + tools, _subscriptions: subscriptions, } } fn new_thread(&mut self, cx: &mut ViewContext) { - let thread = cx.new_model(Thread::new); - let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())]; + let tools = self.thread.read(cx).tools().clone(); + let thread = cx.new_model(|cx| Thread::new(tools, cx)); + let subscriptions = vec![ + cx.observe(&thread, |_, _, cx| cx.notify()), + cx.subscribe(&thread, Self::handle_thread_event), + ]; self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx)); self.thread = thread; @@ -63,6 +78,38 @@ impl AssistantPanel { self.message_editor.focus_handle(cx).focus(cx); } + + fn handle_thread_event( + &mut self, + _: Model, + event: &ThreadEvent, + cx: &mut ViewContext, + ) { + match event { + ThreadEvent::StreamedCompletion => {} + ThreadEvent::UsePendingTools => { + let pending_tool_uses = self + .thread + .read(cx) + .pending_tool_uses() + .into_iter() + .filter(|tool_use| tool_use.status.is_idle()) + .cloned() + .collect::>(); + + for tool_use in pending_tool_uses { + if let Some(tool) = self.tools.tool(&tool_use.name, cx) { + let task = tool.run(tool_use.input, self.workspace.clone(), cx); + + self.thread.update(cx, |thread, cx| { + thread.insert_tool_output(tool_use.id.clone(), task, cx); + }); + } + } + } + ThreadEvent::ToolFinished { .. } => {} + } + } } impl FocusableView for AssistantPanel { diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index f0a8e260bc..c42d66a4d7 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -1,6 +1,7 @@ use editor::{Editor, EditorElement, EditorStyle}; +use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag}; use gpui::{AppContext, FocusableView, Model, TextStyle, View}; -use language_model::LanguageModelRegistry; +use language_model::{LanguageModelRegistry, LanguageModelRequestTool}; use settings::Settings; use theme::ThemeSettings; use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding}; @@ -55,7 +56,21 @@ impl MessageEditor { self.thread.update(cx, |thread, cx| { thread.insert_user_message(user_message); - let request = thread.to_completion_request(request_kind, cx); + let mut request = thread.to_completion_request(request_kind, cx); + + if cx.has_flag::() { + request.tools = thread + .tools() + .tools(cx) + .into_iter() + .map(|tool| LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + }) + .collect(); + } + thread.stream_completion(request, model, cx) }); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index c1df6c76d3..067e82a602 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,12 +1,16 @@ use std::sync::Arc; -use futures::StreamExt as _; +use anyhow::Result; +use assistant_tool::ToolWorkingSet; +use collections::HashMap; +use futures::future::Shared; +use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, EventEmitter, ModelContext, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - MessageContent, Role, StopReason, + LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason, }; -use util::{post_inc, ResultExt as _}; +use util::post_inc; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -14,14 +18,12 @@ pub enum RequestKind { } /// A message in a [`Thread`]. +#[derive(Debug)] pub struct Message { pub role: Role, pub text: String, -} - -struct PendingCompletion { - id: usize, - _task: Task<()>, + pub tool_uses: Vec, + pub tool_results: Vec, } /// A thread of conversation with the LLM. @@ -29,14 +31,20 @@ pub struct Thread { messages: Vec, completion_count: usize, pending_completions: Vec, + tools: Arc, + pending_tool_uses_by_id: HashMap, PendingToolUse>, + completed_tool_uses_by_id: HashMap, String>, } impl Thread { - pub fn new(_cx: &mut ModelContext) -> Self { + pub fn new(tools: Arc, _cx: &mut ModelContext) -> Self { Self { + tools, messages: Vec::new(), completion_count: 0, pending_completions: Vec::new(), + pending_tool_uses_by_id: HashMap::default(), + completed_tool_uses_by_id: HashMap::default(), } } @@ -44,11 +52,31 @@ impl Thread { self.messages.iter() } + pub fn tools(&self) -> &Arc { + &self.tools + } + + pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { + self.pending_tool_uses_by_id.values().collect() + } + pub fn insert_user_message(&mut self, text: impl Into) { - self.messages.push(Message { + let mut message = Message { role: Role::User, text: text.into(), - }); + tool_uses: Vec::new(), + tool_results: Vec::new(), + }; + + for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() { + message.tool_results.push(LanguageModelToolResult { + tool_use_id: tool_use_id.to_string(), + content: tool_output, + is_error: false, + }); + } + + self.messages.push(message); } pub fn to_completion_request( @@ -70,9 +98,23 @@ impl Thread { cache: false, }; - request_message - .content - .push(MessageContent::Text(message.text.clone())); + for tool_result in &message.tool_results { + request_message + .content + .push(MessageContent::ToolResult(tool_result.clone())); + } + + if !message.text.is_empty() { + request_message + .content + .push(MessageContent::Text(message.text.clone())); + } + + for tool_use in &message.tool_uses { + request_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } request.messages.push(request_message); } @@ -103,6 +145,8 @@ impl Thread { thread.messages.push(Message { role: Role::Assistant, text: String::new(), + tool_uses: Vec::new(), + tool_results: Vec::new(), }); } LanguageModelCompletionEvent::Stop(reason) => { @@ -115,7 +159,24 @@ impl Thread { } } } - LanguageModelCompletionEvent::ToolUse(_tool_use) => {} + LanguageModelCompletionEvent::ToolUse(tool_use) => { + if let Some(last_message) = thread.messages.last_mut() { + if last_message.role == Role::Assistant { + last_message.tool_uses.push(tool_use.clone()); + } + } + + let tool_use_id: Arc = tool_use.id.into(); + thread.pending_tool_uses_by_id.insert( + tool_use_id.clone(), + PendingToolUse { + id: tool_use_id, + name: tool_use.name, + input: tool_use.input, + status: PendingToolUseStatus::Idle, + }, + ); + } } cx.emit(ThreadEvent::StreamedCompletion); @@ -135,7 +196,35 @@ impl Thread { }; let result = stream_completion.await; - let _ = result.log_err(); + + thread + .update(&mut cx, |_thread, cx| { + let error_message = if let Some(error) = result.as_ref().err() { + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + Some(error_message) + } else { + None + }; + + if let Some(error_message) = error_message { + eprintln!("Completion failed: {error_message:?}"); + } + + if let Ok(stop_reason) = result { + match stop_reason { + StopReason::ToolUse => { + cx.emit(ThreadEvent::UsePendingTools); + } + StopReason::EndTurn => {} + StopReason::MaxTokens => {} + } + } + }) + .ok(); }); self.pending_completions.push(PendingCompletion { @@ -143,11 +232,80 @@ impl Thread { _task: task, }); } + + pub fn insert_tool_output( + &mut self, + tool_use_id: Arc, + output: Task>, + cx: &mut ModelContext, + ) { + let insert_output_task = cx.spawn(|thread, mut cx| { + let tool_use_id = tool_use_id.clone(); + async move { + let output = output.await; + thread + .update(&mut cx, |thread, cx| match output { + Ok(output) => { + thread + .completed_tool_uses_by_id + .insert(tool_use_id.clone(), output); + + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + } + Err(err) => { + if let Some(tool_use) = + thread.pending_tool_uses_by_id.get_mut(&tool_use_id) + { + tool_use.status = PendingToolUseStatus::Error(err.to_string()); + } + } + }) + .ok(); + } + }); + + if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { + tool_use.status = PendingToolUseStatus::Running { + _task: insert_output_task.shared(), + }; + } + } } #[derive(Debug, Clone)] pub enum ThreadEvent { StreamedCompletion, + UsePendingTools, + ToolFinished { + #[allow(unused)] + tool_use_id: Arc, + }, } impl EventEmitter for Thread {} + +struct PendingCompletion { + id: usize, + _task: Task<()>, +} + +#[derive(Debug, Clone)] +pub struct PendingToolUse { + pub id: Arc, + pub name: String, + pub input: serde_json::Value, + pub status: PendingToolUseStatus, +} + +#[derive(Debug, Clone)] +pub enum PendingToolUseStatus { + Idle, + Running { _task: Shared> }, + Error(#[allow(unused)] String), +} + +impl PendingToolUseStatus { + pub fn is_idle(&self) -> bool { + matches!(self, PendingToolUseStatus::Idle) + } +} diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index 99034321b1..707f2be2bd 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -30,7 +30,7 @@ impl Tool for NowTool { } fn description(&self) -> String { - "Returns the current datetime in RFC 3339 format.".into() + "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into() } fn input_schema(&self) -> serde_json::Value { diff --git a/crates/feature_flags/src/feature_flags.rs b/crates/feature_flags/src/feature_flags.rs index 416971b36e..48e3cc95b2 100644 --- a/crates/feature_flags/src/feature_flags.rs +++ b/crates/feature_flags/src/feature_flags.rs @@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag { } } +pub struct ToolUseFeatureFlag; + +impl FeatureFlag for ToolUseFeatureFlag { + const NAME: &'static str = "assistant-tool-use"; + + fn enabled_for_staff() -> bool { + false + } +} + pub struct Remoting {} impl FeatureFlag for Remoting { const NAME: &'static str = "remoting";