use crate::{ thread::{MessageId, PromptId, ThreadId}, thread_store::SerializedMessage, }; use agent_settings::CompletionMode; use anyhow::Result; use assistant_tool::{ AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, }; use collections::HashMap; use futures::{FutureExt as _, future::Shared}; use gpui::{App, Entity, SharedString, Task, Window}; use icons::IconName; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role, }; use project::Project; use std::sync::Arc; use util::truncate_lines_to_byte_limit; #[derive(Debug)] pub struct ToolUse { pub id: LanguageModelToolUseId, pub name: SharedString, pub ui_text: SharedString, pub status: ToolUseStatus, pub input: serde_json::Value, pub icon: icons::IconName, pub needs_confirmation: bool, } pub struct ToolUseState { tools: Entity, tool_uses_by_assistant_message: HashMap>, tool_results: HashMap, pending_tool_uses_by_id: HashMap, tool_result_cards: HashMap, tool_use_metadata_by_id: HashMap, } impl ToolUseState { pub fn new(tools: Entity) -> Self { Self { tools, tool_uses_by_assistant_message: HashMap::default(), tool_results: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), tool_result_cards: HashMap::default(), tool_use_metadata_by_id: HashMap::default(), } } /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s. /// /// Accepts a function to filter the tools that should be used to populate the state. /// /// If `window` is `None` (e.g., when in headless mode or when running evals), /// tool cards won't be deserialized pub fn from_serialized_messages( tools: Entity, messages: &[SerializedMessage], project: Entity, window: Option<&mut Window>, // None in headless mode cx: &mut App, ) -> Self { let mut this = Self::new(tools); let mut tool_names_by_id = HashMap::default(); let mut window = window; for message in messages { match message.role { Role::Assistant => { if !message.tool_uses.is_empty() { let tool_uses = message .tool_uses .iter() .map(|tool_use| LanguageModelToolUse { id: tool_use.id.clone(), name: tool_use.name.clone().into(), raw_input: tool_use.input.to_string(), input: tool_use.input.clone(), is_input_complete: true, }) .collect::>(); tool_names_by_id.extend( tool_uses .iter() .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())), ); this.tool_uses_by_assistant_message .insert(message.id, tool_uses); for tool_result in &message.tool_results { let tool_use_id = tool_result.tool_use_id.clone(); let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else { log::warn!("no tool name found for tool use: {tool_use_id:?}"); continue; }; this.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name: tool_use.clone(), is_error: tool_result.is_error, content: tool_result.content.clone(), output: tool_result.output.clone(), }, ); if let Some(window) = &mut window && let Some(tool) = this.tools.read(cx).tool(tool_use, cx) && let Some(output) = tool_result.output.clone() && let Some(card) = tool.deserialize_card(output, project.clone(), window, cx) { this.tool_result_cards.insert(tool_use_id, card); } } } } Role::System | Role::User => {} } } this } pub fn cancel_pending(&mut self) -> Vec { let mut canceled_tool_uses = Vec::new(); self.pending_tool_uses_by_id .retain(|tool_use_id, tool_use| { if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) { return true; } let content = "Tool canceled by user".into(); self.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name: tool_use.name.clone(), content, output: None, is_error: true, }, ); canceled_tool_uses.push(tool_use.clone()); false }); canceled_tool_uses } pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { self.pending_tool_uses_by_id.values().collect() } pub fn tool_uses_for_message( &self, id: MessageId, project: &Entity, cx: &App, ) -> Vec { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; let mut tool_uses = Vec::new(); for tool_use in tool_uses_for_message.iter() { let tool_result = self.tool_results.get(&tool_use.id); let status = (|| { if let Some(tool_result) = tool_result { let content = tool_result .content .to_str() .map(|str| str.to_owned().into()) .unwrap_or_default(); return if tool_result.is_error { ToolUseStatus::Error(content) } else { ToolUseStatus::Finished(content) }; } if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { match pending_tool_use.status { PendingToolUseStatus::Idle => ToolUseStatus::Pending, PendingToolUseStatus::NeedsConfirmation { .. } => { ToolUseStatus::NeedsConfirmation } PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, PendingToolUseStatus::Error(ref err) => { ToolUseStatus::Error(err.clone().into()) } PendingToolUseStatus::InputStillStreaming => { ToolUseStatus::InputStillStreaming } } } else { ToolUseStatus::Pending } })(); let (icon, needs_confirmation) = if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { ( tool.icon(), tool.needs_confirmation(&tool_use.input, project, cx), ) } else { (IconName::Cog, false) }; tool_uses.push(ToolUse { id: tool_use.id.clone(), name: tool_use.name.clone().into(), ui_text: self.tool_ui_label( &tool_use.name, &tool_use.input, tool_use.is_input_complete, cx, ), input: tool_use.input.clone(), status, icon, needs_confirmation, }) } tool_uses } pub fn tool_ui_label( &self, tool_name: &str, input: &serde_json::Value, is_input_complete: bool, cx: &App, ) -> SharedString { if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) { if is_input_complete { tool.ui_text(input).into() } else { tool.still_streaming_ui_text(input).into() } } else { format!("Unknown tool {tool_name:?}").into() } } pub fn tool_results_for_message( &self, assistant_message_id: MessageId, ) -> Vec<&LanguageModelToolResult> { let Some(tool_uses) = self .tool_uses_by_assistant_message .get(&assistant_message_id) else { return Vec::new(); }; tool_uses .iter() .filter_map(|tool_use| self.tool_results.get(&tool_use.id)) .collect() } pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { self.tool_uses_by_assistant_message .get(&assistant_message_id) .is_some_and(|results| !results.is_empty()) } pub fn tool_result( &self, tool_use_id: &LanguageModelToolUseId, ) -> Option<&LanguageModelToolResult> { self.tool_results.get(tool_use_id) } pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> { self.tool_result_cards.get(tool_use_id) } pub fn insert_tool_result_card( &mut self, tool_use_id: LanguageModelToolUseId, card: AnyToolCard, ) { self.tool_result_cards.insert(tool_use_id, card); } pub fn request_tool_use( &mut self, assistant_message_id: MessageId, tool_use: LanguageModelToolUse, metadata: ToolUseMetadata, cx: &App, ) -> Arc { let tool_uses = self .tool_uses_by_assistant_message .entry(assistant_message_id) .or_default(); let mut existing_tool_use_found = false; for existing_tool_use in tool_uses.iter_mut() { if existing_tool_use.id == tool_use.id { *existing_tool_use = tool_use.clone(); existing_tool_use_found = true; } } if !existing_tool_use_found { tool_uses.push(tool_use.clone()); } let status = if tool_use.is_input_complete { self.tool_use_metadata_by_id .insert(tool_use.id.clone(), metadata); PendingToolUseStatus::Idle } else { PendingToolUseStatus::InputStillStreaming }; let ui_text: Arc = self .tool_ui_label( &tool_use.name, &tool_use.input, tool_use.is_input_complete, cx, ) .into(); let may_perform_edits = self .tools .read(cx) .tool(&tool_use.name, cx) .is_some_and(|tool| tool.may_perform_edits()); self.pending_tool_uses_by_id.insert( tool_use.id.clone(), PendingToolUse { assistant_message_id, id: tool_use.id, name: tool_use.name.clone(), ui_text: ui_text.clone(), input: tool_use.input, may_perform_edits, status, }, ); ui_text } pub fn run_pending_tool( &mut self, tool_use_id: LanguageModelToolUseId, ui_text: SharedString, task: Task<()>, ) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { tool_use.ui_text = ui_text.into(); tool_use.status = PendingToolUseStatus::Running { _task: task.shared(), }; } } pub fn confirm_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, ui_text: impl Into>, input: serde_json::Value, request: Arc, tool: Arc, ) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { let ui_text = ui_text.into(); tool_use.ui_text = ui_text.clone(); let confirmation = Confirmation { tool_use_id, input, request, tool, ui_text, }; tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation)); } } pub fn insert_tool_output( &mut self, tool_use_id: LanguageModelToolUseId, tool_name: Arc, output: Result, configured_model: Option<&ConfiguredModel>, completion_mode: CompletionMode, ) -> Option { let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); telemetry::event!( "Agent Tool Finished", model = metadata .as_ref() .map(|metadata| metadata.model.telemetry_id()), model_provider = metadata .as_ref() .map(|metadata| metadata.model.provider_id().to_string()), thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()), prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()), tool_name, success = output.is_ok() ); match output { Ok(output) => { let tool_result = output.content; const BYTES_PER_TOKEN_ESTIMATE: usize = 3; let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id); // Protect from overly large output let tool_output_limit = configured_model .map(|model| { model.model.max_token_count_for_mode(completion_mode.into()) as usize * BYTES_PER_TOKEN_ESTIMATE }) .unwrap_or(usize::MAX); let content = match tool_result { ToolResultContent::Text(text) => { let text = if text.len() < tool_output_limit { text } else { let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit); format!( "Tool result too long. The first {} bytes:\n\n{}", truncated.len(), truncated ) }; LanguageModelToolResultContent::Text(text.into()) } ToolResultContent::Image(language_model_image) => { if language_model_image.estimate_tokens() < tool_output_limit { LanguageModelToolResultContent::Image(language_model_image) } else { self.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name, content: "Tool responded with an image that would exceeded the remaining tokens".into(), is_error: true, output: None, }, ); return old_use; } } }; self.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name, content, is_error: false, output: output.output, }, ); old_use } Err(err) => { self.tool_results.insert( tool_use_id.clone(), LanguageModelToolResult { tool_use_id: tool_use_id.clone(), tool_name, content: LanguageModelToolResultContent::Text(err.to_string().into()), is_error: true, output: None, }, ); if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { tool_use.status = PendingToolUseStatus::Error(err.to_string().into()); } self.pending_tool_uses_by_id.get(&tool_use_id).cloned() } } } pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool { self.tool_uses_by_assistant_message .contains_key(&assistant_message_id) } pub fn tool_results( &self, assistant_message_id: MessageId, ) -> impl Iterator)> { self.tool_uses_by_assistant_message .get(&assistant_message_id) .into_iter() .flatten() .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id))) } } #[derive(Debug, Clone)] pub struct PendingToolUse { pub id: LanguageModelToolUseId, /// The ID of the Assistant message in which the tool use was requested. #[allow(unused)] pub assistant_message_id: MessageId, pub name: Arc, pub ui_text: Arc, pub input: serde_json::Value, pub status: PendingToolUseStatus, pub may_perform_edits: bool, } #[derive(Debug, Clone)] pub struct Confirmation { pub tool_use_id: LanguageModelToolUseId, pub input: serde_json::Value, pub ui_text: Arc, pub request: Arc, pub tool: Arc, } #[derive(Debug, Clone)] pub enum PendingToolUseStatus { InputStillStreaming, Idle, NeedsConfirmation(Arc), Running { _task: Shared> }, Error(#[allow(unused)] Arc), } impl PendingToolUseStatus { pub fn is_idle(&self) -> bool { matches!(self, PendingToolUseStatus::Idle) } pub fn is_error(&self) -> bool { matches!(self, PendingToolUseStatus::Error(_)) } pub fn needs_confirmation(&self) -> bool { matches!(self, PendingToolUseStatus::NeedsConfirmation { .. }) } } #[derive(Clone)] pub struct ToolUseMetadata { pub model: Arc, pub thread_id: ThreadId, pub prompt_id: PromptId, }