From 15c9da4ea4e740faa22cf7c740bed666fdf9ff4a Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 24 Jul 2025 23:19:20 -0300 Subject: [PATCH] Add ability to register tools in `McpServer` (#35068) Makes it easier to add tools to a server by implementing a trait Release Notes: - N/A --- crates/agent_servers/src/claude.rs | 11 +- crates/agent_servers/src/claude/mcp_server.rs | 418 +++++++++--------- crates/context_server/src/listener.rs | 217 ++++++++- crates/context_server/src/types.rs | 6 +- 4 files changed, 408 insertions(+), 244 deletions(-) diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 5f35b4af73..d63d8c43cf 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -2,6 +2,7 @@ mod mcp_server; pub mod tools; use collections::HashMap; +use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; @@ -332,10 +333,16 @@ async fn spawn_claude( &format!( "mcp__{}__{}", mcp_server::SERVER_NAME, - mcp_server::PERMISSION_TOOL + mcp_server::PermissionTool::NAME, ), "--allowedTools", - "mcp__zed__Read,mcp__zed__Edit", + &format!( + "mcp__{}__{},mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::EditTool::NAME, + mcp_server::SERVER_NAME, + mcp_server::ReadTool::NAME + ), "--disallowedTools", "Read,Edit", ]) diff --git a/crates/agent_servers/src/claude/mcp_server.rs b/crates/agent_servers/src/claude/mcp_server.rs index 0a39a02931..4272a972dc 100644 --- a/crates/agent_servers/src/claude/mcp_server.rs +++ b/crates/agent_servers/src/claude/mcp_server.rs @@ -1,49 +1,24 @@ use std::path::PathBuf; +use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; use acp_thread::AcpThread; use agent_client_protocol as acp; use anyhow::{Context, Result}; use collections::HashMap; +use context_server::listener::{McpServerTool, ToolResponse}; use context_server::types::{ - CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse, - ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations, - ToolResponseContent, ToolsCapabilities, requests, + Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities, + ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests, }; -use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; +use gpui::{App, AsyncApp, Task, WeakEntity}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams}; - pub struct ClaudeZedMcpServer { server: context_server::listener::McpServer, } pub const SERVER_NAME: &str = "zed"; -pub const READ_TOOL: &str = "Read"; -pub const EDIT_TOOL: &str = "Edit"; -pub const PERMISSION_TOOL: &str = "Confirmation"; - -#[derive(Deserialize, JsonSchema, Debug)] -struct PermissionToolParams { - tool_name: String, - input: serde_json::Value, - tool_use_id: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct PermissionToolResponse { - behavior: PermissionToolBehavior, - updated_input: serde_json::Value, -} - -#[derive(Serialize)] -#[serde(rename_all = "snake_case")] -enum PermissionToolBehavior { - Allow, - Deny, -} impl ClaudeZedMcpServer { pub async fn new( @@ -52,9 +27,15 @@ impl ClaudeZedMcpServer { ) -> Result { let mut mcp_server = context_server::listener::McpServer::new(cx).await?; mcp_server.handle_request::(Self::handle_initialize); - mcp_server.handle_request::(Self::handle_list_tools); - mcp_server.handle_request::(move |request, cx| { - Self::handle_call_tool(request, thread_rx.clone(), cx) + + mcp_server.add_tool(PermissionTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(ReadTool { + thread_rx: thread_rx.clone(), + }); + mcp_server.add_tool(EditTool { + thread_rx: thread_rx.clone(), }); Ok(Self { server: mcp_server }) @@ -96,193 +77,6 @@ impl ClaudeZedMcpServer { }) }) } - - fn handle_list_tools(_: (), cx: &App) -> Task> { - cx.foreground_executor().spawn(async move { - Ok(ListToolsResponse { - tools: vec![ - Tool { - name: PERMISSION_TOOL.into(), - input_schema: schemars::schema_for!(PermissionToolParams).into(), - description: None, - annotations: None, - }, - Tool { - name: READ_TOOL.into(), - input_schema: schemars::schema_for!(ReadToolParams).into(), - description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()), - annotations: Some(ToolAnnotations { - title: Some("Read file".to_string()), - read_only_hint: Some(true), - destructive_hint: Some(false), - open_world_hint: Some(false), - // if time passes the contents might change, but it's not going to do anything different - // true or false seem too strong, let's try a none. - idempotent_hint: None, - }), - }, - Tool { - name: EDIT_TOOL.into(), - input_schema: schemars::schema_for!(EditToolParams).into(), - description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()), - annotations: Some(ToolAnnotations { - title: Some("Edit file".to_string()), - read_only_hint: Some(false), - destructive_hint: Some(false), - open_world_hint: Some(false), - idempotent_hint: Some(false), - }), - }, - ], - next_cursor: None, - meta: None, - }) - }) - } - - fn handle_call_tool( - request: CallToolParams, - mut thread_rx: watch::Receiver>, - cx: &App, - ) -> Task> { - cx.spawn(async move |cx| { - let Some(thread) = thread_rx.recv().await?.upgrade() else { - anyhow::bail!("Thread closed"); - }; - - if request.name.as_str() == PERMISSION_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; - - let result = Self::handle_permissions_tool_call(input, thread, cx).await?; - Ok(CallToolResponse { - content: vec![ToolResponseContent::Text { - text: serde_json::to_string(&result)?, - }], - is_error: None, - meta: None, - }) - } else if request.name.as_str() == READ_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; - - let content = Self::handle_read_tool_call(input, thread, cx).await?; - Ok(CallToolResponse { - content, - is_error: None, - meta: None, - }) - } else if request.name.as_str() == EDIT_TOOL { - let input = - serde_json::from_value(request.arguments.context("Arguments required")?)?; - - Self::handle_edit_tool_call(input, thread, cx).await?; - Ok(CallToolResponse { - content: vec![], - is_error: None, - meta: None, - }) - } else { - anyhow::bail!("Unsupported tool"); - } - }) - } - - fn handle_read_tool_call( - ReadToolParams { - abs_path, - offset, - limit, - }: ReadToolParams, - thread: Entity, - cx: &AsyncApp, - ) -> Task>> { - cx.spawn(async move |cx| { - let content = thread - .update(cx, |thread, cx| { - thread.read_text_file(abs_path, offset, limit, false, cx) - })? - .await?; - - Ok(vec![ToolResponseContent::Text { text: content }]) - }) - } - - fn handle_edit_tool_call( - params: EditToolParams, - thread: Entity, - cx: &AsyncApp, - ) -> Task> { - cx.spawn(async move |cx| { - let content = thread - .update(cx, |threads, cx| { - threads.read_text_file(params.abs_path.clone(), None, None, true, cx) - })? - .await?; - - let new_content = content.replace(¶ms.old_text, ¶ms.new_text); - if new_content == content { - return Err(anyhow::anyhow!("The old_text was not found in the content")); - } - - thread - .update(cx, |threads, cx| { - threads.write_text_file(params.abs_path, new_content, cx) - })? - .await?; - - Ok(()) - }) - } - - fn handle_permissions_tool_call( - params: PermissionToolParams, - thread: Entity, - cx: &AsyncApp, - ) -> Task> { - cx.spawn(async move |cx| { - let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone()); - - let tool_call_id = - acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into()); - - let allow_option_id = acp::PermissionOptionId("allow".into()); - let reject_option_id = acp::PermissionOptionId("reject".into()); - - let chosen_option = thread - .update(cx, |thread, cx| { - thread.request_tool_call_permission( - claude_tool.as_acp(tool_call_id), - vec![ - acp::PermissionOption { - id: allow_option_id.clone(), - label: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - }, - acp::PermissionOption { - id: reject_option_id, - label: "Reject".into(), - kind: acp::PermissionOptionKind::RejectOnce, - }, - ], - cx, - ) - })? - .await?; - - if chosen_option == allow_option_id { - Ok(PermissionToolResponse { - behavior: PermissionToolBehavior::Allow, - updated_input: params.input, - }) - } else { - Ok(PermissionToolResponse { - behavior: PermissionToolBehavior::Deny, - updated_input: params.input, - }) - } - }) - } } #[derive(Serialize)] @@ -299,3 +93,187 @@ pub struct McpServerConfig { #[serde(skip_serializing_if = "Option::is_none")] pub env: Option>, } + +// Tools + +#[derive(Clone)] +pub struct PermissionTool { + thread_rx: watch::Receiver>, +} + +#[derive(Deserialize, JsonSchema, Debug)] +pub struct PermissionToolParams { + tool_name: String, + input: serde_json::Value, + tool_use_id: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionToolResponse { + behavior: PermissionToolBehavior, + updated_input: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +enum PermissionToolBehavior { + Allow, + Deny, +} + +impl McpServerTool for PermissionTool { + type Input = PermissionToolParams; + const NAME: &'static str = "Confirmation"; + + fn description(&self) -> &'static str { + "Request permission for tool calls" + } + + async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone()); + let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into()); + let allow_option_id = acp::PermissionOptionId("allow".into()); + let reject_option_id = acp::PermissionOptionId("reject".into()); + + let chosen_option = thread + .update(cx, |thread, cx| { + thread.request_tool_call_permission( + claude_tool.as_acp(tool_call_id), + vec![ + acp::PermissionOption { + id: allow_option_id.clone(), + label: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + }, + acp::PermissionOption { + id: reject_option_id.clone(), + label: "Reject".into(), + kind: acp::PermissionOptionKind::RejectOnce, + }, + ], + cx, + ) + })? + .await?; + + let response = if chosen_option == allow_option_id { + PermissionToolResponse { + behavior: PermissionToolBehavior::Allow, + updated_input: input.input, + } + } else { + PermissionToolResponse { + behavior: PermissionToolBehavior::Deny, + updated_input: input.input, + } + }; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { + text: serde_json::to_string(&response)?, + }], + structured_content: None, + }) + } +} + +#[derive(Clone)] +pub struct ReadTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for ReadTool { + type Input = ReadToolParams; + const NAME: &'static str = "Read"; + + fn description(&self) -> &'static str { + "Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Read file".to_string()), + read_only_hint: Some(true), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: None, + } + } + + async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![ToolResponseContent::Text { text: content }], + structured_content: None, + }) + } +} + +#[derive(Clone)] +pub struct EditTool { + thread_rx: watch::Receiver>, +} + +impl McpServerTool for EditTool { + type Input = EditToolParams; + const NAME: &'static str = "Edit"; + + fn description(&self) -> &'static str { + "Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better." + } + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: Some("Edit file".to_string()), + read_only_hint: Some(false), + destructive_hint: Some(false), + open_world_hint: Some(false), + idempotent_hint: Some(false), + } + } + + async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result { + let mut thread_rx = self.thread_rx.clone(); + let Some(thread) = thread_rx.recv().await?.upgrade() else { + anyhow::bail!("Thread closed"); + }; + + let content = thread + .update(cx, |thread, cx| { + thread.read_text_file(input.abs_path.clone(), None, None, true, cx) + })? + .await?; + + let new_content = content.replace(&input.old_text, &input.new_text); + if new_content == content { + return Err(anyhow::anyhow!("The old_text was not found in the content")); + } + + thread + .update(cx, |thread, cx| { + thread.write_text_file(input.abs_path, new_content, cx) + })? + .await?; + + Ok(ToolResponse { + content: vec![], + structured_content: None, + }) + } +} diff --git a/crates/context_server/src/listener.rs b/crates/context_server/src/listener.rs index 9295ad979c..087395a961 100644 --- a/crates/context_server/src/listener.rs +++ b/crates/context_server/src/listener.rs @@ -9,6 +9,8 @@ use futures::{ }; use gpui::{App, AppContext, AsyncApp, Task}; use net::async_net::{UnixListener, UnixStream}; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; use serde_json::{json, value::RawValue}; use smol::stream::StreamExt; use std::{ @@ -20,16 +22,28 @@ use util::ResultExt; use crate::{ client::{CspResult, RequestId, Response}, - types::Request, + types::{ + CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations, + ToolResponseContent, + requests::{CallTool, ListTools}, + }, }; pub struct McpServer { socket_path: PathBuf, - handlers: Rc>>, + tools: Rc>>, + handlers: Rc>>, _server_task: Task<()>, } -type McpHandler = Box>, &App) -> Task>; +struct RegisteredTool { + tool: Tool, + handler: ToolHandler, +} + +type ToolHandler = + Box, &mut AsyncApp) -> Task>>; +type RequestHandler = Box>, &App) -> Task>; impl McpServer { pub fn new(cx: &AsyncApp) -> Task> { @@ -43,12 +57,14 @@ impl McpServer { cx.spawn(async move |cx| { let (temp_dir, socket_path, listener) = task.await?; + let tools = Rc::new(RefCell::new(HashMap::default())); let handlers = Rc::new(RefCell::new(HashMap::default())); let server_task = cx.spawn({ + let tools = tools.clone(); let handlers = handlers.clone(); async move |cx| { while let Ok((stream, _)) = listener.accept().await { - Self::serve_connection(stream, handlers.clone(), cx); + Self::serve_connection(stream, tools.clone(), handlers.clone(), cx); } drop(temp_dir) } @@ -56,11 +72,40 @@ impl McpServer { Ok(Self { socket_path, _server_task: server_task, - handlers: handlers.clone(), + tools, + handlers: handlers, }) }) } + pub fn add_tool(&mut self, tool: T) { + let registered_tool = RegisteredTool { + tool: Tool { + name: T::NAME.into(), + description: Some(tool.description().into()), + input_schema: schemars::schema_for!(T::Input).into(), + annotations: Some(tool.annotations()), + }, + handler: Box::new({ + let tool = tool.clone(); + move |input_value, cx| { + let input = match input_value { + Some(input) => serde_json::from_value(input), + None => serde_json::from_value(serde_json::Value::Null), + }; + + let tool = tool.clone(); + match input { + Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await), + Err(err) => Task::ready(Err(err.into())), + } + } + }), + }; + + self.tools.borrow_mut().insert(T::NAME, registered_tool); + } + pub fn handle_request( &mut self, f: impl Fn(R::Params, &App) -> Task> + 'static, @@ -120,7 +165,8 @@ impl McpServer { fn serve_connection( stream: UnixStream, - handlers: Rc>>, + tools: Rc>>, + handlers: Rc>>, cx: &mut AsyncApp, ) { let (read, write) = smol::io::split(stream); @@ -135,7 +181,13 @@ impl McpServer { let Some(request_id) = request.id.clone() else { continue; }; - if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { + + if request.method == CallTool::METHOD { + Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx) + .await; + } else if request.method == ListTools::METHOD { + Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx); + } else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) { let outgoing_tx = outgoing_tx.clone(); if let Some(task) = cx @@ -149,25 +201,122 @@ impl McpServer { .detach(); } } else { - outgoing_tx - .unbounded_send( - serde_json::to_string(&Response::<()> { - jsonrpc: "2.0", - id: request.id.unwrap(), - value: CspResult::Error(Some(crate::client::Error { - message: format!("unhandled method {}", request.method), - code: -32601, - })), - }) - .unwrap(), - ) - .ok(); + Self::send_err( + request_id, + format!("unhandled method {}", request.method), + &outgoing_tx, + ); } } }) .detach(); } + fn handle_list_tools( + request_id: RequestId, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + ) { + let response = ListToolsResponse { + tools: tools.borrow().values().map(|t| t.tool.clone()).collect(), + next_cursor: None, + meta: None, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + } + + async fn handle_call_tool( + request_id: RequestId, + params: Option>, + tools: &Rc>>, + outgoing_tx: &UnboundedSender, + cx: &mut AsyncApp, + ) { + let result: Result = match params.as_ref() { + Some(params) => serde_json::from_str(params.get()), + None => serde_json::from_value(serde_json::Value::Null), + }; + + match result { + Ok(params) => { + if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) { + let outgoing_tx = outgoing_tx.clone(); + + let task = (tool.handler)(params.arguments, cx); + cx.spawn(async move |_| { + let response = match task.await { + Ok(result) => CallToolResponse { + content: result.content, + is_error: Some(false), + meta: None, + structured_content: result.structured_content, + }, + Err(err) => CallToolResponse { + content: vec![ToolResponseContent::Text { + text: err.to_string(), + }], + is_error: Some(true), + meta: None, + structured_content: None, + }, + }; + + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Ok(Some(response)), + }) + .unwrap_or_default(), + ) + .ok(); + }) + .detach(); + } else { + Self::send_err( + request_id, + format!("Tool not found: {}", params.name), + &outgoing_tx, + ); + } + } + Err(err) => { + Self::send_err(request_id, err.to_string(), &outgoing_tx); + } + } + } + + fn send_err( + request_id: RequestId, + message: impl Into, + outgoing_tx: &UnboundedSender, + ) { + outgoing_tx + .unbounded_send( + serde_json::to_string(&Response::<()> { + jsonrpc: "2.0", + id: request_id, + value: CspResult::Error(Some(crate::client::Error { + message: message.into(), + code: -32601, + })), + }) + .unwrap(), + ) + .ok(); + } + async fn handle_io( mut outgoing_rx: UnboundedReceiver, incoming_tx: UnboundedSender, @@ -216,6 +365,34 @@ impl McpServer { } } +pub trait McpServerTool { + type Input: DeserializeOwned + JsonSchema; + const NAME: &'static str; + + fn description(&self) -> &'static str; + + fn annotations(&self) -> ToolAnnotations { + ToolAnnotations { + title: None, + read_only_hint: None, + destructive_hint: None, + idempotent_hint: None, + open_world_hint: None, + } + } + + fn run( + &self, + input: Self::Input, + cx: &mut AsyncApp, + ) -> impl Future>; +} + +pub struct ToolResponse { + pub content: Vec, + pub structured_content: Option, +} + #[derive(Serialize, Deserialize)] struct RawRequest { #[serde(skip_serializing_if = "Option::is_none")] diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index f92c86aa3c..c95d9008bc 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -495,7 +495,7 @@ pub struct RootsCapabilities { pub list_changed: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Tool { pub name: String, @@ -506,7 +506,7 @@ pub struct Tool { pub annotations: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolAnnotations { /// A human-readable title for the tool. @@ -679,6 +679,8 @@ pub struct CallToolResponse { pub is_error: Option, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub structured_content: Option, } #[derive(Debug, Serialize, Deserialize)]