diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index bb98f71949..e4fc470744 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -35,6 +35,7 @@ pub struct ActiveThread { list_state: ListState, rendered_messages_by_id: HashMap>, rendered_scripting_tool_uses: HashMap>, + rendered_tool_use_labels: HashMap>, editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, last_error: Option, @@ -70,6 +71,7 @@ impl ActiveThread { messages: Vec::new(), rendered_messages_by_id: HashMap::default(), rendered_scripting_tool_uses: HashMap::default(), + rendered_tool_use_labels: HashMap::default(), expanded_tool_uses: HashMap::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); @@ -86,10 +88,29 @@ impl ActiveThread { for message in thread.read(cx).messages().cloned().collect::>() { this.push_message(&message.id, message.text.clone(), window, cx); - for tool_use in thread.read(cx).scripting_tool_uses_for_message(message.id) { + for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) { + this.render_tool_use_label_markdown( + tool_use.id.clone(), + tool_use.ui_text.clone(), + window, + cx, + ); + } + + for tool_use in thread + .read(cx) + .scripting_tool_uses_for_message(message.id, cx) + { + this.render_tool_use_label_markdown( + tool_use.id.clone(), + tool_use.ui_text.clone(), + window, + cx, + ); + this.render_scripting_tool_use_markdown( tool_use.id.clone(), - tool_use.name.as_ref(), + tool_use.ui_text.as_ref(), tool_use.input.clone(), window, cx, @@ -287,6 +308,19 @@ impl ActiveThread { .insert(tool_use_id, lua_script); } + fn render_tool_use_label_markdown( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_label: impl Into, + window: &mut Window, + cx: &mut Context, + ) { + self.rendered_tool_use_labels.insert( + tool_use_id, + self.render_markdown(tool_label.into(), window, cx), + ); + } + fn handle_thread_event( &mut self, _thread: &Entity, @@ -341,9 +375,18 @@ impl ActiveThread { cx.notify(); } ThreadEvent::UsePendingTools => { - self.thread.update(cx, |thread, cx| { - thread.use_pending_tools(cx); - }); + let tool_uses = self + .thread + .update(cx, |thread, cx| thread.use_pending_tools(cx)); + + for tool_use in tool_uses { + self.render_tool_use_label_markdown( + tool_use.id, + tool_use.ui_text.clone(), + window, + cx, + ); + } } ThreadEvent::ToolFinished { pending_tool_use, @@ -352,6 +395,13 @@ impl ActiveThread { } => { let canceled = *canceled; if let Some(tool_use) = pending_tool_use { + self.render_tool_use_label_markdown( + tool_use.id.clone(), + SharedString::from(tool_use.ui_text.clone()), + window, + cx, + ); + self.render_scripting_tool_use_markdown( tool_use.id.clone(), tool_use.name.as_ref(), @@ -555,8 +605,8 @@ impl ActiveThread { // Get all the data we need from thread before we start using it in closures let checkpoint = thread.checkpoint_for_message(message_id); let context = thread.context_for_message(message_id); - let tool_uses = thread.tool_uses_for_message(message_id); - let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); + let tool_uses = thread.tool_uses_for_message(message_id, cx); + let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id, cx); // Don't render user messages that are just there for returning tool results. if message.role == Role::User @@ -709,27 +759,25 @@ impl ActiveThread { ) .child(div().p_2().child(message_content)), ), - Role::Assistant => { - v_flex() - .id(("message-container", ix)) - .child(div().py_3().px_4().child(message_content)) - .when( - !tool_uses.is_empty() || !scripting_tool_uses.is_empty(), - |parent| { - parent.child( - v_flex() - .children( - tool_uses - .into_iter() - .map(|tool_use| self.render_tool_use(tool_use, cx)), - ) - .children(scripting_tool_uses.into_iter().map(|tool_use| { - self.render_scripting_tool_use(tool_use, cx) - })), - ) - }, - ) - } + Role::Assistant => v_flex() + .id(("message-container", ix)) + .child(div().py_3().px_4().child(message_content)) + .when( + !tool_uses.is_empty() || !scripting_tool_uses.is_empty(), + |parent| { + parent.child( + v_flex() + .children( + tool_uses + .into_iter() + .map(|tool_use| self.render_tool_use(tool_use, cx)), + ) + .children(scripting_tool_uses.into_iter().map(|tool_use| { + self.render_scripting_tool_use(tool_use, window, cx) + })), + ) + }, + ), Role::System => div().id(("message-container", ix)).py_1().px_2().child( v_flex() .bg(colors.editor_background) @@ -805,11 +853,10 @@ impl ActiveThread { } }), )) - .child( - Label::new(tool_use.name) - .size(LabelSize::Small) - .buffer_font(cx), - ), + .child(div().text_ui_sm(cx).children( + self.rendered_tool_use_labels.get(&tool_use.id).cloned(), + )) + .truncate(), ) .child({ let (icon_name, color, animated) = match &tool_use.status { @@ -937,6 +984,7 @@ impl ActiveThread { fn render_scripting_tool_use( &self, tool_use: ToolUse, + window: &Window, cx: &mut Context, ) -> impl IntoElement { let is_open = self @@ -982,7 +1030,12 @@ impl ActiveThread { } }), )) - .child(Label::new(tool_use.name)), + .child(div().text_ui_sm(cx).child(self.render_markdown( + tool_use.ui_text.clone(), + window, + cx, + ))) + .truncate(), ) .child( Label::new(match tool_use.status { diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index f91600256f..c38aaff44f 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -458,7 +458,7 @@ impl AssistantPanel { workspace.update_in(cx, |workspace, window, cx| { let thread = thread.read(cx); - let markdown = thread.to_markdown()?; + let markdown = thread.to_markdown(cx)?; let thread_summary = thread .summary() .map(|summary| summary.to_string()) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 0257ff40ed..bbab5858ed 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -146,10 +146,10 @@ impl Thread { pending_completions: Vec::new(), project: project.clone(), prompt_builder, - tools, - tool_use: ToolUseState::new(), + tools: tools.clone(), + tool_use: ToolUseState::new(tools.clone()), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), - scripting_tool_use: ToolUseState::new(), + scripting_tool_use: ToolUseState::new(tools), action_log: cx.new(|_| ActionLog::new()), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project, cx); @@ -176,11 +176,12 @@ impl Thread { .map(|message| message.id.0 + 1) .unwrap_or(0), ); - let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| { - name != ScriptingTool::NAME - }); + let tool_use = + ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| { + name != ScriptingTool::NAME + }); let scripting_tool_use = - ToolUseState::from_serialized_messages(&serialized.messages, |name| { + ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |name| { name == ScriptingTool::NAME }); let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx)); @@ -328,12 +329,12 @@ impl Thread { all_pending_tool_uses.all(|tool_use| tool_use.status.is_error()) } - pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { - self.tool_use.tool_uses_for_message(id) + pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { + self.tool_use.tool_uses_for_message(id, cx) } - pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec { - self.scripting_tool_use.tool_uses_for_message(id) + pub fn scripting_tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { + self.scripting_tool_use.tool_uses_for_message(id, cx) } pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> { @@ -448,7 +449,7 @@ impl Thread { let initial_project_snapshot = self.initial_project_snapshot.clone(); cx.spawn(async move |this, cx| { let initial_project_snapshot = initial_project_snapshot.await; - this.read_with(cx, |this, _| SerializedThread { + this.read_with(cx, |this, cx| SerializedThread { summary: this.summary_or_default(), updated_at: this.updated_at(), messages: this @@ -458,9 +459,9 @@ impl Thread { role: message.role, text: message.text.clone(), tool_uses: this - .tool_uses_for_message(message.id) + .tool_uses_for_message(message.id, cx) .into_iter() - .chain(this.scripting_tool_uses_for_message(message.id)) + .chain(this.scripting_tool_uses_for_message(message.id, cx)) .map(|tool_use| SerializedToolUse { id: tool_use.id, name: tool_use.name, @@ -809,13 +810,17 @@ impl Thread { .rfind(|message| message.role == Role::Assistant) { if tool_use.name.as_ref() == ScriptingTool::NAME { - thread - .scripting_tool_use - .request_tool_use(last_assistant_message.id, tool_use); + thread.scripting_tool_use.request_tool_use( + last_assistant_message.id, + tool_use, + cx, + ); } else { - thread - .tool_use - .request_tool_use(last_assistant_message.id, tool_use); + thread.tool_use.request_tool_use( + last_assistant_message.id, + tool_use, + cx, + ); } } } @@ -956,7 +961,10 @@ impl Thread { }); } - pub fn use_pending_tools(&mut self, cx: &mut Context) { + pub fn use_pending_tools( + &mut self, + cx: &mut Context, + ) -> impl IntoIterator { let request = self.to_completion_request(RequestKind::Chat, cx); let pending_tool_uses = self .tool_use @@ -966,17 +974,22 @@ impl Thread { .cloned() .collect::>(); - for tool_use in pending_tool_uses { + for tool_use in pending_tool_uses.iter() { if let Some(tool) = self.tools.tool(&tool_use.name, cx) { let task = tool.run( - tool_use.input, + tool_use.input.clone(), &request.messages, self.project.clone(), self.action_log.clone(), cx, ); - self.insert_tool_output(tool_use.id.clone(), task, cx); + self.insert_tool_output( + tool_use.id.clone(), + tool_use.ui_text.clone().into(), + task, + cx, + ); } } @@ -988,8 +1001,8 @@ impl Thread { .cloned() .collect::>(); - for scripting_tool_use in pending_scripting_tool_uses { - let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) { + for scripting_tool_use in pending_scripting_tool_uses.iter() { + let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) { Err(err) => Task::ready(Err(err.into())), Ok(input) => { let (script_id, script_task) = @@ -1016,13 +1029,20 @@ impl Thread { } }; - self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx); + let ui_text: SharedString = scripting_tool_use.name.clone().into(); + + self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx); } + + pending_tool_uses + .into_iter() + .chain(pending_scripting_tool_uses) } pub fn insert_tool_output( &mut self, tool_use_id: LanguageModelToolUseId, + ui_text: SharedString, output: Task>, cx: &mut Context, ) { @@ -1047,12 +1067,13 @@ impl Thread { }); self.tool_use - .run_pending_tool(tool_use_id, insert_output_task); + .run_pending_tool(tool_use_id, ui_text, insert_output_task); } pub fn insert_scripting_tool_output( &mut self, tool_use_id: LanguageModelToolUseId, + ui_text: SharedString, output: Task>, cx: &mut Context, ) { @@ -1077,7 +1098,7 @@ impl Thread { }); self.scripting_tool_use - .run_pending_tool(tool_use_id, insert_output_task); + .run_pending_tool(tool_use_id, ui_text, insert_output_task); } pub fn attach_tool_results( @@ -1250,7 +1271,7 @@ impl Thread { }) } - pub fn to_markdown(&self) -> Result { + pub fn to_markdown(&self, cx: &App) -> Result { let mut markdown = Vec::new(); if let Some(summary) = self.summary() { @@ -1269,7 +1290,7 @@ impl Thread { )?; writeln!(markdown, "{}\n", message.text)?; - for tool_use in self.tool_uses_for_message(message.id) { + for tool_use in self.tool_uses_for_message(message.id, cx) { writeln!( markdown, "**Use Tool: {} ({})**", diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 42020c4ce5..d020eab72b 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -1,10 +1,11 @@ use std::sync::Arc; use anyhow::Result; +use assistant_tool::ToolWorkingSet; use collections::HashMap; use futures::future::Shared; use futures::FutureExt as _; -use gpui::{SharedString, Task}; +use gpui::{App, SharedString, Task}; use language_model::{ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, @@ -17,6 +18,7 @@ use crate::thread_store::SerializedMessage; pub struct ToolUse { pub id: LanguageModelToolUseId, pub name: SharedString, + pub ui_text: SharedString, pub status: ToolUseStatus, pub input: serde_json::Value, } @@ -30,6 +32,7 @@ pub enum ToolUseStatus { } pub struct ToolUseState { + tools: Arc, tool_uses_by_assistant_message: HashMap>, tool_uses_by_user_message: HashMap>, tool_results: HashMap, @@ -37,8 +40,9 @@ pub struct ToolUseState { } impl ToolUseState { - pub fn new() -> Self { + pub fn new(tools: Arc) -> Self { Self { + tools, tool_uses_by_assistant_message: HashMap::default(), tool_uses_by_user_message: HashMap::default(), tool_results: HashMap::default(), @@ -50,10 +54,11 @@ impl ToolUseState { /// /// Accepts a function to filter the tools that should be used to populate the state. pub fn from_serialized_messages( + tools: Arc, messages: &[SerializedMessage], mut filter_by_tool_name: impl FnMut(&str) -> bool, ) -> Self { - let mut this = Self::new(); + let mut this = Self::new(tools); let mut tool_names_by_id = HashMap::default(); for message in messages { @@ -138,7 +143,7 @@ impl ToolUseState { self.pending_tool_uses_by_id.values().collect() } - pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { + pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; @@ -173,6 +178,7 @@ impl ToolUseState { tool_uses.push(ToolUse { id: tool_use.id.clone(), name: tool_use.name.clone().into(), + ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx), input: tool_use.input.clone(), status, }) @@ -181,6 +187,19 @@ impl ToolUseState { tool_uses } + pub fn tool_ui_label( + &self, + tool_name: &str, + input: &serde_json::Value, + cx: &App, + ) -> SharedString { + if let Some(tool) = self.tools.tool(tool_name, cx) { + tool.ui_text(input).into() + } else { + "Unknown tool".into() + } + } + pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> { let empty = Vec::new(); @@ -209,6 +228,7 @@ impl ToolUseState { &mut self, assistant_message_id: MessageId, tool_use: LanguageModelToolUse, + cx: &App, ) { self.tool_uses_by_assistant_message .entry(assistant_message_id) @@ -228,15 +248,24 @@ impl ToolUseState { PendingToolUse { assistant_message_id, id: tool_use.id, - name: tool_use.name, + name: tool_use.name.clone(), + ui_text: self + .tool_ui_label(&tool_use.name, &tool_use.input, cx) + .into(), input: tool_use.input, status: PendingToolUseStatus::Idle, }, ); } - pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) { + pub fn run_pending_tool( + &mut self, + tool_use_id: LanguageModelToolUseId, + ui_text: SharedString, + task: Task<()>, + ) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { + tool_use.ui_text = ui_text.into(); tool_use.status = PendingToolUseStatus::Running { _task: task.shared(), }; @@ -335,6 +364,7 @@ pub struct PendingToolUse { #[allow(unused)] pub assistant_message_id: MessageId, pub name: Arc, + pub ui_text: Arc, pub input: serde_json::Value, pub status: PendingToolUseStatus, } diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index d26b03bee2..0eb63d84ac 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -128,12 +128,7 @@ impl HeadlessAssistant { } } } - ThreadEvent::StreamedCompletion - | ThreadEvent::SummaryChanged - | ThreadEvent::StreamedAssistantText(_, _) - | ThreadEvent::MessageAdded(_) - | ThreadEvent::MessageEdited(_) - | ThreadEvent::MessageDeleted(_) => {} + _ => {} } } } diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 22564bc37f..4db31d79d4 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -5,8 +5,7 @@ use std::sync::Arc; use anyhow::Result; use collections::{HashMap, HashSet}; -use gpui::Context; -use gpui::{App, Entity, SharedString, Task}; +use gpui::{App, Context, Entity, SharedString, Task}; use language::Buffer; use language_model::LanguageModelRequestMessage; use project::Project; @@ -44,6 +43,9 @@ pub trait Tool: 'static + Send + Sync { serde_json::Value::Object(serde_json::Map::default()) } + /// Returns markdown to be displayed in the UI for this tool. + fn ui_text(&self, input: &serde_json::Value) -> String; + /// Runs the tool with the provided input. fn run( self: Arc, diff --git a/crates/assistant_tools/src/bash_tool.rs b/crates/assistant_tools/src/bash_tool.rs index 6befc4dbae..4648da2881 100644 --- a/crates/assistant_tools/src/bash_tool.rs +++ b/crates/assistant_tools/src/bash_tool.rs @@ -32,6 +32,13 @@ impl Tool for BashTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("`$ {}`", input.command), + Err(_) => "Run bash command".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index 4381c26b08..12b1a6d33f 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -39,6 +39,13 @@ impl Tool for DeletePathTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("Delete “`{}`”", input.path), + Err(_) => "Delete path".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, @@ -59,13 +66,12 @@ impl Tool for DeletePathTool { { Some(deletion_task) => cx.background_spawn(async move { match deletion_task.await { - Ok(()) => Ok(format!("Deleted {}", &path_str)), - Err(err) => Err(anyhow!("Failed to delete {}: {}", &path_str, err)), + Ok(()) => Ok(format!("Deleted {path_str}")), + Err(err) => Err(anyhow!("Failed to delete {path_str}: {err}")), } }), None => Task::ready(Err(anyhow!( - "Couldn't delete {} because that path isn't in this project.", - path_str + "Couldn't delete {path_str} because that path isn't in this project." ))), } } diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 62d77b7273..c7036ff9cb 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -46,6 +46,17 @@ impl Tool for DiagnosticsTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + if let Some(path) = serde_json::from_value::(input.clone()) + .ok() + .and_then(|input| input.path) + { + format!("Check diagnostics for “`{}`”", path.display()) + } else { + "Check project diagnostics".to_string() + } + } + fn run( self: Arc, input: serde_json::Value, @@ -54,14 +65,15 @@ impl Tool for DiagnosticsTool { _action_log: Entity, cx: &mut App, ) -> Task> { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))), - }; - - if let Some(path) = input.path { + if let Some(path) = serde_json::from_value::(input) + .ok() + .and_then(|input| input.path) + { let Some(project_path) = project.read(cx).find_project_path(&path, cx) else { - return Task::ready(Err(anyhow!("Could not find path in project"))); + return Task::ready(Err(anyhow!( + "Could not find path {} in project", + path.display() + ))); }; let buffer = project.update(cx, |project, cx| project.open_buffer(project_path, cx)); diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index be0beb2d52..d0a36b4379 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -24,10 +24,7 @@ use util::ResultExt; pub struct EditFilesToolInput { /// High-level edit instructions. These will be interpreted by a smaller /// model, so explain the changes you want that model to make and which - /// file paths need changing. - /// - /// The description should be concise and clear. We will show this - /// description to the user as well. + /// file paths need changing. The description should be concise and clear. /// /// WARNING: When specifying which file paths need changing, you MUST /// start each path with one of the project's root directories. @@ -58,6 +55,21 @@ pub struct EditFilesToolInput { /// Notice how we never specify code snippets in the instructions! /// pub edit_instructions: String, + + /// A user-friendly description of what changes are being made. + /// This will be shown to the user in the UI to describe the edit operation. The screen real estate for this UI will be extremely + /// constrained, so make the description extremely terse. + /// + /// + /// For fixing a broken authentication system: + /// "Fix auth bug in login flow" + /// + /// + /// + /// For adding unit tests to a module: + /// "Add tests for user profile logic" + /// + pub display_description: String, } pub struct EditFilesTool; @@ -76,6 +88,13 @@ impl Tool for EditFilesTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => input.display_description, + Err(_) => "Edit files".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index ecdc9dddd5..1626d76ed2 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -122,6 +122,13 @@ impl Tool for FetchTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("Fetch `{}`", input.url), + Err(_) => "Fetch URL".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index b43010b6ef..d7501ddce7 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -50,6 +50,13 @@ impl Tool for ListDirectoryTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("List the `{}` directory's contents", input.path.display()), + Err(_) => "List directory".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, @@ -64,7 +71,10 @@ impl Tool for ListDirectoryTool { }; let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path not found in project"))); + return Task::ready(Err(anyhow!( + "Path {} not found in project", + input.path.display() + ))); }; let Some(worktree) = project .read(cx) @@ -79,7 +89,7 @@ impl Tool for ListDirectoryTool { }; if !entry.is_dir() { - return Task::ready(Err(anyhow!("{} is a file.", input.path.display()))); + return Task::ready(Err(anyhow!("{} is not a directory.", input.path.display()))); } let mut output = String::new(); diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index e10bddb669..dce45fe69b 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -40,6 +40,10 @@ impl Tool for NowTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, _input: &serde_json::Value) -> String { + "Get current time".to_string() + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index c4e9250892..4507bbadbc 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -48,6 +48,13 @@ impl Tool for PathSearchTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("Find paths matching “`{}`”", input.glob), + Err(_) => "Search paths".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, @@ -62,7 +69,7 @@ impl Tool for PathSearchTool { }; let path_matcher = match PathMatcher::new(&[glob.clone()]) { Ok(matcher) => matcher, - Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {}", err))), + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), }; let snapshots: Vec = project .read(cx) diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index ae24687482..c25ab6606a 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -53,6 +53,13 @@ impl Tool for ReadFileTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => format!("Read file `{}`", input.path.display()), + Err(_) => "Read file".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, @@ -67,7 +74,10 @@ impl Tool for ReadFileTool { }; let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path not found in project"))); + return Task::ready(Err(anyhow!( + "Path {} not found in project", + &input.path.display() + ))); }; cx.spawn(async move |cx| { diff --git a/crates/assistant_tools/src/regex_search_tool.rs b/crates/assistant_tools/src/regex_search_tool.rs index 2849870846..20630fddbc 100644 --- a/crates/assistant_tools/src/regex_search_tool.rs +++ b/crates/assistant_tools/src/regex_search_tool.rs @@ -22,10 +22,17 @@ pub struct RegexSearchToolInput { /// Optional starting position for paginated results (0-based). /// When not provided, starts from the beginning. #[serde(default)] - pub offset: Option, + pub offset: Option, } -const RESULTS_PER_PAGE: usize = 20; +impl RegexSearchToolInput { + /// Which page of search results this is. + pub fn page(&self) -> u32 { + 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE) + } +} + +const RESULTS_PER_PAGE: u32 = 20; pub struct RegexSearchTool; @@ -43,6 +50,24 @@ impl Tool for RegexSearchTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => { + let page = input.page(); + + if page > 1 { + format!( + "Get page {page} of search results for regex “`{}`”", + input.regex + ) + } else { + format!("Search files for regex “`{}`”", input.regex) + } + } + Err(_) => "Search with regex".to_string(), + } + } + fn run( self: Arc, input: serde_json::Value, @@ -154,7 +179,7 @@ impl Tool for RegexSearchTool { offset + matches_found, offset + RESULTS_PER_PAGE, )) - } else { + } else { Ok(format!("Found {matches_found} matches:\n{output}")) } }) diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index e11faf86b6..d645c0fb5d 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -31,6 +31,10 @@ impl Tool for ThinkingTool { serde_json::to_value(&schema).unwrap() } + fn ui_text(&self, _input: &serde_json::Value) -> String { + "Thinking".to_string() + } + fn run( self: Arc, input: serde_json::Value, diff --git a/crates/context_server/src/context_server_tool.rs b/crates/context_server/src/context_server_tool.rs index 315bb6bce3..f0bb36537f 100644 --- a/crates/context_server/src/context_server_tool.rs +++ b/crates/context_server/src/context_server_tool.rs @@ -56,6 +56,10 @@ impl Tool for ContextServerTool { } } + fn ui_text(&self, _input: &serde_json::Value) -> String { + format!("Run MCP tool `{}`", self.tool.name) + } + fn run( self: Arc, input: serde_json::Value, @@ -65,42 +69,43 @@ impl Tool for ContextServerTool { cx: &mut App, ) -> Task> { if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { - cx.foreground_executor().spawn({ - let tool_name = self.tool.name.clone(); - async move { - let Some(protocol) = server.client() else { - bail!("Context server not initialized"); - }; + let tool_name = self.tool.name.clone(); + let server_clone = server.clone(); + let input_clone = input.clone(); - let arguments = if let serde_json::Value::Object(map) = input { - Some(map.into_iter().collect()) - } else { - None - }; + cx.spawn(async move |_cx| { + let Some(protocol) = server_clone.client() else { + bail!("Context server not initialized"); + }; - log::trace!( - "Running tool: {} with arguments: {:?}", - tool_name, - arguments - ); - let response = protocol.run_tool(tool_name, arguments).await?; + let arguments = if let serde_json::Value::Object(map) = input_clone { + Some(map.into_iter().collect()) + } else { + None + }; - let mut result = String::new(); - for content in response.content { - match content { - types::ToolResponseContent::Text { text } => { - result.push_str(&text); - } - types::ToolResponseContent::Image { .. } => { - log::warn!("Ignoring image content from tool response"); - } - types::ToolResponseContent::Resource { .. } => { - log::warn!("Ignoring resource content from tool response"); - } + log::trace!( + "Running tool: {} with arguments: {:?}", + tool_name, + arguments + ); + let response = protocol.run_tool(tool_name, arguments).await?; + + let mut result = String::new(); + for content in response.content { + match content { + types::ToolResponseContent::Text { text } => { + result.push_str(&text); + } + types::ToolResponseContent::Image { .. } => { + log::warn!("Ignoring image content from tool response"); + } + types::ToolResponseContent::Resource { .. } => { + log::warn!("Ignoring resource content from tool response"); } } - Ok(result) } + Ok(result) }) } else { Task::ready(Err(anyhow!("Context server not found")))