Add ability to register tools in McpServer (#35068)

Makes it easier to add tools to a server by implementing a trait

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-07-24 23:19:20 -03:00 committed by GitHub
parent b446d66be7
commit 15c9da4ea4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 408 additions and 244 deletions

View file

@ -9,6 +9,8 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Task};
use net::async_net::{UnixListener, UnixStream};
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde_json::{json, value::RawValue};
use smol::stream::StreamExt;
use std::{
@ -20,16 +22,28 @@ use util::ResultExt;
use crate::{
client::{CspResult, RequestId, Response},
types::Request,
types::{
CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
ToolResponseContent,
requests::{CallTool, ListTools},
},
};
pub struct McpServer {
socket_path: PathBuf,
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
_server_task: Task<()>,
}
type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
struct RegisteredTool {
tool: Tool,
handler: ToolHandler,
}
type ToolHandler =
Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
impl McpServer {
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
@ -43,12 +57,14 @@ impl McpServer {
cx.spawn(async move |cx| {
let (temp_dir, socket_path, listener) = task.await?;
let tools = Rc::new(RefCell::new(HashMap::default()));
let handlers = Rc::new(RefCell::new(HashMap::default()));
let server_task = cx.spawn({
let tools = tools.clone();
let handlers = handlers.clone();
async move |cx| {
while let Ok((stream, _)) = listener.accept().await {
Self::serve_connection(stream, handlers.clone(), cx);
Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
}
drop(temp_dir)
}
@ -56,11 +72,40 @@ impl McpServer {
Ok(Self {
socket_path,
_server_task: server_task,
handlers: handlers.clone(),
tools,
handlers: handlers,
})
})
}
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let registered_tool = RegisteredTool {
tool: Tool {
name: T::NAME.into(),
description: Some(tool.description().into()),
input_schema: schemars::schema_for!(T::Input).into(),
annotations: Some(tool.annotations()),
},
handler: Box::new({
let tool = tool.clone();
move |input_value, cx| {
let input = match input_value {
Some(input) => serde_json::from_value(input),
None => serde_json::from_value(serde_json::Value::Null),
};
let tool = tool.clone();
match input {
Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
Err(err) => Task::ready(Err(err.into())),
}
}
}),
};
self.tools.borrow_mut().insert(T::NAME, registered_tool);
}
pub fn handle_request<R: Request>(
&mut self,
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
@ -120,7 +165,8 @@ impl McpServer {
fn serve_connection(
stream: UnixStream,
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
cx: &mut AsyncApp,
) {
let (read, write) = smol::io::split(stream);
@ -135,7 +181,13 @@ impl McpServer {
let Some(request_id) = request.id.clone() else {
continue;
};
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
if request.method == CallTool::METHOD {
Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
.await;
} else if request.method == ListTools::METHOD {
Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
} else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
if let Some(task) = cx
@ -149,25 +201,122 @@ impl McpServer {
.detach();
}
} else {
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response::<()> {
jsonrpc: "2.0",
id: request.id.unwrap(),
value: CspResult::Error(Some(crate::client::Error {
message: format!("unhandled method {}", request.method),
code: -32601,
})),
})
.unwrap(),
)
.ok();
Self::send_err(
request_id,
format!("unhandled method {}", request.method),
&outgoing_tx,
);
}
}
})
.detach();
}
fn handle_list_tools(
request_id: RequestId,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
) {
let response = ListToolsResponse {
tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
next_cursor: None,
meta: None,
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
}
async fn handle_call_tool(
request_id: RequestId,
params: Option<Box<RawValue>>,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
cx: &mut AsyncApp,
) {
let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
Some(params) => serde_json::from_str(params.get()),
None => serde_json::from_value(serde_json::Value::Null),
};
match result {
Ok(params) => {
if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
let task = (tool.handler)(params.arguments, cx);
cx.spawn(async move |_| {
let response = match task.await {
Ok(result) => CallToolResponse {
content: result.content,
is_error: Some(false),
meta: None,
structured_content: result.structured_content,
},
Err(err) => CallToolResponse {
content: vec![ToolResponseContent::Text {
text: err.to_string(),
}],
is_error: Some(true),
meta: None,
structured_content: None,
},
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
})
.detach();
} else {
Self::send_err(
request_id,
format!("Tool not found: {}", params.name),
&outgoing_tx,
);
}
}
Err(err) => {
Self::send_err(request_id, err.to_string(), &outgoing_tx);
}
}
}
fn send_err(
request_id: RequestId,
message: impl Into<String>,
outgoing_tx: &UnboundedSender<String>,
) {
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response::<()> {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Error(Some(crate::client::Error {
message: message.into(),
code: -32601,
})),
})
.unwrap(),
)
.ok();
}
async fn handle_io(
mut outgoing_rx: UnboundedReceiver<String>,
incoming_tx: UnboundedSender<RawRequest>,
@ -216,6 +365,34 @@ impl McpServer {
}
}
pub trait McpServerTool {
type Input: DeserializeOwned + JsonSchema;
const NAME: &'static str;
fn description(&self) -> &'static str;
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: None,
read_only_hint: None,
destructive_hint: None,
idempotent_hint: None,
open_world_hint: None,
}
}
fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> impl Future<Output = Result<ToolResponse>>;
}
pub struct ToolResponse {
pub content: Vec<ToolResponseContent>,
pub structured_content: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize)]
struct RawRequest {
#[serde(skip_serializing_if = "Option::is_none")]

View file

@ -495,7 +495,7 @@ pub struct RootsCapabilities {
pub list_changed: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
@ -506,7 +506,7 @@ pub struct Tool {
pub annotations: Option<ToolAnnotations>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolAnnotations {
/// A human-readable title for the tool.
@ -679,6 +679,8 @@ pub struct CallToolResponse {
pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_content: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]