From fc52b43159435f02b6f3108fa54e0cd0e483bd61 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 28 Feb 2025 12:04:20 -0500 Subject: [PATCH] assistant2: Factor out tool use into its own module (#25819) This PR factors out the concerns related to tool use out of `Thread` and into their own module. Release Notes: - N/A --- crates/assistant2/src/active_thread.rs | 5 +- crates/assistant2/src/assistant.rs | 1 + crates/assistant2/src/thread.rs | 228 ++++--------------------- crates/assistant2/src/tool_use.rs | 221 ++++++++++++++++++++++++ 4 files changed, 258 insertions(+), 197 deletions(-) create mode 100644 crates/assistant2/src/tool_use.rs diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index ffa650d230..9963f8913b 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -15,10 +15,9 @@ use theme::ThemeSettings; use ui::{prelude::*, Disclosure}; use workspace::Workspace; -use crate::thread::{ - MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus, -}; +use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; use crate::thread_store::ThreadStore; +use crate::tool_use::{ToolUse, ToolUseStatus}; use crate::ui::ContextPill; pub struct ActiveThread { diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 7a3a82c9f7..2e3b3a3be1 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -16,6 +16,7 @@ mod terminal_inline_assistant; mod thread; mod thread_history; mod thread_store; +mod tool_use; mod ui; use std::sync::Arc; diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 50ea13e670..238f101401 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -4,14 +4,12 @@ use anyhow::Result; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; -use futures::future::Shared; -use futures::{FutureExt as _, StreamExt as _}; +use futures::StreamExt as _; use gpui::{App, Context, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolUse, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, - PaymentRequiredError, Role, StopReason, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId, + MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason, }; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; @@ -19,6 +17,7 @@ use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::thread_store::SavedThread; +use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -43,7 +42,7 @@ impl std::fmt::Display for ThreadId { } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct MessageId(usize); +pub struct MessageId(pub(crate) usize); impl MessageId { fn post_inc(&mut self) -> Self { @@ -59,22 +58,6 @@ pub struct Message { pub text: String, } -#[derive(Debug)] -pub struct ToolUse { - pub id: LanguageModelToolUseId, - pub name: SharedString, - pub status: ToolUseStatus, - pub input: serde_json::Value, -} - -#[derive(Debug, Clone)] -pub enum ToolUseStatus { - Pending, - Running, - Finished(SharedString), - Error(SharedString), -} - /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -88,10 +71,7 @@ pub struct Thread { completion_count: usize, pending_completions: Vec, tools: Arc, - tool_uses_by_assistant_message: HashMap>, - tool_uses_by_user_message: HashMap>, - tool_results: HashMap, - pending_tool_uses_by_id: HashMap, + tool_use: ToolUseState, } impl Thread { @@ -108,10 +88,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_uses_by_assistant_message: HashMap::default(), - tool_uses_by_user_message: HashMap::default(), - tool_results: HashMap::default(), - pending_tool_uses_by_id: HashMap::default(), + tool_use: ToolUseState::default(), } } @@ -143,10 +120,7 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_uses_by_assistant_message: HashMap::default(), - tool_uses_by_user_message: HashMap::default(), - tool_results: HashMap::default(), - pending_tool_uses_by_id: HashMap::default(), + tool_use: ToolUseState::default(), } } @@ -208,56 +182,15 @@ impl Thread { } pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { - self.pending_tool_uses_by_id.values().collect() + self.tool_use.pending_tool_uses() } pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { - let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { - return Vec::new(); - }; - - let mut tool_uses = Vec::new(); - - for tool_use in tool_uses_for_message.iter() { - let tool_result = self.tool_results.get(&tool_use.id); - - let status = (|| { - if let Some(tool_result) = tool_result { - return if tool_result.is_error { - ToolUseStatus::Error(tool_result.content.clone().into()) - } else { - ToolUseStatus::Finished(tool_result.content.clone().into()) - }; - } - - if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { - return match pending_tool_use.status { - PendingToolUseStatus::Idle => ToolUseStatus::Pending, - PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, - PendingToolUseStatus::Error(ref err) => { - ToolUseStatus::Error(err.clone().into()) - } - }; - } - - ToolUseStatus::Pending - })(); - - tool_uses.push(ToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - input: tool_use.input.clone(), - status, - }) - } - - tool_uses + self.tool_use.tool_uses_for_message(id) } pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { - self.tool_uses_by_user_message - .get(&message_id) - .map_or(false, |results| !results.is_empty()) + self.tool_use.message_has_tool_results(message_id) } pub fn insert_user_message( @@ -360,20 +293,13 @@ impl Thread { content: Vec::new(), cache: false, }; - if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) { - match request_kind { - RequestKind::Chat => { - for tool_use_id in tool_uses { - if let Some(tool_result) = self.tool_results.get(tool_use_id) { - request_message - .content - .push(MessageContent::ToolResult(tool_result.clone())); - } - } - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - } + match request_kind { + RequestKind::Chat => { + self.tool_use + .attach_tool_results(message.id, &mut request_message); + } + RequestKind::Summarize => { + // We don't care about tool use during summarization. } } @@ -383,18 +309,13 @@ impl Thread { .push(MessageContent::Text(message.text.clone())); } - if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) { - match request_kind { - RequestKind::Chat => { - for tool_use in tool_uses { - request_message - .content - .push(MessageContent::ToolUse(tool_use.clone())); - } - } - RequestKind::Summarize => { - // We don't care about tool use during summarization. - } + match request_kind { + RequestKind::Chat => { + self.tool_use + .attach_tool_uses(message.id, &mut request_message); + } + RequestKind::Summarize => { + // We don't care about tool use during summarization. } } @@ -470,32 +391,8 @@ impl Thread { .rfind(|message| message.role == Role::Assistant) { thread - .tool_uses_by_assistant_message - .entry(last_assistant_message.id) - .or_default() - .push(tool_use.clone()); - - // The tool use is being requested by the - // Assistant, so we want to attach the tool - // results to the next user message. - let next_user_message_id = - MessageId(last_assistant_message.id.0 + 1); - thread - .tool_uses_by_user_message - .entry(next_user_message_id) - .or_default() - .push(tool_use.id.clone()); - - thread.pending_tool_uses_by_id.insert( - tool_use.id.clone(), - PendingToolUse { - assistant_message_id: last_assistant_message.id, - id: tool_use.id, - name: tool_use.name, - input: tool_use.input, - status: PendingToolUseStatus::Idle, - }, - ); + .tool_use + .request_tool_use(last_assistant_message.id, tool_use); } } } @@ -624,49 +521,19 @@ impl Thread { async move { let output = output.await; thread - .update(&mut cx, |thread, cx| match output { - Ok(output) => { - thread.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - content: output.into(), - is_error: false, - }, - ); - thread.pending_tool_uses_by_id.remove(&tool_use_id); + .update(&mut cx, |thread, cx| { + thread + .tool_use + .insert_tool_output(tool_use_id.clone(), output); - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); - } - Err(err) => { - thread.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - content: err.to_string().into(), - is_error: true, - }, - ); - - if let Some(tool_use) = - thread.pending_tool_uses_by_id.get_mut(&tool_use_id) - { - tool_use.status = - PendingToolUseStatus::Error(err.to_string().into()); - } - - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); - } + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); }) .ok(); } }); - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - tool_use.status = PendingToolUseStatus::Running { - _task: insert_output_task.shared(), - }; - } + self.tool_use + .run_pending_tool(tool_use_id, insert_output_task); } /// Cancels the last pending completion, if there are any pending. @@ -708,30 +575,3 @@ struct PendingCompletion { id: usize, _task: Task<()>, } - -#[derive(Debug, Clone)] -pub struct PendingToolUse { - pub id: LanguageModelToolUseId, - /// The ID of the Assistant message in which the tool use was requested. - pub assistant_message_id: MessageId, - pub name: Arc, - pub input: serde_json::Value, - pub status: PendingToolUseStatus, -} - -#[derive(Debug, Clone)] -pub enum PendingToolUseStatus { - Idle, - Running { _task: Shared> }, - Error(#[allow(unused)] Arc), -} - -impl PendingToolUseStatus { - pub fn is_idle(&self) -> bool { - matches!(self, PendingToolUseStatus::Idle) - } - - pub fn is_error(&self) -> bool { - matches!(self, PendingToolUseStatus::Error(_)) - } -} diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs new file mode 100644 index 0000000000..12b73554f9 --- /dev/null +++ b/crates/assistant2/src/tool_use.rs @@ -0,0 +1,221 @@ +use std::sync::Arc; + +use anyhow::Result; +use collections::HashMap; +use futures::future::Shared; +use futures::FutureExt as _; +use gpui::{SharedString, Task}; +use language_model::{ + LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, + LanguageModelToolUseId, MessageContent, +}; + +use crate::thread::MessageId; + +#[derive(Debug)] +pub struct ToolUse { + pub id: LanguageModelToolUseId, + pub name: SharedString, + pub status: ToolUseStatus, + pub input: serde_json::Value, +} + +#[derive(Debug, Clone)] +pub enum ToolUseStatus { + Pending, + Running, + Finished(SharedString), + Error(SharedString), +} + +#[derive(Default)] +pub struct ToolUseState { + tool_uses_by_assistant_message: HashMap>, + tool_uses_by_user_message: HashMap>, + tool_results: HashMap, + pending_tool_uses_by_id: HashMap, +} + +impl ToolUseState { + pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { + self.pending_tool_uses_by_id.values().collect() + } + + pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { + let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { + return Vec::new(); + }; + + let mut tool_uses = Vec::new(); + + for tool_use in tool_uses_for_message.iter() { + let tool_result = self.tool_results.get(&tool_use.id); + + let status = (|| { + if let Some(tool_result) = tool_result { + return if tool_result.is_error { + ToolUseStatus::Error(tool_result.content.clone().into()) + } else { + ToolUseStatus::Finished(tool_result.content.clone().into()) + }; + } + + if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { + return match pending_tool_use.status { + PendingToolUseStatus::Idle => ToolUseStatus::Pending, + PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, + PendingToolUseStatus::Error(ref err) => { + ToolUseStatus::Error(err.clone().into()) + } + }; + } + + ToolUseStatus::Pending + })(); + + tool_uses.push(ToolUse { + id: tool_use.id.clone(), + name: tool_use.name.clone().into(), + input: tool_use.input.clone(), + status, + }) + } + + tool_uses + } + + pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { + self.tool_uses_by_user_message + .get(&message_id) + .map_or(false, |results| !results.is_empty()) + } + + pub fn request_tool_use( + &mut self, + assistant_message_id: MessageId, + tool_use: LanguageModelToolUse, + ) { + self.tool_uses_by_assistant_message + .entry(assistant_message_id) + .or_default() + .push(tool_use.clone()); + + // The tool use is being requested by the Assistant, so we want to + // attach the tool results to the next user message. + let next_user_message_id = MessageId(assistant_message_id.0 + 1); + self.tool_uses_by_user_message + .entry(next_user_message_id) + .or_default() + .push(tool_use.id.clone()); + + self.pending_tool_uses_by_id.insert( + tool_use.id.clone(), + PendingToolUse { + assistant_message_id, + id: tool_use.id, + name: tool_use.name, + input: tool_use.input, + status: PendingToolUseStatus::Idle, + }, + ); + } + + pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) { + if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { + tool_use.status = PendingToolUseStatus::Running { + _task: task.shared(), + }; + } + } + + pub fn insert_tool_output( + &mut self, + tool_use_id: LanguageModelToolUseId, + output: Result, + ) { + match output { + Ok(output) => { + self.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use_id.clone(), + content: output.into(), + is_error: false, + }, + ); + self.pending_tool_uses_by_id.remove(&tool_use_id); + } + Err(err) => { + self.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use_id.clone(), + content: err.to_string().into(), + is_error: true, + }, + ); + + if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { + tool_use.status = PendingToolUseStatus::Error(err.to_string().into()); + } + } + } + } + + pub fn attach_tool_uses( + &self, + message_id: MessageId, + request_message: &mut LanguageModelRequestMessage, + ) { + if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) { + for tool_use in tool_uses { + request_message + .content + .push(MessageContent::ToolUse(tool_use.clone())); + } + } + } + + pub fn attach_tool_results( + &self, + message_id: MessageId, + request_message: &mut LanguageModelRequestMessage, + ) { + if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) { + for tool_use_id in tool_uses { + if let Some(tool_result) = self.tool_results.get(tool_use_id) { + request_message + .content + .push(MessageContent::ToolResult(tool_result.clone())); + } + } + } + } +} + +#[derive(Debug, Clone)] +pub struct PendingToolUse { + pub id: LanguageModelToolUseId, + /// The ID of the Assistant message in which the tool use was requested. + pub assistant_message_id: MessageId, + pub name: Arc, + pub input: serde_json::Value, + pub status: PendingToolUseStatus, +} + +#[derive(Debug, Clone)] +pub enum PendingToolUseStatus { + Idle, + Running { _task: Shared> }, + Error(#[allow(unused)] Arc), +} + +impl PendingToolUseStatus { + pub fn is_idle(&self) -> bool { + matches!(self, PendingToolUseStatus::Idle) + } + + pub fn is_error(&self) -> bool { + matches!(self, PendingToolUseStatus::Error(_)) + } +}