From 90fa9217563b2ca79cbfd1b1c2deb11aff2fc551 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 8 Aug 2025 00:21:26 +0200 Subject: [PATCH] Wire up find_path tool in agent2 (#35799) Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 9 + crates/agent2/Cargo.toml | 1 - crates/agent2/src/agent.rs | 10 +- crates/agent2/src/agent2.rs | 1 + crates/agent2/src/templates.rs | 13 - crates/agent2/src/templates/glob.hbs | 8 - crates/agent2/src/tests/mod.rs | 132 +++++++++- crates/agent2/src/tests/test_tools.rs | 82 +++++-- crates/agent2/src/thread.rs | 282 +++++++++++++--------- crates/agent2/src/tools.rs | 6 +- crates/agent2/src/tools/find_path_tool.rs | 231 ++++++++++++++++++ crates/agent2/src/tools/glob.rs | 84 ------- crates/agent2/src/tools/thinking_tool.rs | 48 ++++ crates/agent_servers/src/acp/v0.rs | 1 + crates/agent_servers/src/claude/tools.rs | 1 + crates/agent_ui/src/acp/thread_view.rs | 1 + 18 files changed, 669 insertions(+), 247 deletions(-) delete mode 100644 crates/agent2/src/templates/glob.hbs create mode 100644 crates/agent2/src/tools/find_path_tool.rs delete mode 100644 crates/agent2/src/tools/glob.rs create mode 100644 crates/agent2/src/tools/thinking_tool.rs diff --git a/Cargo.lock b/Cargo.lock index fe0c7a1b23..8c1f1d00ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.21" +version = "0.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7ae3c22c23b64a5c3b7fc8a86fcc7c494e989bd2cd66fdce14a58cfc8078381" +checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8" dependencies = [ "anyhow", "futures 0.3.31", diff --git a/Cargo.toml b/Cargo.toml index 6bff713aaa..d547110bb4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = "0.0.21" +agent-client-protocol = { version = "0.0.23" } aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 1671003023..71827d6948 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -167,6 +167,7 @@ pub struct ToolCall { pub status: ToolCallStatus, pub locations: Vec, pub raw_input: Option, + pub raw_output: Option, } impl ToolCall { @@ -195,6 +196,7 @@ impl ToolCall { locations: tool_call.locations, status, raw_input: tool_call.raw_input, + raw_output: tool_call.raw_output, } } @@ -211,6 +213,7 @@ impl ToolCall { content, locations, raw_input, + raw_output, } = fields; if let Some(kind) = kind { @@ -241,6 +244,10 @@ impl ToolCall { if let Some(raw_input) = raw_input { self.raw_input = Some(raw_input); } + + if let Some(raw_output) = raw_output { + self.raw_output = Some(raw_output); + } } pub fn diffs(&self) -> impl Iterator { @@ -1547,6 +1554,7 @@ mod tests { content: vec![], locations: vec![], raw_input: None, + raw_output: None, }), cx, ) @@ -1659,6 +1667,7 @@ mod tests { }], locations: vec![], raw_input: None, + raw_output: None, }), cx, ) diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 21a043fd98..884378fbcc 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -39,7 +39,6 @@ ui.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true -worktree.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5c0acb3fb1..cb568f04c2 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,5 +1,5 @@ -use crate::ToolCallAuthorization; use crate::{templates::Templates, AgentResponseEvent, Thread}; +use crate::{FindPathTool, ThinkingTool, ToolCallAuthorization}; use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::{anyhow, Context as _, Result}; @@ -412,7 +412,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { anyhow!("No default model configured. Please configure a default model in settings.") })?; - let thread = cx.new(|_| Thread::new(project, agent.project_context.clone(), action_log, agent.templates.clone(), default_model)); + let thread = cx.new(|_| { + let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log, agent.templates.clone(), default_model); + thread.add_tool(ThinkingTool); + thread.add_tool(FindPathTool::new(project.clone())); + thread + }); + Ok(thread) }, )??; diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index d759f63d89..db743c8429 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -10,3 +10,4 @@ mod tests; pub use agent::*; pub use native_agent_server::NativeAgentServer; pub use thread::*; +pub use tools::*; diff --git a/crates/agent2/src/templates.rs b/crates/agent2/src/templates.rs index e634d414d6..a63f0ad206 100644 --- a/crates/agent2/src/templates.rs +++ b/crates/agent2/src/templates.rs @@ -33,19 +33,6 @@ pub trait Template: Sized { } } -#[expect( - dead_code, - reason = "Marked as unused by Rust 1.89 and left as is as of 07 Aug 2025 to let AI team address it." -)] -#[derive(Serialize)] -pub struct GlobTemplate { - pub project_roots: String, -} - -impl Template for GlobTemplate { - const TEMPLATE_NAME: &'static str = "glob.hbs"; -} - #[derive(Serialize)] pub struct SystemPromptTemplate<'a> { #[serde(flatten)] diff --git a/crates/agent2/src/templates/glob.hbs b/crates/agent2/src/templates/glob.hbs deleted file mode 100644 index 3bf992b093..0000000000 --- a/crates/agent2/src/templates/glob.hbs +++ /dev/null @@ -1,8 +0,0 @@ -Find paths on disk with glob patterns. - -Assume that all glob patterns are matched in a project directory with the following entries. - -{{project_roots}} - -When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be -sure to anchor them with one of the directories listed above. diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index b13b1cbe1a..7913f9a24c 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -270,14 +270,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { vec![ MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(), - tool_name: tool_call_auth_1.tool_call.title.into(), + tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), output: None }), MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), - tool_name: tool_call_auth_2.tool_call.title.into(), + tool_name: ToolRequiringPermission.name().into(), is_error: true, content: "Permission to run tool denied by user".into(), output: None @@ -286,6 +286,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_tool_hallucination(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx)); + cx.run_until_parked(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_id_1".into(), + name: "nonexistent_tool".into(), + raw_input: "{}".into(), + input: json!({}), + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!(tool_call.title, "nonexistent_tool"); + assert_eq!(tool_call.status, acp::ToolCallStatus::Pending); + let update = expect_tool_call_update(&mut events).await; + assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed)); +} + +async fn expect_tool_call( + events: &mut UnboundedReceiver>, +) -> acp::ToolCall { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCall(tool_call) => return tool_call, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + +async fn expect_tool_call_update( + events: &mut UnboundedReceiver>, +) -> acp::ToolCallUpdate { + let event = events + .next() + .await + .expect("no tool call authorization event received") + .unwrap(); + match event { + AgentResponseEvent::ToolCallUpdate(tool_call_update) => return tool_call_update, + event => { + panic!("Unexpected event {event:?}"); + } + } +} + async fn next_tool_call_authorization( events: &mut UnboundedReceiver>, ) -> ToolCallAuthorization { @@ -582,6 +639,77 @@ async fn test_agent_connection(cx: &mut TestAppContext) { ); } +#[gpui::test] +async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { + let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await; + thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool)); + let fake_model = model.as_fake(); + + let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx)); + cx.run_until_parked(); + + let input = json!({ "content": "Thinking hard!" }); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "1".into(), + name: ThinkingTool.name().into(), + raw_input: input.to_string(), + input, + is_input_complete: true, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + let tool_call = expect_tool_call(&mut events).await; + assert_eq!( + tool_call, + acp::ToolCall { + id: acp::ToolCallId("1".into()), + title: "Thinking".into(), + kind: acp::ToolKind::Think, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(json!({ "content": "Thinking hard!" })), + raw_output: None, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress,), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + content: Some(vec!["Thinking hard!".into()]), + ..Default::default() + }, + } + ); + let update = expect_tool_call_update(&mut events).await; + assert_eq!( + update, + acp::ToolCallUpdate { + id: acp::ToolCallId("1".into()), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + } + ); +} + /// Filters out the stop events for asserting against in tests fn stop_events( result_events: Vec>, diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index a066bb982e..fd6e7e941f 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -19,11 +19,20 @@ impl AgentTool for EchoTool { "echo".into() } - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - false + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other } - fn run(self: Arc, input: Self::Input, _cx: &mut App) -> Task> { + fn initial_title(&self, _: Self::Input) -> SharedString { + "Echo".into() + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { Task::ready(Ok(input.text)) } } @@ -44,11 +53,20 @@ impl AgentTool for DelayTool { "delay".into() } - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - false + fn initial_title(&self, input: Self::Input) -> SharedString { + format!("Delay {}ms", input.ms).into() } - fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn run( + self: Arc, + input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> where Self: Sized, { @@ -71,16 +89,28 @@ impl AgentTool for ToolRequiringPermission { "tool_requiring_permission".into() } - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - true + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other } - fn run(self: Arc, _input: Self::Input, cx: &mut App) -> Task> + fn initial_title(&self, _input: Self::Input) -> SharedString { + "This tool requires permission".into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> where Self: Sized, { - cx.foreground_executor() - .spawn(async move { Ok("Allowed".to_string()) }) + let auth_check = self.authorize(input, event_stream); + cx.foreground_executor().spawn(async move { + auth_check.await?; + Ok("Allowed".to_string()) + }) } } @@ -96,11 +126,20 @@ impl AgentTool for InfiniteTool { "infinite".into() } - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - false + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other } - fn run(self: Arc, _input: Self::Input, cx: &mut App) -> Task> { + fn initial_title(&self, _input: Self::Input) -> SharedString { + "This is the tool that never ends... it just goes on and on my friends!".into() + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { cx.foreground_executor().spawn(async move { future::pending::<()>().await; unreachable!() @@ -137,11 +176,20 @@ impl AgentTool for WordListTool { "word_list".into() } - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - false + fn initial_title(&self, _input: Self::Input) -> SharedString { + "List of random words".into() } - fn run(self: Arc, _input: Self::Input, _cx: &mut App) -> Task> { + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Other + } + + fn run( + self: Arc, + _input: Self::Input, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { Task::ready(Ok("ok".to_string())) } } diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 9b17d7e37e..805ffff1c0 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,16 +1,16 @@ use crate::templates::{SystemPromptTemplate, Template, Templates}; use agent_client_protocol as acp; use anyhow::{anyhow, Context as _, Result}; -use assistant_tool::ActionLog; +use assistant_tool::{adapt_schema_to_format, ActionLog}; use cloud_llm_client::{CompletionIntent, CompletionMode}; use collections::HashMap; use futures::{ channel::{mpsc, oneshot}, stream::FuturesUnordered, }; -use gpui::{App, Context, Entity, ImageFormat, SharedString, Task}; +use gpui::{App, Context, Entity, SharedString, Task}; use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, @@ -19,7 +19,7 @@ use log; use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc}; use util::{markdown::MarkdownCodeBlock, ResultExt}; @@ -276,7 +276,17 @@ impl Thread { while let Some(tool_result) = tool_uses.next().await { log::info!("Tool finished {:?}", tool_result); - event_stream.send_tool_call_result(&tool_result); + event_stream.send_tool_call_update( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + ..Default::default() + }, + ); thread .update(cx, |thread, _cx| { thread.pending_tool_uses.remove(&tool_result.tool_use_id); @@ -426,6 +436,8 @@ impl Thread { ) -> Option> { cx.notify(); + let tool = self.tools.get(tool_use.name.as_ref()).cloned(); + self.pending_tool_uses .insert(tool_use.id.clone(), tool_use.clone()); let last_message = self.last_assistant_message(); @@ -443,8 +455,9 @@ impl Thread { true } }); + if push_new_tool_use { - event_stream.send_tool_call(&tool_use); + event_stream.send_tool_call(tool.as_ref(), &tool_use); last_message .content .push(MessageContent::ToolUse(tool_use.clone())); @@ -462,37 +475,36 @@ impl Thread { return None; } - if let Some(tool) = self.tools.get(tool_use.name.as_ref()) { - let tool_result = - self.run_tool(tool.clone(), tool_use.clone(), event_stream.clone(), cx); - Some(cx.foreground_executor().spawn(async move { - match tool_result.await { - Ok(tool_output) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), - output: None, - }, - Err(error) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), - output: None, - }, - } - })) - } else { + let Some(tool) = tool else { let content = format!("No tool named {} exists", tool_use.name); - Some(Task::ready(LanguageModelToolResult { + return Some(Task::ready(LanguageModelToolResult { content: LanguageModelToolResultContent::Text(Arc::from(content)), tool_use_id: tool_use.id, tool_name: tool_use.name, is_error: true, output: None, - })) - } + })); + }; + + let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx); + Some(cx.foreground_executor().spawn(async move { + match tool_result.await { + Ok(tool_output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), + output: None, + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: None, + }, + } + })) } fn run_tool( @@ -502,20 +514,14 @@ impl Thread { event_stream: AgentResponseEventStream, cx: &mut Context, ) -> Task> { - let needs_authorization = tool.needs_authorization(tool_use.input.clone(), cx); cx.spawn(async move |_this, cx| { - if needs_authorization? { - event_stream.authorize_tool_call(&tool_use).await?; - } - - event_stream.send_tool_call_update( - &tool_use.id, - acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::InProgress), - ..Default::default() - }, - ); - cx.update(|cx| tool.run(tool_use.input, cx))?.await + let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream); + tool_event_stream.send_update(acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }); + cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))? + .await }) } @@ -584,7 +590,7 @@ impl Thread { name: tool_name, description: tool.description(cx).to_string(), input_schema: tool - .input_schema(LanguageModelToolSchemaFormat::JsonSchema) + .input_schema(self.selected_model.tool_input_format()) .log_err()?, }) }) @@ -651,9 +657,10 @@ pub trait AgentTool where Self: 'static + Sized, { - type Input: for<'de> Deserialize<'de> + JsonSchema; + type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; fn name(&self) -> SharedString; + fn description(&self, _cx: &mut App) -> SharedString { let schema = schemars::schema_for!(Self::Input); SharedString::new( @@ -664,17 +671,33 @@ where ) } + fn kind(&self) -> acp::ToolKind; + + /// The initial tool title to display. Can be updated during the tool run. + fn initial_title(&self, input: Self::Input) -> SharedString; + /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema { + fn input_schema(&self) -> Schema { schemars::schema_for!(Self::Input) } - /// Returns true if the tool needs the users's authorization - /// before running. - fn needs_authorization(&self, input: Self::Input, cx: &App) -> bool; + /// Allows the tool to authorize a given tool call with the user if necessary + fn authorize( + &self, + input: Self::Input, + event_stream: ToolCallEventStream, + ) -> impl use + Future> { + let json_input = serde_json::json!(&input); + event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input) + } /// Runs the tool with the provided input. - fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task>; + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) @@ -686,9 +709,15 @@ pub struct Erased(T); pub trait AnyAgentTool { fn name(&self) -> SharedString; fn description(&self, cx: &mut App) -> SharedString; + fn kind(&self) -> acp::ToolKind; + fn initial_title(&self, input: serde_json::Value) -> Result; fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; - fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result; - fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task>; + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; } impl AnyAgentTool for Erased> @@ -703,22 +732,30 @@ where self.0.description(cx) } + fn kind(&self) -> agent_client_protocol::ToolKind { + self.0.kind() + } + + fn initial_title(&self, input: serde_json::Value) -> Result { + let parsed_input = serde_json::from_value(input)?; + Ok(self.0.initial_title(parsed_input)) + } + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - Ok(serde_json::to_value(self.0.input_schema(format))?) + let mut json = serde_json::to_value(self.0.input_schema())?; + adapt_schema_to_format(&mut json, format)?; + Ok(json) } - fn needs_authorization(&self, input: serde_json::Value, cx: &mut App) -> Result { + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); match parsed_input { - Ok(input) => Ok(self.0.needs_authorization(input, cx)), - Err(error) => Err(anyhow!(error)), - } - } - - fn run(self: Arc, input: serde_json::Value, cx: &mut App) -> Task> { - let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); - match parsed_input { - Ok(input) => self.0.clone().run(input, cx), + Ok(input) => self.0.clone().run(input, event_stream, cx), Err(error) => Task::ready(Err(anyhow!(error))), } } @@ -744,21 +781,16 @@ impl AgentResponseEventStream { fn authorize_tool_call( &self, - tool_use: &LanguageModelToolUse, + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, ) -> impl use<> + Future> { let (response_tx, response_rx) = oneshot::channel(); self.0 .unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization( ToolCallAuthorization { - tool_call: acp::ToolCall { - id: acp::ToolCallId(tool_use.id.to_string().into()), - title: tool_use.name.to_string(), - kind: acp::ToolKind::Other, - status: acp::ToolCallStatus::Pending, - content: vec![], - locations: vec![], - raw_input: Some(tool_use.input.clone()), - }, + tool_call: Self::initial_tool_call(id, title, kind, input), options: vec![ acp::PermissionOption { id: acp::PermissionOptionId("always_allow".into()), @@ -788,20 +820,41 @@ impl AgentResponseEventStream { } } - fn send_tool_call(&self, tool_use: &LanguageModelToolUse) { + fn send_tool_call( + &self, + tool: Option<&Arc>, + tool_use: &LanguageModelToolUse, + ) { self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCall(acp::ToolCall { - id: acp::ToolCallId(tool_use.id.to_string().into()), - title: tool_use.name.to_string(), - kind: acp::ToolKind::Other, - status: acp::ToolCallStatus::Pending, - content: vec![], - locations: vec![], - raw_input: Some(tool_use.input.clone()), - }))) + .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call( + &tool_use.id, + tool.and_then(|t| t.initial_title(tool_use.input.clone()).ok()) + .map(|i| i.into()) + .unwrap_or_else(|| tool_use.name.to_string()), + tool.map(|t| t.kind()).unwrap_or(acp::ToolKind::Other), + tool_use.input.clone(), + )))) .ok(); } + fn initial_tool_call( + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> acp::ToolCall { + acp::ToolCall { + id: acp::ToolCallId(id.to_string().into()), + title, + kind, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(input), + raw_output: None, + } + } + fn send_tool_call_update( &self, tool_use_id: &LanguageModelToolUseId, @@ -817,38 +870,6 @@ impl AgentResponseEventStream { .ok(); } - fn send_tool_call_result(&self, tool_result: &LanguageModelToolResult) { - let status = if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }; - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string().into(), - LanguageModelToolResultContent::Image(LanguageModelImage { source, .. }) => { - acp::ToolCallContent::Content { - content: acp::ContentBlock::Image(acp::ImageContent { - annotations: None, - data: source.to_string(), - mime_type: ImageFormat::Png.mime_type().to_string(), - }), - } - } - }; - self.0 - .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate( - acp::ToolCallUpdate { - id: acp::ToolCallId(tool_result.tool_use_id.to_string().into()), - fields: acp::ToolCallUpdateFields { - status: Some(status), - content: Some(vec![content]), - ..Default::default() - }, - }, - ))) - .ok(); - } - fn send_stop(&self, reason: StopReason) { match reason { StopReason::EndTurn => { @@ -874,3 +895,32 @@ impl AgentResponseEventStream { self.0.unbounded_send(Err(error)).ok(); } } + +#[derive(Clone)] +pub struct ToolCallEventStream { + tool_use_id: LanguageModelToolUseId, + stream: AgentResponseEventStream, +} + +impl ToolCallEventStream { + fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self { + Self { + tool_use_id, + stream, + } + } + + pub fn send_update(&self, fields: acp::ToolCallUpdateFields) { + self.stream.send_tool_call_update(&self.tool_use_id, fields); + } + + pub fn authorize( + &self, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> impl use<> + Future> { + self.stream + .authorize_tool_call(&self.tool_use_id, title, kind, input) + } +} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index cf3162abfa..848fe552ed 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1 +1,5 @@ -mod glob; +mod find_path_tool; +mod thinking_tool; + +pub use find_path_tool::*; +pub use thinking_tool::*; diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs new file mode 100644 index 0000000000..e840fec78c --- /dev/null +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -0,0 +1,231 @@ +use agent_client_protocol as acp; +use anyhow::{anyhow, Result}; +use gpui::{App, AppContext, Entity, SharedString, Task}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; +use std::{cmp, path::PathBuf, sync::Arc}; +use util::paths::PathMatcher; + +use crate::{AgentTool, ToolCallEventStream}; + +/// Fast file path pattern matching tool that works with any codebase size +/// +/// - Supports glob patterns like "**/*.js" or "src/**/*.ts" +/// - Returns matching file paths sorted alphabetically +/// - Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths. +/// - Use this tool when you need to find files by name patterns +/// - Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FindPathToolInput { + /// The glob to match against every path in the project. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1/a/something.txt + /// - directory2/a/things.txt + /// - directory3/a/other.txt + /// + /// You can get back the first two paths by providing a glob of "*thing*.txt" + /// + pub glob: String, + + /// Optional starting position for paginated results (0-based). + /// When not provided, starts from the beginning. + #[serde(default)] + pub offset: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FindPathToolOutput { + paths: Vec, +} + +const RESULTS_PER_PAGE: usize = 50; + +pub struct FindPathTool { + project: Entity, +} + +impl FindPathTool { + pub fn new(project: Entity) -> Self { + Self { project } + } +} + +impl AgentTool for FindPathTool { + type Input = FindPathToolInput; + + fn name(&self) -> SharedString { + "find_path".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Search + } + + fn initial_title(&self, input: Self::Input) -> SharedString { + format!("Find paths matching “`{}`”", input.glob).into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); + + cx.background_spawn(async move { + let matches = search_paths_task.await?; + let paginated_matches: &[PathBuf] = &matches[cmp::min(input.offset, matches.len()) + ..cmp::min(input.offset + RESULTS_PER_PAGE, matches.len())]; + + event_stream.send_update(acp::ToolCallUpdateFields { + title: Some(if paginated_matches.len() == 0 { + "No matches".into() + } else if paginated_matches.len() == 1 { + "1 match".into() + } else { + format!("{} matches", paginated_matches.len()) + }), + content: Some( + paginated_matches + .iter() + .map(|path| acp::ToolCallContent::Content { + content: acp::ContentBlock::ResourceLink(acp::ResourceLink { + uri: format!("file://{}", path.display()), + name: path.to_string_lossy().into(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + }), + }) + .collect(), + ), + raw_output: Some(serde_json::json!({ + "paths": &matches, + })), + ..Default::default() + }); + + if matches.is_empty() { + Ok("No matches found".into()) + } else { + let mut message = format!("Found {} total matches.", matches.len()); + if matches.len() > RESULTS_PER_PAGE { + write!( + &mut message, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", + input.offset + 1, + input.offset + paginated_matches.len() + ) + .unwrap(); + } + + for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) { + write!(&mut message, "\n{}", mat.display()).unwrap(); + } + + Ok(message) + } + }) + } +} + +fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { + let path_matcher = match PathMatcher::new([ + // Sometimes models try to search for "". In this case, return all paths in the project. + if glob.is_empty() { "*" } else { glob }, + ]) { + Ok(matcher) => matcher, + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), + }; + let snapshots: Vec<_> = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + Ok(snapshots + .iter() + .flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }) + .collect()) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_find_path_tool(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + serde_json::json!({ + "apple": { + "banana": { + "carrot": "1", + }, + "bandana": { + "carbonara": "2", + }, + "endive": "3" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let matches = cx + .update(|cx| search_paths("root/**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + + let matches = cx + .update(|cx| search_paths("**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/glob.rs b/crates/agent2/src/tools/glob.rs deleted file mode 100644 index 4dace7c074..0000000000 --- a/crates/agent2/src/tools/glob.rs +++ /dev/null @@ -1,84 +0,0 @@ -use anyhow::{anyhow, Result}; -use gpui::{App, AppContext, Entity, SharedString, Task}; -use project::Project; -use schemars::JsonSchema; -use serde::Deserialize; -use std::{path::PathBuf, sync::Arc}; -use util::paths::PathMatcher; -use worktree::Snapshot as WorktreeSnapshot; - -use crate::{ - templates::{GlobTemplate, Template, Templates}, - thread::AgentTool, -}; - -// Description is dynamic, see `fn description` below -#[derive(Deserialize, JsonSchema)] -struct GlobInput { - /// A POSIX glob pattern - glob: SharedString, -} - -#[expect( - dead_code, - reason = "Marked as unused by Rust 1.89 and left as is as of 07 Aug 2025 to let AI team address it." -)] -struct GlobTool { - project: Entity, - templates: Arc, -} - -impl AgentTool for GlobTool { - type Input = GlobInput; - - fn name(&self) -> SharedString { - "glob".into() - } - - fn description(&self, cx: &mut App) -> SharedString { - let project_roots = self - .project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).root_name().into()) - .collect::>() - .join("\n"); - - GlobTemplate { project_roots } - .render(&self.templates) - .expect("template failed to render") - .into() - } - - fn needs_authorization(&self, _input: Self::Input, _cx: &App) -> bool { - false - } - - fn run(self: Arc, input: Self::Input, cx: &mut App) -> Task> { - let path_matcher = match PathMatcher::new([&input.glob]) { - Ok(matcher) => matcher, - Err(error) => return Task::ready(Err(anyhow!(error))), - }; - - let snapshots: Vec = self - .project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect(); - - cx.background_spawn(async move { - let paths = snapshots.iter().flat_map(|snapshot| { - let root_name = PathBuf::from(snapshot.root_name()); - snapshot - .entries(false, 0) - .map(move |entry| root_name.join(&entry.path)) - .filter(|path| path_matcher.is_match(&path)) - }); - let output = paths - .map(|path| format!("{}\n", path.display())) - .collect::(); - Ok(output) - }) - } -} diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent2/src/tools/thinking_tool.rs new file mode 100644 index 0000000000..bb85d8eceb --- /dev/null +++ b/crates/agent2/src/tools/thinking_tool.rs @@ -0,0 +1,48 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use gpui::{App, SharedString, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +use crate::{AgentTool, ToolCallEventStream}; + +/// A tool for thinking through problems, brainstorming ideas, or planning without executing any actions. +/// Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct ThinkingToolInput { + /// Content to think about. This should be a description of what to think about or + /// a problem to solve. + content: String, +} + +pub struct ThinkingTool; + +impl AgentTool for ThinkingTool { + type Input = ThinkingToolInput; + + fn name(&self) -> SharedString { + "thinking".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Think + } + + fn initial_title(&self, _input: Self::Input) -> SharedString { + "Thinking".into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Task> { + event_stream.send_update(acp::ToolCallUpdateFields { + content: Some(vec![input.content.into()]), + ..Default::default() + }); + Task::ready(Ok("Finished thinking.".to_string())) + } +} diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index e676b7ee46..8d85435f92 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -280,6 +280,7 @@ fn into_new_tool_call(id: acp::ToolCallId, request: acp_old::PushToolCallParams) .map(into_new_tool_call_location) .collect(), raw_input: None, + raw_output: None, } } diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 85b9a13642..7ca150c0bd 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -297,6 +297,7 @@ impl ClaudeTool { content: self.content(), locations: self.locations(), raw_input: None, + raw_output: None, } } } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index ff6da43299..3d1fbba45d 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -2988,6 +2988,7 @@ mod tests { content: vec!["hi".into()], locations: vec![], raw_input: None, + raw_output: None, }; let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)]) .with_permission_requests(HashMap::from_iter([(