diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 793ef35be2..2be7ea7a12 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -3,9 +3,12 @@ mod diff; mod mention; mod terminal; +use collections::HashSet; pub use connection::*; pub use diff::*; +use language::language_settings::FormatOnSave; pub use mention::*; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; use serde::{Deserialize, Serialize}; pub use terminal::*; @@ -1051,6 +1054,22 @@ impl AcpThread { }) } + pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> { + self.entries + .iter() + .enumerate() + .rev() + .find_map(|(index, tool_call)| { + if let AgentThreadEntry::ToolCall(tool_call) = tool_call + && &tool_call.id == id + { + Some((index, tool_call)) + } else { + None + } + }) + } + pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context) { let project = self.project.clone(); let Some((_, tool_call)) = self.tool_call_mut(&id) else { @@ -1601,30 +1620,59 @@ impl AcpThread { .collect::>() }) .await; - cx.update(|cx| { - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: edits - .last() - .map(|(range, _)| range.end) - .unwrap_or(Anchor::MIN), - }), - cx, - ); - }); + project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: edits + .last() + .map(|(range, _)| range.end) + .unwrap_or(Anchor::MIN), + }), + cx, + ); + })?; + + let format_on_save = cx.update(|cx| { action_log.update(cx, |action_log, cx| { action_log.buffer_read(buffer.clone(), cx); }); - buffer.update(cx, |buffer, cx| { + + let format_on_save = buffer.update(cx, |buffer, cx| { buffer.edit(edits, None, cx); + + let settings = language::language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + + settings.format_on_save != FormatOnSave::Off }); action_log.update(cx, |action_log, cx| { action_log.buffer_edited(buffer.clone(), cx); }); + format_on_save })?; + + if format_on_save { + let format_task = project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + })?; + format_task.await.log_err(); + + action_log.update(cx, |action_log, cx| { + action_log.buffer_edited(buffer.clone(), cx); + })?; + } + project .update(cx, |project, cx| project.save_buffer(buffer, cx))? .await diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index cbc874057a..8cd6980ae1 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -29,6 +29,7 @@ futures.workspace = true gpui.workspace = true indoc.workspace = true itertools.workspace = true +language.workspace = true language_model.workspace = true language_models.workspace = true log.workspace = true diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index e214ee875c..a53c81d4c4 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -1,5 +1,9 @@ +mod edit_tool; mod mcp_server; +mod permission_tool; +mod read_tool; pub mod tools; +mod write_tool; use action_log::ActionLog; use collections::HashMap; @@ -351,18 +355,16 @@ fn spawn_claude( &format!( "mcp__{}__{}", mcp_server::SERVER_NAME, - mcp_server::PermissionTool::NAME, + permission_tool::PermissionTool::NAME, ), "--allowedTools", &format!( - "mcp__{}__{},mcp__{}__{}", + "mcp__{}__{}", mcp_server::SERVER_NAME, - mcp_server::EditTool::NAME, - mcp_server::SERVER_NAME, - mcp_server::ReadTool::NAME + read_tool::ReadTool::NAME ), "--disallowedTools", - "Read,Edit", + "Read,Write,Edit,MultiEdit", ]) .args(match mode { ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], @@ -470,9 +472,16 @@ impl ClaudeAgentSession { let content = content.to_string(); thread .update(cx, |thread, cx| { + let id = acp::ToolCallId(tool_use_id.into()); + let set_new_content = !content.is_empty() + && thread.tool_call(&id).is_none_or(|(_, tool_call)| { + // preserve rich diff if we have one + tool_call.diffs().next().is_none() + }); + thread.update_tool_call( acp::ToolCallUpdate { - id: acp::ToolCallId(tool_use_id.into()), + id, fields: acp::ToolCallUpdateFields { status: if turn_state.borrow().is_canceled() { // Do not set to completed if turn was canceled @@ -480,7 +489,7 @@ impl ClaudeAgentSession { } else { Some(acp::ToolCallStatus::Completed) }, - content: (!content.is_empty()) + content: set_new_content .then(|| vec![content.into()]), ..Default::default() }, diff --git a/crates/agent_servers/src/claude/edit_tool.rs b/crates/agent_servers/src/claude/edit_tool.rs new file mode 100644 index 0000000000..a8d93c3f3d --- /dev/null +++ b/crates/agent_servers/src/claude/edit_tool.rs @@ -0,0 +1,178 @@ +use acp_thread::AcpThread; +use anyhow::Result; +use context_server::{ + listener::{McpServerTool, ToolResponse}, + types::{ToolAnnotations, ToolResponseContent}, +}; +use gpui::{AsyncApp, WeakEntity}; +use language::unified_diff; +use util::markdown::MarkdownCodeBlock; + +use crate::tools::EditToolParams; + +#[derive(Clone)] +pub struct EditTool { + thread_rx: watch::Receiver>, +} + +impl EditTool { + pub fn new(thread_rx: watch::Receiver>) -> Self { + Self { thread_rx } + } +} + +impl McpServerTool for EditTool { + type Input = EditToolParams; + type Output = (); + + const NAME: &'static str = "Edit"; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path.clone(), None, None, true, cx) + })? + .await?; + + let (new_content, diff) = cx + .background_executor() + .spawn(async move { + let new_content = content.replace(&input.old_text, &input.new_text); + if new_content == content { + return Err(anyhow::anyhow!("Failed to find `old_text`",)); + } + let diff = unified_diff(&content, &new_content); + + Ok((new_content, diff)) + }) + .await?; + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, new_content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: MarkdownCodeBlock { + tag: "diff", + text: diff.as_str().trim_end_matches('\n'), + } + .to_string(), + }], + structured_content: (), + }) + } +} + +#[cfg(test)] +mod tests { + use std::rc::Rc; + + use acp_thread::{AgentConnection, StubAgentConnection}; + use gpui::{Entity, TestAppContext}; + use indoc::indoc; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use super::*; + + #[gpui::test] + async fn old_text_not_found(cx: &mut TestAppContext) { + let (_thread, tool) = init_test(cx).await; + + let result = tool + .run( + EditToolParams { + abs_path: path!("/root/file.txt").into(), + old_text: "hi".into(), + new_text: "bye".into(), + }, + &mut cx.to_async(), + ) + .await; + + assert_eq!(result.unwrap_err().to_string(), "Failed to find `old_text`"); + } + + #[gpui::test] + async fn found_and_replaced(cx: &mut TestAppContext) { + let (_thread, tool) = init_test(cx).await; + + let result = tool + .run( + EditToolParams { + abs_path: path!("/root/file.txt").into(), + old_text: "hello".into(), + new_text: "hi".into(), + }, + &mut cx.to_async(), + ) + .await; + + assert_eq!( + result.unwrap().content[0].text().unwrap(), + indoc! { + r" + ```diff + @@ -1,1 +1,1 @@ + -hello + +hi + ``` + " + } + ); + } + + async fn init_test(cx: &mut TestAppContext) -> (Entity, EditTool) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let connection = Rc::new(StubAgentConnection::new()); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "file.txt": "hello" + }), + ) + .await; + let project = Project::test(fs, [path!("/root").as_ref()], cx).await; + let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid()); + + let thread = cx + .update(|cx| connection.new_thread(project, path!("/test").as_ref(), cx)) + .await + .unwrap(); + + thread_tx.send(thread.downgrade()).unwrap(); + + (thread, EditTool::new(thread_rx)) + } +} diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 3086752850..6442c784b5 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,23 +1,22 @@ use std::path::PathBuf; use std::sync::Arc; -use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; +use crate::claude::edit_tool::EditTool; +use crate::claude::permission_tool::PermissionTool; +use crate::claude::read_tool::ReadTool; +use crate::claude::write_tool::WriteTool; use acp_thread::AcpThread; -use agent_client_protocol as acp; -use agent_settings::AgentSettings; -use anyhow::{Context, Result}; +#[cfg(not(test))] +use anyhow::Context as _; +use anyhow::Result; use collections::HashMap; -use context_server::listener::{McpServerTool, ToolResponse}; use context_server::types::{ Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, - ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, + ToolsCapabilities, requests, }; use gpui::{App, AsyncApp, Task, WeakEntity}; use project::Fs; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::{Settings as _, update_settings_file}; -use util::debug_panic; +use serde::Serialize; pub struct ClaudeZedMcpServer { server: context_server::listener::McpServer, @@ -34,16 +33,10 @@ impl ClaudeZedMcpServer { let mut mcp_server = context_server::listener::McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); - mcp_server.add_tool(PermissionTool { - thread_rx: thread_rx.clone(), - fs: fs.clone(), - }); - mcp_server.add_tool(ReadTool { - thread_rx: thread_rx.clone(), - }); - mcp_server.add_tool(EditTool { - thread_rx: thread_rx.clone(), - }); + mcp_server.add_tool(PermissionTool::new(fs.clone(), thread_rx.clone())); + mcp_server.add_tool(ReadTool::new(thread_rx.clone())); + mcp_server.add_tool(EditTool::new(thread_rx.clone())); + mcp_server.add_tool(WriteTool::new(thread_rx.clone())); Ok(Self { server: mcp_server }) } @@ -104,249 +97,3 @@ pub struct McpServerConfig { #[serde(skip_serializing_if = "Option::is_none")] pub env: Option>, } - -// Tools - -#[derive(Clone)] -pub struct PermissionTool { - fs: Arc, - thread_rx: watch::Receiver>, -} - -#[derive(Deserialize, JsonSchema, Debug)] -pub struct PermissionToolParams { - tool_name: String, - input: serde_json::Value, - tool_use_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PermissionToolResponse { - behavior: PermissionToolBehavior, - updated_input: serde_json::Value, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -enum PermissionToolBehavior { - Allow, - Deny, -} - -impl McpServerTool for PermissionTool { - type Input = PermissionToolParams; - type Output = (); - - const NAME: &'static str = "Confirmation"; - - fn description(&self) -> &'static str { - "Request permission for tool calls" - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - if agent_settings::AgentSettings::try_read_global(cx, |settings| { - settings.always_allow_tool_actions - }) - .unwrap_or(false) - { - let response = PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: input.input, - }; - - return Ok(ToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&response)?, - }], - structured_content: (), - }); - } - - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); - let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); - - const ALWAYS_ALLOW: &str = "always_allow"; - const ALLOW: &str = "allow"; - const REJECT: &str = "reject"; - - let chosen_option = thread - .update(cx, |thread, cx| { - thread.request_tool_call_authorization( - claude_tool.as_acp(tool_call_id).into(), - vec![ - acp::PermissionOption { - id: acp::PermissionOptionId(ALWAYS_ALLOW.into()), - name: "Always Allow".into(), - kind: acp::PermissionOptionKind::AllowAlways, - }, - acp::PermissionOption { - id: acp::PermissionOptionId(ALLOW.into()), - name: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }, - acp::PermissionOption { - id: acp::PermissionOptionId(REJECT.into()), - name: "Reject".into(), - kind: acp::PermissionOptionKind::RejectOnce, - }, - ], - cx, - ) - })?? - .await?; - - let response = match chosen_option.0.as_ref() { - ALWAYS_ALLOW => { - cx.update(|cx| { - update_settings_file::(self.fs.clone(), cx, |settings, _| { - settings.set_always_allow_tool_actions(true); - }); - })?; - - PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: input.input, - } - } - ALLOW => PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: input.input, - }, - REJECT => PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: input.input, - }, - opt => { - debug_panic!("Unexpected option: {}", opt); - PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: input.input, - } - } - }; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&response)?, - }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct ReadTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for ReadTool { - type Input = ReadToolParams; - type Output = (); - - const NAME: &'static str = "Read"; - - fn description(&self) -> &'static str { - "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Read file".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: None, - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![ToolResponseContent::Text { text: content }], - structured_content: (), - }) - } -} - -#[derive(Clone)] -pub struct EditTool { - thread_rx: watch::Receiver>, -} - -impl McpServerTool for EditTool { - type Input = EditToolParams; - type Output = (); - - const NAME: &'static str = "Edit"; - - fn description(&self) -> &'static str { - "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." - } - - fn annotations(&self) -> ToolAnnotations { - ToolAnnotations { - title: Some("Edit file".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: Some(false), - } - } - - async fn run( - &self, - input: Self::Input, - cx: &mut AsyncApp, - ) -> Result> { - let mut thread_rx = self.thread_rx.clone(); - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(input.abs_path.clone(), None, None, true, cx) - })? - .await?; - - let new_content = content.replace(&input.old_text, &input.new_text); - if new_content == content { - return Err(anyhow::anyhow!("The old_text was not found in the content")); - } - - thread - .update(cx, |thread, cx| { - thread.write_text_file(input.abs_path, new_content, cx) - })? - .await?; - - Ok(ToolResponse { - content: vec![], - structured_content: (), - }) - } -} diff --git a/crates/agent_servers/src/claude/permission_tool.rs b/crates/agent_servers/src/claude/permission_tool.rs new file mode 100644 index 0000000000..96a24105e8 --- /dev/null +++ b/crates/agent_servers/src/claude/permission_tool.rs @@ -0,0 +1,158 @@ +use std::sync::Arc; + +use acp_thread::AcpThread; +use agent_client_protocol as acp; +use agent_settings::AgentSettings; +use anyhow::{Context as _, Result}; +use context_server::{ + listener::{McpServerTool, ToolResponse}, + types::ToolResponseContent, +}; +use gpui::{AsyncApp, WeakEntity}; +use project::Fs; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings as _, update_settings_file}; +use util::debug_panic; + +use crate::tools::ClaudeTool; + +#[derive(Clone)] +pub struct PermissionTool { + fs: Arc, + thread_rx: watch::Receiver>, +} + +/// Request permission for tool calls +#[derive(Deserialize, JsonSchema, Debug)] +pub struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl PermissionTool { + pub fn new(fs: Arc, thread_rx: watch::Receiver>) -> Self { + Self { fs, thread_rx } + } +} + +impl McpServerTool for PermissionTool { + type Input = PermissionToolParams; + type Output = (); + + const NAME: &'static str = "Confirmation"; + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + if agent_settings::AgentSettings::try_read_global(cx, |settings| { + settings.always_allow_tool_actions + }) + .unwrap_or(false) + { + let response = PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + }; + + return Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: (), + }); + } + + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); + let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); + + const ALWAYS_ALLOW: &str = "always_allow"; + const ALLOW: &str = "allow"; + const REJECT: &str = "reject"; + + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + claude_tool.as_acp(tool_call_id).into(), + vec![ + acp::PermissionOption { + id: acp::PermissionOptionId(ALWAYS_ALLOW.into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + }, + acp::PermissionOption { + id: acp::PermissionOptionId(ALLOW.into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: acp::PermissionOptionId(REJECT.into()), + name: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, + ) + })?? + .await?; + + let response = match chosen_option.0.as_ref() { + ALWAYS_ALLOW => { + cx.update(|cx| { + update_settings_file::(self.fs.clone(), cx, |settings, _| { + settings.set_always_allow_tool_actions(true); + }); + })?; + + PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + } + } + ALLOW => PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + }, + REJECT => PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + }, + opt => { + debug_panic!("Unexpected option: {}", opt); + PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + } + } + }; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: (), + }) + } +} diff --git a/crates/agent_servers/src/claude/read_tool.rs b/crates/agent_servers/src/claude/read_tool.rs new file mode 100644 index 0000000000..cbe25876b3 --- /dev/null +++ b/crates/agent_servers/src/claude/read_tool.rs @@ -0,0 +1,59 @@ +use acp_thread::AcpThread; +use anyhow::Result; +use context_server::{ + listener::{McpServerTool, ToolResponse}, + types::{ToolAnnotations, ToolResponseContent}, +}; +use gpui::{AsyncApp, WeakEntity}; + +use crate::tools::ReadToolParams; + +#[derive(Clone)] +pub struct ReadTool { + thread_rx: watch::Receiver>, +} + +impl ReadTool { + pub fn new(thread_rx: watch::Receiver>) -> Self { + Self { thread_rx } + } +} + +impl McpServerTool for ReadTool { + type Input = ReadToolParams; + type Output = (); + + const NAME: &'static str = "Read"; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: None, + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { text: content }], + structured_content: (), + }) + } +} diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index 7ca150c0bd..3be10ed94c 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -34,6 +34,7 @@ impl ClaudeTool { // Known tools "mcp__zed__Read" => Self::ReadFile(serde_json::from_value(input).log_err()), "mcp__zed__Edit" => Self::Edit(serde_json::from_value(input).log_err()), + "mcp__zed__Write" => Self::Write(serde_json::from_value(input).log_err()), "MultiEdit" => Self::MultiEdit(serde_json::from_value(input).log_err()), "Write" => Self::Write(serde_json::from_value(input).log_err()), "LS" => Self::Ls(serde_json::from_value(input).log_err()), @@ -93,7 +94,7 @@ impl ClaudeTool { } Self::MultiEdit(None) => "Multi Edit".into(), Self::Write(Some(params)) => { - format!("Write {}", params.file_path.display()) + format!("Write {}", params.abs_path.display()) } Self::Write(None) => "Write".into(), Self::Glob(Some(params)) => { @@ -153,7 +154,7 @@ impl ClaudeTool { }], Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { diff: acp::Diff { - path: params.file_path.clone(), + path: params.abs_path.clone(), old_text: None, new_text: params.content.clone(), }, @@ -229,7 +230,10 @@ impl ClaudeTool { line: None, }] } - Self::Write(Some(WriteToolParams { file_path, .. })) => { + Self::Write(Some(WriteToolParams { + abs_path: file_path, + .. + })) => { vec![acp::ToolCallLocation { path: file_path.clone(), line: None, @@ -302,6 +306,20 @@ impl ClaudeTool { } } +/// Edit a file. +/// +/// In sessions with mcp__zed__Edit always use it instead of Edit as it will +/// allow the user to conveniently review changes. +/// +/// File editing instructions: +/// - The `old_text` param must match existing file content, including indentation. +/// - The `old_text` param must come from the actual file, not an outline. +/// - The `old_text` section must not be empty. +/// - Be minimal with replacements: +/// - For unique lines, include only those lines. +/// - For non-unique lines, include enough context to identify them. +/// - Do not escape quotes, newlines, or other characters. +/// - Only edit the specified file. #[derive(Deserialize, JsonSchema, Debug)] pub struct EditToolParams { /// The absolute path to the file to read. @@ -312,6 +330,11 @@ pub struct EditToolParams { pub new_text: String, } +/// Reads the content of the given file in the project. +/// +/// Never attempt to read a path that hasn't been previously mentioned. +/// +/// In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents. #[derive(Deserialize, JsonSchema, Debug)] pub struct ReadToolParams { /// The absolute path to the file to read. @@ -324,11 +347,15 @@ pub struct ReadToolParams { pub limit: Option, } +/// Writes content to the specified file in the project. +/// +/// In sessions with mcp__zed__Write always use it instead of Write as it will +/// allow the user to conveniently review changes. #[derive(Deserialize, JsonSchema, Debug)] pub struct WriteToolParams { - /// Absolute path for new file - pub file_path: PathBuf, - /// File content + /// The absolute path of the file to write. + pub abs_path: PathBuf, + /// The full content to write. pub content: String, } diff --git a/crates/agent_servers/src/claude/write_tool.rs b/crates/agent_servers/src/claude/write_tool.rs new file mode 100644 index 0000000000..39479a9c38 --- /dev/null +++ b/crates/agent_servers/src/claude/write_tool.rs @@ -0,0 +1,59 @@ +use acp_thread::AcpThread; +use anyhow::Result; +use context_server::{ + listener::{McpServerTool, ToolResponse}, + types::ToolAnnotations, +}; +use gpui::{AsyncApp, WeakEntity}; + +use crate::tools::WriteToolParams; + +#[derive(Clone)] +pub struct WriteTool { + thread_rx: watch::Receiver>, +} + +impl WriteTool { + pub fn new(thread_rx: watch::Receiver>) -> Self { + Self { thread_rx } + } +} + +impl McpServerTool for WriteTool { + type Input = WriteToolParams; + type Output = (); + + const NAME: &'static str = "Write"; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Write file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> Result> { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, input.content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: (), + }) + } +} diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 6f4b5c1369..1b44cefbd2 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -14,6 +14,7 @@ use serde::de::DeserializeOwned; use serde_json::{json, value::RawValue}; use smol::stream::StreamExt; use std::{ + any::TypeId, cell::RefCell, path::{Path, PathBuf}, rc::Rc, @@ -87,18 +88,26 @@ impl McpServer { settings.inline_subschemas = true; let mut generator = settings.into_generator(); - let output_schema = generator.root_schema_for::(); - let unit_schema = generator.root_schema_for::(); + let input_schema = generator.root_schema_for::(); + + let description = input_schema + .get("description") + .and_then(|desc| desc.as_str()) + .map(|desc| desc.to_string()); + debug_assert!( + description.is_some(), + "Input schema struct must include a doc comment for the tool description" + ); let registered_tool = RegisteredTool { tool: Tool { name: T::NAME.into(), - description: Some(tool.description().into()), - input_schema: generator.root_schema_for::().into(), - output_schema: if output_schema == unit_schema { + description, + input_schema: input_schema.into(), + output_schema: if TypeId::of::() == TypeId::of::<()>() { None } else { - Some(output_schema.into()) + Some(generator.root_schema_for::().into()) }, annotations: Some(tool.annotations()), }, @@ -399,8 +408,6 @@ pub trait McpServerTool { const NAME: &'static str; - fn description(&self) -> &'static str; - fn annotations(&self) -> ToolAnnotations { ToolAnnotations { title: None, @@ -418,6 +425,7 @@ pub trait McpServerTool { ) -> impl Future>>; } +#[derive(Debug)] pub struct ToolResponse { pub content: Vec, pub structured_content: T, diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index e92a18c763..03aca4f3ca 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -711,6 +711,16 @@ pub enum ToolResponseContent { Resource { resource: ResourceContents }, } +impl ToolResponseContent { + pub fn text(&self) -> Option<&str> { + if let ToolResponseContent::Text { text } = self { + Some(text) + } else { + None + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListToolsResponse {