diff --git a/Cargo.lock b/Cargo.lock index 0f0e78bb48..6f434e8685 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -158,8 +158,10 @@ dependencies = [ "acp_thread", "agent-client-protocol", "agent_servers", + "agent_settings", "anyhow", "assistant_tool", + "assistant_tools", "client", "clock", "cloud_llm_client", @@ -177,6 +179,8 @@ dependencies = [ "language_model", "language_models", "log", + "lsp", + "paths", "pretty_assertions", "project", "prompt_store", diff --git a/Cargo.toml b/Cargo.toml index d547110bb4..998e727602 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -425,7 +425,7 @@ zlog_settings = { path = "crates/zlog_settings" } # agentic-coding-protocol = "0.0.10" -agent-client-protocol = { version = "0.0.23" } +agent-client-protocol = "0.0.23" aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 443375a51b..54bfe56a15 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1,18 +1,17 @@ mod connection; +mod diff; + pub use connection::*; +pub use diff::*; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; use assistant_tool::ActionLog; -use buffer_diff::BufferDiff; -use editor::{Bias, MultiBuffer, PathKey}; +use editor::Bias; use futures::{FutureExt, channel::oneshot, future::BoxFuture}; use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task}; use itertools::Itertools; -use language::{ - Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point, - text_diff, -}; +use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff}; use markdown::Markdown; use project::{AgentLocation, Project}; use std::collections::HashMap; @@ -140,7 +139,7 @@ impl AgentThreadEntry { } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { if let AgentThreadEntry::ToolCall(call) = self { itertools::Either::Left(call.diffs()) } else { @@ -249,7 +248,7 @@ impl ToolCall { } } - pub fn diffs(&self) -> impl Iterator { + pub fn diffs(&self) -> impl Iterator> { self.content.iter().filter_map(|content| match content { ToolCallContent::ContentBlock { .. } => None, ToolCallContent::Diff { diff } => Some(diff), @@ -389,7 +388,7 @@ impl ContentBlock { #[derive(Debug)] pub enum ToolCallContent { ContentBlock { content: ContentBlock }, - Diff { diff: Diff }, + Diff { diff: Entity }, } impl ToolCallContent { @@ -403,7 +402,7 @@ impl ToolCallContent { content: ContentBlock::new(content, &language_registry, cx), }, acp::ToolCallContent::Diff { diff } => Self::Diff { - diff: Diff::from_acp(diff, language_registry, cx), + diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)), }, } } @@ -411,108 +410,11 @@ impl ToolCallContent { pub fn to_markdown(&self, cx: &App) -> String { match self { Self::ContentBlock { content } => content.to_markdown(cx).to_string(), - Self::Diff { diff } => diff.to_markdown(cx), + Self::Diff { diff } => diff.read(cx).to_markdown(cx), } } } -#[derive(Debug)] -pub struct Diff { - pub multibuffer: Entity, - pub path: PathBuf, - _task: Task>, -} - -impl Diff { - pub fn from_acp( - diff: acp::Diff, - language_registry: Arc, - cx: &mut App, - ) -> Self { - let acp::Diff { - path, - old_text, - new_text, - } = diff; - - let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); - - let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); - let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); - let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); - let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); - - let task = cx.spawn({ - let multibuffer = multibuffer.clone(); - let path = path.clone(); - async move |cx| { - let language = language_registry - .language_for_file_path(&path) - .await - .log_err(); - - new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; - - let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { - buffer.set_language(language, cx); - buffer.snapshot() - })?; - - buffer_diff - .update(cx, |diff, cx| { - diff.set_base_text( - old_buffer_snapshot, - Some(language_registry), - new_buffer_snapshot, - cx, - ) - })? - .await?; - - multibuffer - .update(cx, |multibuffer, cx| { - let hunk_ranges = { - let buffer = new_buffer.read(cx); - let diff = buffer_diff.read(cx); - diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) - .collect::>() - }; - - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&new_buffer, cx), - new_buffer.clone(), - hunk_ranges, - editor::DEFAULT_MULTIBUFFER_CONTEXT, - cx, - ); - multibuffer.add_diff(buffer_diff, cx); - }) - .log_err(); - - anyhow::Ok(()) - } - }); - - Self { - multibuffer, - path, - _task: task, - } - } - - fn to_markdown(&self, cx: &App) -> String { - let buffer_text = self - .multibuffer - .read(cx) - .all_buffers() - .iter() - .map(|buffer| buffer.read(cx).text()) - .join("\n"); - format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text) - } -} - #[derive(Debug, Default)] pub struct Plan { pub entries: Vec, @@ -823,6 +725,21 @@ impl AcpThread { Ok(()) } + pub fn set_tool_call_diff( + &mut self, + tool_call_id: &acp::ToolCallId, + diff: Entity, + cx: &mut Context, + ) -> Result<()> { + let (ix, current_call) = self + .tool_call_mut(tool_call_id) + .context("Tool call not found")?; + current_call.content.clear(); + current_call.content.push(ToolCallContent::Diff { diff }); + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + Ok(()) + } + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { let status = ToolCallStatus::Allowed { diff --git a/crates/acp_thread/src/diff.rs b/crates/acp_thread/src/diff.rs new file mode 100644 index 0000000000..9cc6271360 --- /dev/null +++ b/crates/acp_thread/src/diff.rs @@ -0,0 +1,388 @@ +use agent_client_protocol as acp; +use anyhow::Result; +use buffer_diff::{BufferDiff, BufferDiffSnapshot}; +use editor::{MultiBuffer, PathKey}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task}; +use itertools::Itertools; +use language::{ + Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _, Point, Rope, TextBuffer, +}; +use std::{ + cmp::Reverse, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::ResultExt; + +pub enum Diff { + Pending(PendingDiff), + Finalized(FinalizedDiff), +} + +impl Diff { + pub fn from_acp( + diff: acp::Diff, + language_registry: Arc, + cx: &mut Context, + ) -> Self { + let acp::Diff { + path, + old_text, + new_text, + } = diff; + + let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly)); + + let new_buffer = cx.new(|cx| Buffer::local(new_text, cx)); + let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx)); + let new_buffer_snapshot = new_buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx)); + + let task = cx.spawn({ + let multibuffer = multibuffer.clone(); + let path = path.clone(); + async move |_, cx| { + let language = language_registry + .language_for_file_path(&path) + .await + .log_err(); + + new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?; + + let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| { + buffer.set_language(language, cx); + buffer.snapshot() + })?; + + buffer_diff + .update(cx, |diff, cx| { + diff.set_base_text( + old_buffer_snapshot, + Some(language_registry), + new_buffer_snapshot, + cx, + ) + })? + .await?; + + multibuffer + .update(cx, |multibuffer, cx| { + let hunk_ranges = { + let buffer = new_buffer.read(cx); + let diff = buffer_diff.read(cx); + diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>() + }; + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&new_buffer, cx), + new_buffer.clone(), + hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff, cx); + }) + .log_err(); + + anyhow::Ok(()) + } + }); + + Self::Finalized(FinalizedDiff { + multibuffer, + path, + _update_diff: task, + }) + } + + pub fn new(buffer: Entity, cx: &mut Context) -> Self { + let buffer_snapshot = buffer.read(cx).snapshot(); + let base_text = buffer_snapshot.text(); + let language_registry = buffer.read(cx).language_registry(); + let text_snapshot = buffer.read(cx).text_snapshot(); + let buffer_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&text_snapshot, cx); + let _ = diff.set_base_text( + buffer_snapshot.clone(), + language_registry, + text_snapshot, + cx, + ); + diff + }); + + let multibuffer = cx.new(|cx| { + let mut multibuffer = MultiBuffer::without_headers(Capability::ReadOnly); + multibuffer.add_diff(buffer_diff.clone(), cx); + multibuffer + }); + + Self::Pending(PendingDiff { + multibuffer, + base_text: Arc::new(base_text), + _subscription: cx.observe(&buffer, |this, _, cx| { + if let Diff::Pending(diff) = this { + diff.update(cx); + } + }), + buffer, + diff: buffer_diff, + revealed_ranges: Vec::new(), + update_diff: Task::ready(Ok(())), + }) + } + + pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { + if let Self::Pending(diff) = self { + diff.reveal_range(range, cx); + } + } + + pub fn finalize(&mut self, cx: &mut Context) { + if let Self::Pending(diff) = self { + *self = Self::Finalized(diff.finalize(cx)); + } + } + + pub fn multibuffer(&self) -> &Entity { + match self { + Self::Pending(PendingDiff { multibuffer, .. }) => multibuffer, + Self::Finalized(FinalizedDiff { multibuffer, .. }) => multibuffer, + } + } + + pub fn to_markdown(&self, cx: &App) -> String { + let buffer_text = self + .multibuffer() + .read(cx) + .all_buffers() + .iter() + .map(|buffer| buffer.read(cx).text()) + .join("\n"); + let path = match self { + Diff::Pending(PendingDiff { buffer, .. }) => { + buffer.read(cx).file().map(|file| file.path().as_ref()) + } + Diff::Finalized(FinalizedDiff { path, .. }) => Some(path.as_path()), + }; + format!( + "Diff: {}\n```\n{}\n```\n", + path.unwrap_or(Path::new("untitled")).display(), + buffer_text + ) + } +} + +pub struct PendingDiff { + multibuffer: Entity, + base_text: Arc, + buffer: Entity, + diff: Entity, + revealed_ranges: Vec>, + _subscription: Subscription, + update_diff: Task>, +} + +impl PendingDiff { + pub fn update(&mut self, cx: &mut Context) { + let buffer = self.buffer.clone(); + let buffer_diff = self.diff.clone(); + let base_text = self.base_text.clone(); + self.update_diff = cx.spawn(async move |diff, cx| { + let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?; + let diff_snapshot = BufferDiff::update_diff( + buffer_diff.clone(), + text_snapshot.clone(), + Some(base_text), + false, + false, + None, + None, + cx, + ) + .await?; + buffer_diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &text_snapshot, cx) + })?; + diff.update(cx, |diff, cx| { + if let Diff::Pending(diff) = diff { + diff.update_visible_ranges(cx); + } + }) + }); + } + + pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { + self.revealed_ranges.push(range); + self.update_visible_ranges(cx); + } + + fn finalize(&self, cx: &mut Context) -> FinalizedDiff { + let ranges = self.excerpt_ranges(cx); + let base_text = self.base_text.clone(); + let language_registry = self.buffer.read(cx).language_registry().clone(); + + let path = self + .buffer + .read(cx) + .file() + .map(|file| file.path().as_ref()) + .unwrap_or(Path::new("untitled")) + .into(); + + // Replace the buffer in the multibuffer with the snapshot + let buffer = cx.new(|cx| { + let language = self.buffer.read(cx).language().cloned(); + let buffer = TextBuffer::new_normalized( + 0, + cx.entity_id().as_non_zero_u64().into(), + self.buffer.read(cx).line_ending(), + self.buffer.read(cx).as_rope().clone(), + ); + let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); + buffer.set_language(language, cx); + buffer + }); + + let buffer_diff = cx.spawn({ + let buffer = buffer.clone(); + let language_registry = language_registry.clone(); + async move |_this, cx| { + build_buffer_diff(base_text, &buffer, language_registry, cx).await + } + }); + + let update_diff = cx.spawn(async move |this, cx| { + let buffer_diff = buffer_diff.await?; + this.update(cx, |this, cx| { + this.multibuffer().update(cx, |multibuffer, cx| { + let path_key = PathKey::for_buffer(&buffer, cx); + multibuffer.clear(cx); + multibuffer.set_excerpts_for_path( + path_key, + buffer, + ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff.clone(), cx); + }); + + cx.notify(); + }) + }); + + FinalizedDiff { + path, + multibuffer: self.multibuffer.clone(), + _update_diff: update_diff, + } + } + + fn update_visible_ranges(&mut self, cx: &mut Context) { + let ranges = self.excerpt_ranges(cx); + self.multibuffer.update(cx, |multibuffer, cx| { + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&self.buffer, cx), + self.buffer.clone(), + ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + let end = multibuffer.len(cx); + Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1) + }); + cx.notify(); + } + + fn excerpt_ranges(&self, cx: &App) -> Vec> { + let buffer = self.buffer.read(cx); + let diff = self.diff.read(cx); + let mut ranges = diff + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>(); + ranges.extend( + self.revealed_ranges + .iter() + .map(|range| range.to_point(&buffer)), + ); + ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); + + // Merge adjacent ranges + let mut ranges = ranges.into_iter().peekable(); + let mut merged_ranges = Vec::new(); + while let Some(mut range) = ranges.next() { + while let Some(next_range) = ranges.peek() { + if range.end >= next_range.start { + range.end = range.end.max(next_range.end); + ranges.next(); + } else { + break; + } + } + + merged_ranges.push(range); + } + merged_ranges + } +} + +pub struct FinalizedDiff { + path: PathBuf, + multibuffer: Entity, + _update_diff: Task>, +} + +async fn build_buffer_diff( + old_text: Arc, + buffer: &Entity, + language_registry: Option>, + cx: &mut AsyncApp, +) -> Result> { + let buffer = cx.update(|cx| buffer.read(cx).snapshot())?; + + let old_text_rope = cx + .background_spawn({ + let old_text = old_text.clone(); + async move { Rope::from(old_text.as_str()) } + }) + .await; + let base_buffer = cx + .update(|cx| { + Buffer::build_snapshot( + old_text_rope, + buffer.language().cloned(), + language_registry, + cx, + ) + })? + .await; + + let diff_snapshot = cx + .update(|cx| { + BufferDiffSnapshot::new_with_base_buffer( + buffer.text.clone(), + Some(old_text), + base_buffer, + cx, + ) + })? + .await; + + let secondary_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&buffer, cx); + diff.set_snapshot(diff_snapshot.clone(), &buffer, cx); + diff + })?; + + cx.new(|cx| { + let mut diff = BufferDiff::new(&buffer.text, cx); + diff.set_snapshot(diff_snapshot, &buffer, cx); + diff.set_secondary_diff(secondary_diff); + diff + }) +} diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index a75011a671..3e19895a31 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -15,8 +15,10 @@ workspace = true acp_thread.workspace = true agent-client-protocol.workspace = true agent_servers.workspace = true +agent_settings.workspace = true anyhow.workspace = true assistant_tool.workspace = true +assistant_tools.workspace = true cloud_llm_client.workspace = true collections.workspace = true fs.workspace = true @@ -29,6 +31,7 @@ language.workspace = true language_model.workspace = true language_models.workspace = true log.workspace = true +paths.workspace = true project.workspace = true prompt_store.workspace = true rust-embed.workspace = true @@ -53,6 +56,7 @@ gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } +lsp = { workspace = true, "features" = ["test-support"] } project = { workspace = true, "features" = ["test-support"] } reqwest_client.workspace = true settings = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 2014d86fb7..e7920e7891 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -1,5 +1,5 @@ use crate::{templates::Templates, AgentResponseEvent, Thread}; -use crate::{FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; +use crate::{EditFileTool, FindPathTool, ReadFileTool, ThinkingTool, ToolCallAuthorization}; use acp_thread::ModelSelector; use agent_client_protocol as acp; use anyhow::{anyhow, Context as _, Result}; @@ -412,11 +412,12 @@ impl acp_thread::AgentConnection for NativeAgentConnection { anyhow!("No default model configured. Please configure a default model in settings.") })?; - let thread = cx.new(|_| { + let thread = cx.new(|cx| { let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model); thread.add_tool(ThinkingTool); thread.add_tool(FindPathTool::new(project.clone())); thread.add_tool(ReadFileTool::new(project.clone(), action_log)); + thread.add_tool(EditFileTool::new(cx.entity())); thread }); @@ -564,6 +565,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection { ) })??; } + AgentResponseEvent::ToolCallDiff(tool_call_diff) => { + acp_thread.update(cx, |thread, cx| { + thread.set_tool_call_diff( + &tool_call_diff.tool_call_id, + tool_call_diff.diff, + cx, + ) + })??; + } AgentResponseEvent::Stop(stop_reason) => { log::debug!("Assistant message complete: {:?}", stop_reason); return Ok(acp::PromptResponse { stop_reason }); diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index db743c8429..f13cd1bd67 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -9,5 +9,6 @@ mod tests; pub use agent::*; pub use native_agent_server::NativeAgentServer; +pub use templates::*; pub use thread::*; pub use tools::*; diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 7913f9a24c..b70f54ac0a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,5 +1,4 @@ use super::*; -use crate::templates::Templates; use acp_thread::AgentConnection; use agent_client_protocol::{self as acp}; use anyhow::Result; @@ -273,7 +272,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { tool_name: ToolRequiringPermission.name().into(), is_error: false, content: "Allowed".into(), - output: None + output: Some("Allowed".into()) }), MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(), diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent2/src/tests/test_tools.rs index fd6e7e941f..d22ff6ace8 100644 --- a/crates/agent2/src/tests/test_tools.rs +++ b/crates/agent2/src/tests/test_tools.rs @@ -14,6 +14,7 @@ pub struct EchoTool; impl AgentTool for EchoTool { type Input = EchoToolInput; + type Output = String; fn name(&self) -> SharedString { "echo".into() @@ -48,6 +49,7 @@ pub struct DelayTool; impl AgentTool for DelayTool { type Input = DelayToolInput; + type Output = String; fn name(&self) -> SharedString { "delay".into() @@ -84,6 +86,7 @@ pub struct ToolRequiringPermission; impl AgentTool for ToolRequiringPermission { type Input = ToolRequiringPermissionInput; + type Output = String; fn name(&self) -> SharedString { "tool_requiring_permission".into() @@ -99,14 +102,11 @@ impl AgentTool for ToolRequiringPermission { fn run( self: Arc, - input: Self::Input, + _input: Self::Input, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> - where - Self: Sized, - { - let auth_check = self.authorize(input, event_stream); + ) -> Task> { + let auth_check = event_stream.authorize("Authorize?".into()); cx.foreground_executor().spawn(async move { auth_check.await?; Ok("Allowed".to_string()) @@ -121,6 +121,7 @@ pub struct InfiniteTool; impl AgentTool for InfiniteTool { type Input = InfiniteToolInput; + type Output = String; fn name(&self) -> SharedString { "infinite".into() @@ -171,19 +172,20 @@ pub struct WordListTool; impl AgentTool for WordListTool { type Input = WordListInput; + type Output = String; fn name(&self) -> SharedString { "word_list".into() } - fn initial_title(&self, _input: Self::Input) -> SharedString { - "List of random words".into() - } - fn kind(&self) -> acp::ToolKind { acp::ToolKind::Other } + fn initial_title(&self, _input: Self::Input) -> SharedString { + "List of random words".into() + } + fn run( self: Arc, _input: Self::Input, diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 4b8a65655f..98f2d0651d 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,4 +1,5 @@ -use crate::templates::{SystemPromptTemplate, Template, Templates}; +use crate::{SystemPromptTemplate, Template, Templates}; +use acp_thread::Diff; use agent_client_protocol as acp; use anyhow::{anyhow, Context as _, Result}; use assistant_tool::{adapt_schema_to_format, ActionLog}; @@ -103,6 +104,7 @@ pub enum AgentResponseEvent { ToolCall(acp::ToolCall), ToolCallUpdate(acp::ToolCallUpdate), ToolCallAuthorization(ToolCallAuthorization), + ToolCallDiff(ToolCallDiff), Stop(acp::StopReason), } @@ -113,6 +115,12 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +#[derive(Debug)] +pub struct ToolCallDiff { + pub tool_call_id: acp::ToolCallId, + pub diff: Entity, +} + pub struct Thread { messages: Vec, completion_mode: CompletionMode, @@ -125,12 +133,13 @@ pub struct Thread { project_context: Rc>, templates: Arc, pub selected_model: Arc, + project: Entity, action_log: Entity, } impl Thread { pub fn new( - _project: Entity, + project: Entity, project_context: Rc>, action_log: Entity, templates: Arc, @@ -145,10 +154,19 @@ impl Thread { project_context, templates, selected_model: default_model, + project, action_log, } } + pub fn project(&self) -> &Entity { + &self.project + } + + pub fn action_log(&self) -> &Entity { + &self.action_log + } + pub fn set_mode(&mut self, mode: CompletionMode) { self.completion_mode = mode; } @@ -315,10 +333,6 @@ impl Thread { events_rx } - pub fn action_log(&self) -> &Entity { - &self.action_log - } - pub fn build_system_message(&self) -> AgentMessage { log::debug!("Building system message"); let prompt = SystemPromptTemplate { @@ -490,15 +504,33 @@ impl Thread { })); }; - let tool_result = self.run_tool(tool, tool_use.clone(), event_stream.clone(), cx); + let tool_event_stream = + ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone()); + tool_event_stream.send_update(acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }); + let supports_images = self.selected_model.supports_images(); + let tool_result = tool.run(tool_use.input, tool_event_stream, cx); Some(cx.foreground_executor().spawn(async move { - match tool_result.await { - Ok(tool_output) => LanguageModelToolResult { + let tool_result = tool_result.await.and_then(|output| { + if let LanguageModelToolResultContent::Image(_) = &output.llm_output { + if !supports_images { + return Err(anyhow!( + "Attempted to read an image, but this model doesn't support it.", + )); + } + } + Ok(output) + }); + + match tool_result { + Ok(output) => LanguageModelToolResult { tool_use_id: tool_use.id, tool_name: tool_use.name, is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from(tool_output)), - output: None, + content: output.llm_output, + output: Some(output.raw_output), }, Err(error) => LanguageModelToolResult { tool_use_id: tool_use.id, @@ -511,24 +543,6 @@ impl Thread { })) } - fn run_tool( - &self, - tool: Arc, - tool_use: LanguageModelToolUse, - event_stream: AgentResponseEventStream, - cx: &mut Context, - ) -> Task> { - cx.spawn(async move |_this, cx| { - let tool_event_stream = ToolCallEventStream::new(tool_use.id, event_stream); - tool_event_stream.send_update(acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::InProgress), - ..Default::default() - }); - cx.update(|cx| tool.run(tool_use.input, tool_event_stream, cx))? - .await - }) - } - fn handle_tool_use_json_parse_error_event( &mut self, tool_use_id: LanguageModelToolUseId, @@ -572,7 +586,7 @@ impl Thread { self.messages.last_mut().unwrap() } - fn build_completion_request( + pub(crate) fn build_completion_request( &self, completion_intent: CompletionIntent, cx: &mut App, @@ -662,6 +676,7 @@ where Self: 'static + Sized, { type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; + type Output: for<'de> Deserialize<'de> + Serialize + Into; fn name(&self) -> SharedString; @@ -685,23 +700,13 @@ where schemars::schema_for!(Self::Input) } - /// Allows the tool to authorize a given tool call with the user if necessary - fn authorize( - &self, - input: Self::Input, - event_stream: ToolCallEventStream, - ) -> impl use + Future> { - let json_input = serde_json::json!(&input); - event_stream.authorize(self.initial_title(input).into(), self.kind(), json_input) - } - /// Runs the tool with the provided input. fn run( self: Arc, input: Self::Input, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task>; + ) -> Task>; fn erase(self) -> Arc { Arc::new(Erased(Arc::new(self))) @@ -710,6 +715,11 @@ where pub struct Erased(T); +pub struct AgentToolOutput { + llm_output: LanguageModelToolResultContent, + raw_output: serde_json::Value, +} + pub trait AnyAgentTool { fn name(&self) -> SharedString; fn description(&self, cx: &mut App) -> SharedString; @@ -721,7 +731,7 @@ pub trait AnyAgentTool { input: serde_json::Value, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task>; + ) -> Task>; } impl AnyAgentTool for Erased> @@ -756,12 +766,18 @@ where input: serde_json::Value, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> { - let parsed_input: Result = serde_json::from_value(input).map_err(Into::into); - match parsed_input { - Ok(input) => self.0.clone().run(input, event_stream, cx), - Err(error) => Task::ready(Err(anyhow!(error))), - } + ) -> Task> { + cx.spawn(async move |cx| { + let input = serde_json::from_value(input)?; + let output = cx + .update(|cx| self.0.clone().run(input, event_stream, cx))? + .await?; + let raw_output = serde_json::to_value(&output)?; + Ok(AgentToolOutput { + llm_output: output.into(), + raw_output, + }) + }) } } @@ -874,6 +890,12 @@ impl AgentResponseEventStream { .ok(); } + fn send_tool_call_diff(&self, tool_call_diff: ToolCallDiff) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::ToolCallDiff(tool_call_diff))) + .ok(); + } + fn send_stop(&self, reason: StopReason) { match reason { StopReason::EndTurn => { @@ -903,13 +925,41 @@ impl AgentResponseEventStream { #[derive(Clone)] pub struct ToolCallEventStream { tool_use_id: LanguageModelToolUseId, + kind: acp::ToolKind, + input: serde_json::Value, stream: AgentResponseEventStream, } impl ToolCallEventStream { - fn new(tool_use_id: LanguageModelToolUseId, stream: AgentResponseEventStream) -> Self { + #[cfg(test)] + pub fn test() -> (Self, ToolCallEventStreamReceiver) { + let (events_tx, events_rx) = + mpsc::unbounded::>(); + + let stream = ToolCallEventStream::new( + &LanguageModelToolUse { + id: "test_id".into(), + name: "test_tool".into(), + raw_input: String::new(), + input: serde_json::Value::Null, + is_input_complete: true, + }, + acp::ToolKind::Other, + AgentResponseEventStream(events_tx), + ); + + (stream, ToolCallEventStreamReceiver(events_rx)) + } + + fn new( + tool_use: &LanguageModelToolUse, + kind: acp::ToolKind, + stream: AgentResponseEventStream, + ) -> Self { Self { - tool_use_id, + tool_use_id: tool_use.id.clone(), + kind, + input: tool_use.input.clone(), stream, } } @@ -918,38 +968,52 @@ impl ToolCallEventStream { self.stream.send_tool_call_update(&self.tool_use_id, fields); } - pub fn authorize( - &self, - title: String, - kind: acp::ToolKind, - input: serde_json::Value, - ) -> impl use<> + Future> { - self.stream - .authorize_tool_call(&self.tool_use_id, title, kind, input) + pub fn send_diff(&self, diff: Entity) { + self.stream.send_tool_call_diff(ToolCallDiff { + tool_call_id: acp::ToolCallId(self.tool_use_id.to_string().into()), + diff, + }); + } + + pub fn authorize(&self, title: String) -> impl use<> + Future> { + self.stream.authorize_tool_call( + &self.tool_use_id, + title, + self.kind.clone(), + self.input.clone(), + ) } } #[cfg(test)] -pub struct TestToolCallEventStream { - stream: ToolCallEventStream, - _events_rx: mpsc::UnboundedReceiver>, -} +pub struct ToolCallEventStreamReceiver( + mpsc::UnboundedReceiver>, +); #[cfg(test)] -impl TestToolCallEventStream { - pub fn new() -> Self { - let (events_tx, events_rx) = - mpsc::unbounded::>(); - - let stream = ToolCallEventStream::new("test".into(), AgentResponseEventStream(events_tx)); - - Self { - stream, - _events_rx: events_rx, +impl ToolCallEventStreamReceiver { + pub async fn expect_tool_authorization(&mut self) -> ToolCallAuthorization { + let event = self.0.next().await; + if let Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth))) = event { + auth + } else { + panic!("Expected ToolCallAuthorization but got: {:?}", event); } } +} - pub fn stream(&self) -> ToolCallEventStream { - self.stream.clone() +#[cfg(test)] +impl std::ops::Deref for ToolCallEventStreamReceiver { + type Target = mpsc::UnboundedReceiver>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(test)] +impl std::ops::DerefMut for ToolCallEventStreamReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs index 240614c263..5fe13db854 100644 --- a/crates/agent2/src/tools.rs +++ b/crates/agent2/src/tools.rs @@ -1,7 +1,9 @@ +mod edit_file_tool; mod find_path_tool; mod read_file_tool; mod thinking_tool; +pub use edit_file_tool::*; pub use find_path_tool::*; pub use read_file_tool::*; pub use thinking_tool::*; diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs new file mode 100644 index 0000000000..0dbe0be217 --- /dev/null +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -0,0 +1,1361 @@ +use acp_thread::Diff; +use agent_client_protocol as acp; +use anyhow::{anyhow, Context as _, Result}; +use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; +use cloud_llm_client::CompletionIntent; +use collections::HashSet; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use indoc::formatdoc; +use language::language_settings::{self, FormatOnSave}; +use language_model::LanguageModelToolResultContent; +use paths; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; +use project::{Project, ProjectPath}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::Settings; +use smol::stream::StreamExt as _; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use ui::SharedString; +use util::ResultExt; + +use crate::{AgentTool, Thread, ToolCallEventStream}; + +/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. +/// +/// Before using this tool: +/// +/// 1. Use the `read_file` tool to understand the file's contents and context +/// +/// 2. Verify the directory path is correct (only applicable when creating new files): +/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct EditFileToolInput { + /// A one-line, user-friendly markdown description of the edit. This will be + /// shown in the UI and also passed to another model to perform the edit. + /// + /// Be terse, but also descriptive in what you want to achieve with this + /// edit. Avoid generic instructions. + /// + /// NEVER mention the file path in this description. + /// + /// Fix API endpoint URLs + /// Update copyright year in `page_footer` + /// + /// Make sure to include this field before all the others in the input object + /// so that we can display it immediately. + pub display_description: String, + + /// The full path of the file to create or modify in the project. + /// + /// WARNING: When specifying which file path need changing, you MUST + /// start each path with one of the project's root directories. + /// + /// The following examples assume we have two root directories in the project: + /// - /a/b/backend + /// - /c/d/frontend + /// + /// + /// `backend/src/main.rs` + /// + /// Notice how the file path starts with `backend`. Without that, the path + /// would be ambiguous and the call would fail! + /// + /// + /// + /// `frontend/db.js` + /// + pub path: PathBuf, + + /// The mode of operation on the file. Possible values: + /// - 'edit': Make granular edits to an existing file. + /// - 'create': Create a new file if it doesn't exist. + /// - 'overwrite': Replace the entire contents of an existing file. + /// + /// When a file already exists or you just created it, prefer editing + /// it as opposed to recreating it from scratch. + pub mode: EditFileMode, +} + +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum EditFileMode { + Edit, + Create, + Overwrite, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EditFileToolOutput { + input_path: PathBuf, + project_path: PathBuf, + new_text: String, + old_text: Arc, + diff: String, + edit_agent_output: EditAgentOutput, +} + +impl From for LanguageModelToolResultContent { + fn from(output: EditFileToolOutput) -> Self { + if output.diff.is_empty() { + "No edits were made.".into() + } else { + format!( + "Edited {}:\n\n```diff\n{}\n```", + output.input_path.display(), + output.diff + ) + .into() + } + } +} + +pub struct EditFileTool { + thread: Entity, +} + +impl EditFileTool { + pub fn new(thread: Entity) -> Self { + Self { thread } + } + + fn authorize( + &self, + input: &EditFileToolInput, + event_stream: &ToolCallEventStream, + cx: &App, + ) -> Task> { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return Task::ready(Ok(())); + } + + // If any path component matches the local settings folder, then this could affect + // the editor in ways beyond the project source, so prompt. + let local_settings_folder = paths::local_settings_folder_relative_path(); + let path = Path::new(&input.path); + if path + .components() + .any(|component| component.as_os_str() == local_settings_folder.as_os_str()) + { + return cx.foreground_executor().spawn( + event_stream.authorize(format!("{} (local settings)", input.display_description)), + ); + } + + // It's also possible that the global config dir is configured to be inside the project, + // so check for that edge case too. + if let Ok(canonical_path) = std::fs::canonicalize(&input.path) { + if canonical_path.starts_with(paths::config_dir()) { + return cx.foreground_executor().spawn( + event_stream + .authorize(format!("{} (global settings)", input.display_description)), + ); + } + } + + // Check if path is inside the global config directory + // First check if it's already inside project - if not, try to canonicalize + let thread = self.thread.read(cx); + let project_path = thread.project().read(cx).find_project_path(&input.path, cx); + + // If the path is inside the project, and it's not one of the above edge cases, + // then no confirmation is necessary. Otherwise, confirmation is necessary. + if project_path.is_some() { + Task::ready(Ok(())) + } else { + cx.foreground_executor() + .spawn(event_stream.authorize(input.display_description.clone())) + } + } +} + +impl AgentTool for EditFileTool { + type Input = EditFileToolInput; + type Output = EditFileToolOutput; + + fn name(&self) -> SharedString { + "edit_file".into() + } + + fn kind(&self) -> acp::ToolKind { + acp::ToolKind::Edit + } + + fn initial_title(&self, input: Self::Input) -> SharedString { + input.display_description.into() + } + + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + let project = self.thread.read(cx).project().clone(); + let project_path = match resolve_path(&input, project.clone(), cx) { + Ok(path) => path, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let request = self.thread.update(cx, |thread, cx| { + thread.build_completion_request(CompletionIntent::ToolResults, cx) + }); + let thread = self.thread.read(cx); + let model = thread.selected_model.clone(); + let action_log = thread.action_log().clone(); + + let authorize = self.authorize(&input, &event_stream, cx); + cx.spawn(async move |cx: &mut AsyncApp| { + authorize.await?; + + let edit_format = EditFormat::from_model(model.clone())?; + let edit_agent = EditAgent::new( + model, + project.clone(), + action_log.clone(), + // TODO: move edit agent to this crate so we can use our templates + assistant_tools::templates::Templates::new(), + edit_format, + ); + + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path.clone(), cx) + })? + .await?; + + let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?; + event_stream.send_diff(diff.clone()); + + let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let old_text = cx + .background_spawn({ + let old_snapshot = old_snapshot.clone(); + async move { Arc::new(old_snapshot.text()) } + }) + .await; + + + let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { + edit_agent.edit( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + } else { + edit_agent.overwrite( + buffer.clone(), + input.display_description.clone(), + &request, + cx, + ) + }; + + let mut hallucinated_old_text = false; + let mut ambiguous_ranges = Vec::new(); + while let Some(event) = events.next().await { + match event { + EditAgentOutputEvent::Edited => {}, + EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, + EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, + EditAgentOutputEvent::ResolvingEditRange(range) => { + diff.update(cx, |card, cx| card.reveal_range(range, cx))?; + } + } + } + + // If format_on_save is enabled, format the buffer + let format_on_save_enabled = buffer + .read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + settings.format_on_save != FormatOnSave::Off + }) + .unwrap_or(false); + + let edit_agent_output = output.await?; + + if format_on_save_enabled { + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, // Don't push to history since the tool did it. + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + } + + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? + .await?; + + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + })?; + + let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let (new_text, unified_diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + (new_text, diff) + } + }) + .await; + + diff.update(cx, |diff, cx| diff.finalize(cx)).ok(); + + let input_path = input.path.display(); + if unified_diff.is_empty() { + anyhow::ensure!( + !hallucinated_old_text, + formatdoc! {" + Some edits were produced but none of them could be applied. + Read the relevant sections of {input_path} again so that + I can perform the requested edits. + "} + ); + anyhow::ensure!( + ambiguous_ranges.is_empty(), + { + let line_numbers = ambiguous_ranges + .iter() + .map(|range| range.start.to_string()) + .collect::>() + .join(", "); + formatdoc! {" + matches more than one position in the file (lines: {line_numbers}). Read the + relevant sections of {input_path} again and extend so + that I can perform the requested edits. + "} + } + ); + } + + Ok(EditFileToolOutput { + input_path: input.path, + project_path: project_path.path.to_path_buf(), + new_text: new_text.clone(), + old_text, + diff: unified_diff, + edit_agent_output, + }) + }) + } +} + +/// Validate that the file path is valid, meaning: +/// +/// - For `edit` and `overwrite`, the path must point to an existing file. +/// - For `create`, the file must not already exist, but it's parent dir must exist. +fn resolve_path( + input: &EditFileToolInput, + project: Entity, + cx: &mut App, +) -> Result { + let project = project.read(cx); + + match input.mode { + EditFileMode::Edit | EditFileMode::Overwrite => { + let path = project + .find_project_path(&input.path, cx) + .context("Can't edit file: path not found")?; + + let entry = project + .entry_for_path(&path, cx) + .context("Can't edit file: path not found")?; + + anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); + Ok(path) + } + + EditFileMode::Create => { + if let Some(path) = project.find_project_path(&input.path, cx) { + anyhow::ensure!( + project.entry_for_path(&path, cx).is_none(), + "Can't create file: file already exists" + ); + } + + let parent_path = input + .path + .parent() + .context("Can't create file: incorrect path")?; + + let parent_project_path = project.find_project_path(&parent_path, cx); + + let parent_entry = parent_project_path + .as_ref() + .and_then(|path| project.entry_for_path(&path, cx)) + .context("Can't create file: parent directory doesn't exist")?; + + anyhow::ensure!( + parent_entry.is_dir(), + "Can't create file: parent is not a directory" + ); + + let file_name = input + .path + .file_name() + .context("Can't create file: invalid filename")?; + + let new_file_path = parent_project_path.map(|parent| ProjectPath { + path: Arc::from(parent.path.join(file_name)), + ..parent + }); + + new_file_path.context("Can't create file") + } + } +} + +#[cfg(test)] +mod tests { + use crate::Templates; + + use super::*; + use assistant_tool::ActionLog; + use client::TelemetrySettings; + use fs::Fs; + use gpui::{TestAppContext, UpdateGlobal}; + use language_model::fake_provider::FakeLanguageModel; + use serde_json::json; + use settings::SettingsStore; + use std::rc::Rc; + use util::path; + + #[gpui::test] + async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = + cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model)); + let result = cx + .update(|cx| { + let input = EditFileToolInput { + display_description: "Some edit".into(), + path: "root/nonexistent_file.txt".into(), + mode: EditFileMode::Edit, + }; + Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "Can't edit file: path not found" + ); + } + + #[gpui::test] + async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) { + let mode = &EditFileMode::Create; + + let result = test_resolve_path(mode, "root/new.txt", cx); + assert_resolved_path_eq(result.await, "new.txt"); + + let result = test_resolve_path(mode, "new.txt", cx); + assert_resolved_path_eq(result.await, "new.txt"); + + let result = test_resolve_path(mode, "dir/new.txt", cx); + assert_resolved_path_eq(result.await, "dir/new.txt"); + + let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: file already exists" + ); + + let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: parent directory doesn't exist" + ); + } + + #[gpui::test] + async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) { + let mode = &EditFileMode::Edit; + + let path_with_root = "root/dir/subdir/existing.txt"; + let path_without_root = "dir/subdir/existing.txt"; + let result = test_resolve_path(mode, path_with_root, cx); + assert_resolved_path_eq(result.await, path_without_root); + + let result = test_resolve_path(mode, path_without_root, cx); + assert_resolved_path_eq(result.await, path_without_root); + + let result = test_resolve_path(mode, "root/nonexistent.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path not found" + ); + + let result = test_resolve_path(mode, "root/dir", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path is a directory" + ); + } + + async fn test_resolve_path( + mode: &EditFileMode, + path: &str, + cx: &mut TestAppContext, + ) -> anyhow::Result { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "dir": { + "subdir": { + "existing.txt": "hello" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let input = EditFileToolInput { + display_description: "Some edit".into(), + path: path.into(), + mode: mode.clone(), + }; + + let result = cx.update(|cx| resolve_path(&input, project, cx)); + result + } + + fn assert_resolved_path_eq(path: anyhow::Result, expected: &str) { + let actual = path + .expect("Should return valid path") + .path + .to_str() + .unwrap() + .replace("\\", "/"); // Naive Windows paths normalization + assert_eq!(actual, expected); + } + + #[gpui::test] + async fn test_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + // Set up a Rust language with LSP formatting support + let rust_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Rust".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + // Register the language and fake LSP + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + // Create the file + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; + const FORMATTED_CONTENT: &str = + "This file was formatted by the fake formatter in the test.\n"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + + // First, test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::On); + settings.defaults.formatter = + Some(language::language_settings::SelectedFormatter::Auto); + }, + ); + }); + }); + + // Have the model stream unformatted content + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify it was formatted automatically + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count()); + + assert_eq!( + stale_buffer_count, 0, + "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \ + This causes the agent to think the file was modified externally when it was just formatted.", + stale_buffer_count + ); + + // Next, test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.format_on_save = Some(FormatOnSave::Off); + }, + ); + }); + }); + + // Stream unformatted edits again + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the unformatted content + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file was not formatted + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + new_content.replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should not be formatted when format_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + // Create a simple file with trailing whitespace + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + + // First, test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(true); + }, + ); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + // Have the model stream content that contains trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Read the file to verify trailing whitespace was removed automatically + assert_eq!( + // Ignore carriage returns on Windows + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Next, test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings::( + cx, + |settings| { + settings.defaults.remove_trailing_whitespace_on_save = Some(false); + }, + ); + }); + }); + + // Stream edits again with trailing whitespace + let edit_result = { + let edit_task = cx.update(|cx| { + let input = EditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Overwrite, + }; + Arc::new(EditFileTool { + thread: thread.clone(), + }) + .run(input, ToolCallEventStream::test().0, cx) + }); + + // Stream the content with trailing whitespace + cx.executor().run_until_parked(); + model.send_last_completion_stream_text_chunk( + CONTENT_WITH_TRAILING_WHITESPACE.to_string(), + ); + model.end_last_completion_stream(); + + edit_task.await + }; + assert!(edit_result.is_ok()); + + // Wait for any async operations (e.g. formatting) to complete + cx.executor().run_until_parked(); + + // Verify the file still has trailing whitespace + // Read the file again - it should still have trailing whitespace + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + // Ignore carriage returns on Windows + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_authorize(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 1".into(), + path: ".zed/settings.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 1 (local settings)"); + + // Test 2: Path outside project should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 2".into(), + path: "/etc/hosts".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 2"); + + // Test 3: Relative path without .zed should not require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 3".into(), + path: "root/src/main.rs".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + // Test 4: Path with .zed in the middle should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 4".into(), + path: "root/.zed/tasks.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + let event = stream_rx.expect_tool_authorization().await; + assert_eq!(event.tool_call.title, "test 4 (local settings)"); + + // Test 5: When always_allow_tool_actions is enabled, no confirmation needed + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.always_allow_tool_actions = true; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 5.1".into(), + path: ".zed/settings.json".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "test 5.2".into(), + path: "/etc/hosts".into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + } + + #[gpui::test] + async fn test_authorize_global_config(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project, + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test global config paths - these should require confirmation if they exist and are outside the project + let test_cases = vec![ + ( + "/etc/hosts", + true, + "System file should require confirmation", + ), + ( + "/usr/local/bin/script", + true, + "System bin file should require confirmation", + ), + ( + "project/normal_file.rs", + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + + // Create multiple worktree directories + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + // Create project with multiple worktrees + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test files in different worktrees + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test edge cases + let test_cases = vec![ + // Empty path - find_project_path returns Some for empty paths + ("", false, "Empty path is treated as project root"), + // Root directory + ("/", true, "Root directory should be outside project"), + // Parent directory references - find_project_path resolves these + ( + "project/../other", + false, + "Path with .. is resolved by find_project_path", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + // Windows-style paths (if on Windows) + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: path.into(), + mode: EditFileMode::Edit, + }, + &stream_tx, + cx, + ) + }); + + if should_confirm { + stream_rx.expect_tool_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|_| { + Thread::new( + project.clone(), + Rc::default(), + action_log.clone(), + Templates::new(), + model.clone(), + ) + }); + let tool = Arc::new(EditFileTool { thread }); + + // Test different EditFileMode values + let modes = vec![ + EditFileMode::Edit, + EditFileMode::Create, + EditFileMode::Overwrite, + ]; + + for mode in modes { + // Test .zed path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit settings".into(), + path: "project/.zed/settings.json".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }); + + stream_rx.expect_tool_authorization().await; + + // Test outside path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: "/outside/file.txt".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }); + + stream_rx.expect_tool_authorization().await; + + // Test normal path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &EditFileToolInput { + display_description: "Edit file".into(), + path: "project/normal.txt".into(), + mode: mode.clone(), + }, + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + } + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + TelemetrySettings::register(cx); + agent_settings::AgentSettings::register(cx); + Project::init_settings(cx); + }); + } +} diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent2/src/tools/find_path_tool.rs index e840fec78c..24bdcded8c 100644 --- a/crates/agent2/src/tools/find_path_tool.rs +++ b/crates/agent2/src/tools/find_path_tool.rs @@ -1,6 +1,8 @@ +use crate::{AgentTool, ToolCallEventStream}; use agent_client_protocol as acp; use anyhow::{anyhow, Result}; use gpui::{App, AppContext, Entity, SharedString, Task}; +use language_model::LanguageModelToolResultContent; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -8,8 +10,6 @@ use std::fmt::Write; use std::{cmp, path::PathBuf, sync::Arc}; use util::paths::PathMatcher; -use crate::{AgentTool, ToolCallEventStream}; - /// Fast file path pattern matching tool that works with any codebase size /// /// - Supports glob patterns like "**/*.js" or "src/**/*.ts" @@ -39,8 +39,35 @@ pub struct FindPathToolInput { } #[derive(Debug, Serialize, Deserialize)] -struct FindPathToolOutput { - paths: Vec, +pub struct FindPathToolOutput { + offset: usize, + current_matches_page: Vec, + all_matches_len: usize, +} + +impl From for LanguageModelToolResultContent { + fn from(output: FindPathToolOutput) -> Self { + if output.current_matches_page.is_empty() { + "No matches found".into() + } else { + let mut llm_output = format!("Found {} total matches.", output.all_matches_len); + if output.all_matches_len > RESULTS_PER_PAGE { + write!( + &mut llm_output, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", + output.offset + 1, + output.offset + output.current_matches_page.len() + ) + .unwrap(); + } + + for mat in output.current_matches_page { + write!(&mut llm_output, "\n{}", mat.display()).unwrap(); + } + + llm_output.into() + } + } } const RESULTS_PER_PAGE: usize = 50; @@ -57,6 +84,7 @@ impl FindPathTool { impl AgentTool for FindPathTool { type Input = FindPathToolInput; + type Output = FindPathToolOutput; fn name(&self) -> SharedString { "find_path".into() @@ -75,7 +103,7 @@ impl AgentTool for FindPathTool { input: Self::Input, event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> { + ) -> Task> { let search_paths_task = search_paths(&input.glob, self.project.clone(), cx); cx.background_spawn(async move { @@ -113,26 +141,11 @@ impl AgentTool for FindPathTool { ..Default::default() }); - if matches.is_empty() { - Ok("No matches found".into()) - } else { - let mut message = format!("Found {} total matches.", matches.len()); - if matches.len() > RESULTS_PER_PAGE { - write!( - &mut message, - "\nShowing results {}-{} (provide 'offset' parameter for more results):", - input.offset + 1, - input.offset + paginated_matches.len() - ) - .unwrap(); - } - - for mat in matches.iter().skip(input.offset).take(RESULTS_PER_PAGE) { - write!(&mut message, "\n{}", mat.display()).unwrap(); - } - - Ok(message) - } + Ok(FindPathToolOutput { + offset: input.offset, + current_matches_page: paginated_matches.to_vec(), + all_matches_len: matches.len(), + }) }) } } diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent2/src/tools/read_file_tool.rs index 30794ccdad..3d91e3dc74 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent2/src/tools/read_file_tool.rs @@ -1,10 +1,11 @@ use agent_client_protocol::{self as acp}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use assistant_tool::{outline, ActionLog}; use gpui::{Entity, Task}; use indoc::formatdoc; use language::{Anchor, Point}; -use project::{AgentLocation, Project, WorktreeSettings}; +use language_model::{LanguageModelImage, LanguageModelToolResultContent}; +use project::{image_store, AgentLocation, ImageItem, Project, WorktreeSettings}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; @@ -59,6 +60,7 @@ impl ReadFileTool { impl AgentTool for ReadFileTool { type Input = ReadFileToolInput; + type Output = LanguageModelToolResultContent; fn name(&self) -> SharedString { "read_file".into() @@ -91,9 +93,9 @@ impl AgentTool for ReadFileTool { fn run( self: Arc, input: Self::Input, - event_stream: ToolCallEventStream, + _event_stream: ToolCallEventStream, cx: &mut App, - ) -> Task> { + ) -> Task> { let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))); }; @@ -132,51 +134,27 @@ impl AgentTool for ReadFileTool { let file_path = input.path.clone(); - event_stream.send_update(acp::ToolCallUpdateFields { - locations: Some(vec![acp::ToolCallLocation { - path: project_path.path.to_path_buf(), - line: input.start_line, - // TODO (tracked): use full range - }]), - ..Default::default() - }); + if image_store::is_image_file(&self.project, &project_path, cx) { + return cx.spawn(async move |cx| { + let image_entity: Entity = cx + .update(|cx| { + self.project.update(cx, |project, cx| { + project.open_image(project_path.clone(), cx) + }) + })? + .await?; - // TODO (tracked): images - // if image_store::is_image_file(&self.project, &project_path, cx) { - // let model = &self.thread.read(cx).selected_model; + let image = + image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; - // if !model.supports_images() { - // return Task::ready(Err(anyhow!( - // "Attempted to read an image, but Zed doesn't currently support sending images to {}.", - // model.name().0 - // ))) - // .into(); - // } + let language_model_image = cx + .update(|cx| LanguageModelImage::from_image(image, cx))? + .await + .context("processing image")?; - // return cx.spawn(async move |cx| -> Result { - // let image_entity: Entity = cx - // .update(|cx| { - // self.project.update(cx, |project, cx| { - // project.open_image(project_path.clone(), cx) - // }) - // })? - // .await?; - - // let image = - // image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; - - // let language_model_image = cx - // .update(|cx| LanguageModelImage::from_image(image, cx))? - // .await - // .context("processing image")?; - - // Ok(ToolResultOutput { - // content: ToolResultContent::Image(language_model_image), - // output: None, - // }) - // }); - // } - // + Ok(language_model_image.into()) + }); + } let project = self.project.clone(); let action_log = self.action_log.clone(); @@ -244,7 +222,7 @@ impl AgentTool for ReadFileTool { })?; } - Ok(result) + Ok(result.into()) } else { // No line ranges specified, so check file size to see if it's too big. let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?; @@ -257,7 +235,7 @@ impl AgentTool for ReadFileTool { log.buffer_read(buffer, cx); })?; - Ok(result) + Ok(result.into()) } else { // File is too big, so return the outline // and a suggestion to read again with line numbers. @@ -276,7 +254,8 @@ impl AgentTool for ReadFileTool { Alternatively, you can fall back to the `grep` tool (if available) to search the file for specific content." - }) + } + .into()) } } }) @@ -285,8 +264,6 @@ impl AgentTool for ReadFileTool { #[cfg(test)] mod test { - use crate::TestToolCallEventStream; - use super::*; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher}; @@ -304,7 +281,7 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); + let (event_stream, _) = ToolCallEventStream::test(); let result = cx .update(|cx| { @@ -313,7 +290,7 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream.stream(), cx) + tool.run(input, event_stream, cx) }) .await; assert_eq!( @@ -321,6 +298,7 @@ mod test { "root/nonexistent_file.txt not found" ); } + #[gpui::test] async fn test_read_small_file(cx: &mut TestAppContext) { init_test(cx); @@ -336,7 +314,6 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -344,10 +321,10 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream.stream(), cx) + tool.run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "This is a small file content"); + assert_eq!(result.unwrap(), "This is a small file content".into()); } #[gpui::test] @@ -367,18 +344,18 @@ mod test { language_registry.add(Arc::new(rust_lang())); let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); - let content = cx + let result = cx .update(|cx| { let input = ReadFileToolInput { path: "root/large_file.rs".into(), start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await .unwrap(); + let content = result.to_str().unwrap(); assert_eq!( content.lines().skip(4).take(6).collect::>(), @@ -399,10 +376,11 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream.stream(), cx) + tool.run(input, ToolCallEventStream::test().0, cx) }) - .await; - let content = result.unwrap(); + .await + .unwrap(); + let content = result.to_str().unwrap(); let expected_content = (0..1000) .flat_map(|i| { vec![ @@ -438,7 +416,6 @@ mod test { let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); let result = cx .update(|cx| { let input = ReadFileToolInput { @@ -446,10 +423,10 @@ mod test { start_line: Some(2), end_line: Some(4), }; - tool.run(input, event_stream.stream(), cx) + tool.run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4"); + assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4".into()); } #[gpui::test] @@ -467,7 +444,6 @@ mod test { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); // start_line of 0 should be treated as 1 let result = cx @@ -477,10 +453,10 @@ mod test { start_line: Some(0), end_line: Some(2), }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 1\nLine 2"); + assert_eq!(result.unwrap(), "Line 1\nLine 2".into()); // end_line of 0 should result in at least 1 line let result = cx @@ -490,10 +466,10 @@ mod test { start_line: Some(1), end_line: Some(0), }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 1"); + assert_eq!(result.unwrap(), "Line 1".into()); // when start_line > end_line, should still return at least 1 line let result = cx @@ -503,10 +479,10 @@ mod test { start_line: Some(3), end_line: Some(2), }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; - assert_eq!(result.unwrap(), "Line 3"); + assert_eq!(result.unwrap(), "Line 3".into()); } fn init_test(cx: &mut TestAppContext) { @@ -612,7 +588,6 @@ mod test { let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project, action_log)); - let event_stream = TestToolCallEventStream::new(); // Reading a file outside the project worktree should fail let result = cx @@ -622,7 +597,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -638,7 +613,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -654,7 +629,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -669,7 +644,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -685,7 +660,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -700,7 +675,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -715,7 +690,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -731,11 +706,11 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; assert!(result.is_ok(), "Should be able to read normal files"); - assert_eq!(result.unwrap(), "Normal file content"); + assert_eq!(result.unwrap(), "Normal file content".into()); // Path traversal attempts with .. should fail let result = cx @@ -745,7 +720,7 @@ mod test { start_line: None, end_line: None, }; - tool.run(input, event_stream.stream(), cx) + tool.run(input, ToolCallEventStream::test().0, cx) }) .await; assert!( @@ -826,7 +801,6 @@ mod test { let action_log = cx.new(|_| ActionLog::new(project.clone())); let tool = Arc::new(ReadFileTool::new(project.clone(), action_log.clone())); - let event_stream = TestToolCallEventStream::new(); // Test reading allowed files in worktree1 let result = cx @@ -836,12 +810,15 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await .unwrap(); - assert_eq!(result, "fn main() { println!(\"Hello from worktree1\"); }"); + assert_eq!( + result, + "fn main() { println!(\"Hello from worktree1\"); }".into() + ); // Test reading private file in worktree1 should fail let result = cx @@ -851,7 +828,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; @@ -872,7 +849,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; @@ -893,14 +870,14 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await .unwrap(); assert_eq!( result, - "export function greet() { return 'Hello from worktree2'; }" + "export function greet() { return 'Hello from worktree2'; }".into() ); // Test reading private file in worktree2 should fail @@ -911,7 +888,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; @@ -932,7 +909,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; @@ -954,7 +931,7 @@ mod test { start_line: None, end_line: None, }; - tool.clone().run(input, event_stream.stream(), cx) + tool.clone().run(input, ToolCallEventStream::test().0, cx) }) .await; diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent2/src/tools/thinking_tool.rs index bb85d8eceb..d85370e7e5 100644 --- a/crates/agent2/src/tools/thinking_tool.rs +++ b/crates/agent2/src/tools/thinking_tool.rs @@ -20,6 +20,7 @@ pub struct ThinkingTool; impl AgentTool for ThinkingTool { type Input = ThinkingToolInput; + type Output = String; fn name(&self) -> SharedString { "thinking".into() diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 3d1fbba45d..7f4e7e7208 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -42,7 +42,7 @@ use workspace::{CollaboratorId, Workspace}; use zed_actions::agent::{Chat, NextHistoryMessage, PreviousHistoryMessage}; use ::acp_thread::{ - AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, Diff, + AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, LoadError, MentionPath, ThreadStatus, ToolCall, ToolCallContent, ToolCallStatus, }; @@ -732,7 +732,11 @@ impl AcpThreadView { cx: &App, ) -> Option>> { let entry = self.thread()?.read(cx).entries().get(entry_ix)?; - Some(entry.diffs().map(|diff| diff.multibuffer.clone())) + Some( + entry + .diffs() + .map(|diff| diff.read(cx).multibuffer().clone()), + ) } fn authenticate( @@ -1314,10 +1318,9 @@ impl AcpThreadView { Empty.into_any_element() } } - ToolCallContent::Diff { - diff: Diff { multibuffer, .. }, - .. - } => self.render_diff_editor(multibuffer), + ToolCallContent::Diff { diff, .. } => { + self.render_diff_editor(&diff.read(cx).multibuffer()) + } } } diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 90bb2e9b7c..bf668e6918 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -2,7 +2,7 @@ mod copy_path_tool; mod create_directory_tool; mod delete_path_tool; mod diagnostics_tool; -mod edit_agent; +pub mod edit_agent; mod edit_file_tool; mod fetch_tool; mod find_path_tool; @@ -14,7 +14,7 @@ mod open_tool; mod project_notifications_tool; mod read_file_tool; mod schema; -mod templates; +pub mod templates; mod terminal_tool; mod thinking_tool; mod ui; diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index 715d106a26..dcb14a48f3 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -29,7 +29,6 @@ use serde::{Deserialize, Serialize}; use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll}; use streaming_diff::{CharOperation, StreamingDiff}; use streaming_fuzzy_matcher::StreamingFuzzyMatcher; -use util::debug_panic; #[derive(Serialize)] struct CreateFilePromptTemplate { @@ -682,11 +681,6 @@ impl EditAgent { if last_message.content.is_empty() { conversation.messages.pop(); } - } else { - debug_panic!( - "Last message must be an Assistant tool calling! Got {:?}", - last_message.content - ); } } diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index dce9f49abd..311521019d 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -120,8 +120,6 @@ struct PartialInput { display_description: String, } -const DEFAULT_UI_TEXT: &str = "Editing file"; - impl Tool for EditFileTool { fn name(&self) -> String { "edit_file".into() @@ -211,22 +209,6 @@ impl Tool for EditFileTool { } } - fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { - if let Some(input) = serde_json::from_value::(input.clone()).ok() { - let description = input.display_description.trim(); - if !description.is_empty() { - return description.to_string(); - } - - let path = input.path.trim(); - if !path.is_empty() { - return path.to_string(); - } - } - - DEFAULT_UI_TEXT.to_string() - } - fn run( self: Arc, input: serde_json::Value, @@ -1370,73 +1352,6 @@ mod tests { assert_eq!(actual, expected); } - #[test] - fn still_streaming_ui_text_with_path() { - let input = json!({ - "path": "src/main.rs", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs"); - } - - #[test] - fn still_streaming_ui_text_with_description() { - let input = json!({ - "path": "", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - "Fix error handling", - ); - } - - #[test] - fn still_streaming_ui_text_with_path_and_description() { - let input = json!({ - "path": "src/main.rs", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - "Fix error handling", - ); - } - - #[test] - fn still_streaming_ui_text_no_path_or_description() { - let input = json!({ - "path": "", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - DEFAULT_UI_TEXT, - ); - } - - #[test] - fn still_streaming_ui_text_with_null() { - let input = serde_json::Value::Null; - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - DEFAULT_UI_TEXT, - ); - } - fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index dc485e9937..edce3d03b7 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -297,6 +297,12 @@ impl From for LanguageModelToolResultContent { } } +impl From for LanguageModelToolResultContent { + fn from(image: LanguageModelImage) -> Self { + Self::Image(image) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum MessageContent { Text(String),