Add ability to register tools in McpServer
(#35068)
Makes it easier to add tools to a server by implementing a trait Release Notes: - N/A
This commit is contained in:
parent
b446d66be7
commit
15c9da4ea4
4 changed files with 408 additions and 244 deletions
|
@ -2,6 +2,7 @@ mod mcp_server;
|
|||
pub mod tools;
|
||||
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
|
@ -332,10 +333,16 @@ async fn spawn_claude(
|
|||
&format!(
|
||||
"mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::PERMISSION_TOOL
|
||||
mcp_server::PermissionTool::NAME,
|
||||
),
|
||||
"--allowedTools",
|
||||
"mcp__zed__Read,mcp__zed__Edit",
|
||||
&format!(
|
||||
"mcp__{}__{},mcp__{}__{}",
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::EditTool::NAME,
|
||||
mcp_server::SERVER_NAME,
|
||||
mcp_server::ReadTool::NAME
|
||||
),
|
||||
"--disallowedTools",
|
||||
"Read,Edit",
|
||||
])
|
||||
|
|
|
@ -1,49 +1,24 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
CallToolParams, CallToolResponse, Implementation, InitializeParams, InitializeResponse,
|
||||
ListToolsResponse, ProtocolVersion, ServerCapabilities, Tool, ToolAnnotations,
|
||||
ToolResponseContent, ToolsCapabilities, requests,
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolAnnotations, ToolResponseContent, ToolsCapabilities, requests,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
|
||||
pub struct ClaudeZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
pub const READ_TOOL: &str = "Read";
|
||||
pub const EDIT_TOOL: &str = "Edit";
|
||||
pub const PERMISSION_TOOL: &str = "Confirmation";
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
struct PermissionToolParams {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl ClaudeZedMcpServer {
|
||||
pub async fn new(
|
||||
|
@ -52,9 +27,15 @@ impl ClaudeZedMcpServer {
|
|||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
mcp_server.handle_request::<requests::ListTools>(Self::handle_list_tools);
|
||||
mcp_server.handle_request::<requests::CallTool>(move |request, cx| {
|
||||
Self::handle_call_tool(request, thread_rx.clone(), cx)
|
||||
|
||||
mcp_server.add_tool(PermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(EditTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
|
@ -96,193 +77,6 @@ impl ClaudeZedMcpServer {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_list_tools(_: (), cx: &App) -> Task<Result<ListToolsResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(ListToolsResponse {
|
||||
tools: vec![
|
||||
Tool {
|
||||
name: PERMISSION_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(PermissionToolParams).into(),
|
||||
description: None,
|
||||
annotations: None,
|
||||
},
|
||||
Tool {
|
||||
name: READ_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(ReadToolParams).into(),
|
||||
description: Some("Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Read file".to_string()),
|
||||
read_only_hint: Some(true),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
// if time passes the contents might change, but it's not going to do anything different
|
||||
// true or false seem too strong, let's try a none.
|
||||
idempotent_hint: None,
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: EDIT_TOOL.into(),
|
||||
input_schema: schemars::schema_for!(EditToolParams).into(),
|
||||
description: Some("Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better.".to_string()),
|
||||
annotations: Some(ToolAnnotations {
|
||||
title: Some("Edit file".to_string()),
|
||||
read_only_hint: Some(false),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}),
|
||||
},
|
||||
],
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_call_tool(
|
||||
request: CallToolParams,
|
||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
cx: &App,
|
||||
) -> Task<Result<CallToolResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
if request.name.as_str() == PERMISSION_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let result = Self::handle_permissions_tool_call(input, thread, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&result)?,
|
||||
}],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == READ_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
let content = Self::handle_read_tool_call(input, thread, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content,
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else if request.name.as_str() == EDIT_TOOL {
|
||||
let input =
|
||||
serde_json::from_value(request.arguments.context("Arguments required")?)?;
|
||||
|
||||
Self::handle_edit_tool_call(input, thread, cx).await?;
|
||||
Ok(CallToolResponse {
|
||||
content: vec![],
|
||||
is_error: None,
|
||||
meta: None,
|
||||
})
|
||||
} else {
|
||||
anyhow::bail!("Unsupported tool");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_read_tool_call(
|
||||
ReadToolParams {
|
||||
abs_path,
|
||||
offset,
|
||||
limit,
|
||||
}: ReadToolParams,
|
||||
thread: Entity<AcpThread>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Vec<ToolResponseContent>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(abs_path, offset, limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(vec![ToolResponseContent::Text { text: content }])
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_edit_tool_call(
|
||||
params: EditToolParams,
|
||||
thread: Entity<AcpThread>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let content = thread
|
||||
.update(cx, |threads, cx| {
|
||||
threads.read_text_file(params.abs_path.clone(), None, None, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let new_content = content.replace(¶ms.old_text, ¶ms.new_text);
|
||||
if new_content == content {
|
||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
||||
}
|
||||
|
||||
thread
|
||||
.update(cx, |threads, cx| {
|
||||
threads.write_text_file(params.abs_path, new_content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_permissions_tool_call(
|
||||
params: PermissionToolParams,
|
||||
thread: Entity<AcpThread>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<PermissionToolResponse>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let claude_tool = ClaudeTool::infer(¶ms.tool_name, params.input.clone());
|
||||
|
||||
let tool_call_id =
|
||||
acp::ToolCallId(params.tool_use_id.context("Tool ID required")?.into());
|
||||
|
||||
let allow_option_id = acp::PermissionOptionId("allow".into());
|
||||
let reject_option_id = acp::PermissionOptionId("reject".into());
|
||||
|
||||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
id: allow_option_id.clone(),
|
||||
label: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
},
|
||||
acp::PermissionOption {
|
||||
id: reject_option_id,
|
||||
label: "Reject".into(),
|
||||
kind: acp::PermissionOptionKind::RejectOnce,
|
||||
},
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
if chosen_option == allow_option_id {
|
||||
Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: params.input,
|
||||
})
|
||||
} else {
|
||||
Ok(PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Deny,
|
||||
updated_input: params.input,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
@ -299,3 +93,187 @@ pub struct McpServerConfig {
|
|||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, JsonSchema, Debug)]
|
||||
pub struct PermissionToolParams {
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
tool_use_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
impl McpServerTool for PermissionTool {
|
||||
type Input = PermissionToolParams;
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Request permission for tool calls"
|
||||
}
|
||||
|
||||
async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let claude_tool = ClaudeTool::infer(&input.tool_name, input.input.clone());
|
||||
let tool_call_id = acp::ToolCallId(input.tool_use_id.context("Tool ID required")?.into());
|
||||
let allow_option_id = acp::PermissionOptionId("allow".into());
|
||||
let reject_option_id = acp::PermissionOptionId("reject".into());
|
||||
|
||||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
id: allow_option_id.clone(),
|
||||
label: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
},
|
||||
acp::PermissionOption {
|
||||
id: reject_option_id.clone(),
|
||||
label: "Reject".into(),
|
||||
kind: acp::PermissionOptionKind::RejectOnce,
|
||||
},
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let response = if chosen_option == allow_option_id {
|
||||
PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: input.input,
|
||||
}
|
||||
} else {
|
||||
PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Deny,
|
||||
updated_input: input.input,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&response)?,
|
||||
}],
|
||||
structured_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTool {
|
||||
type Input = ReadToolParams;
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Read the contents of a file. In sessions with mcp__zed__Read always use it instead of Read as it contains the most up-to-date contents."
|
||||
}
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: Some("Read file".to_string()),
|
||||
read_only_hint: Some(true),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.abs_path, input.offset, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text { text: content }],
|
||||
structured_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for EditTool {
|
||||
type Input = EditToolParams;
|
||||
const NAME: &'static str = "Edit";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Edits a file. In sessions with mcp__zed__Edit always use it instead of Edit as it will show the diff to the user better."
|
||||
}
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: Some("Edit file".to_string()),
|
||||
read_only_hint: Some(false),
|
||||
destructive_hint: Some(false),
|
||||
open_world_hint: Some(false),
|
||||
idempotent_hint: Some(false),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.abs_path.clone(), None, None, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let new_content = content.replace(&input.old_text, &input.new_text);
|
||||
if new_content == content {
|
||||
return Err(anyhow::anyhow!("The old_text was not found in the content"));
|
||||
}
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.abs_path, new_content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ use futures::{
|
|||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Task};
|
||||
use net::async_net::{UnixListener, UnixStream};
|
||||
use schemars::JsonSchema;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::{json, value::RawValue};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{
|
||||
|
@ -20,16 +22,28 @@ use util::ResultExt;
|
|||
|
||||
use crate::{
|
||||
client::{CspResult, RequestId, Response},
|
||||
types::Request,
|
||||
types::{
|
||||
CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
|
||||
ToolResponseContent,
|
||||
requests::{CallTool, ListTools},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct McpServer {
|
||||
socket_path: PathBuf,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
|
||||
_server_task: Task<()>,
|
||||
}
|
||||
|
||||
type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
|
||||
struct RegisteredTool {
|
||||
tool: Tool,
|
||||
handler: ToolHandler,
|
||||
}
|
||||
|
||||
type ToolHandler =
|
||||
Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
|
||||
type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
|
||||
|
||||
impl McpServer {
|
||||
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
|
||||
|
@ -43,12 +57,14 @@ impl McpServer {
|
|||
|
||||
cx.spawn(async move |cx| {
|
||||
let (temp_dir, socket_path, listener) = task.await?;
|
||||
let tools = Rc::new(RefCell::new(HashMap::default()));
|
||||
let handlers = Rc::new(RefCell::new(HashMap::default()));
|
||||
let server_task = cx.spawn({
|
||||
let tools = tools.clone();
|
||||
let handlers = handlers.clone();
|
||||
async move |cx| {
|
||||
while let Ok((stream, _)) = listener.accept().await {
|
||||
Self::serve_connection(stream, handlers.clone(), cx);
|
||||
Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
|
||||
}
|
||||
drop(temp_dir)
|
||||
}
|
||||
|
@ -56,11 +72,40 @@ impl McpServer {
|
|||
Ok(Self {
|
||||
socket_path,
|
||||
_server_task: server_task,
|
||||
handlers: handlers.clone(),
|
||||
tools,
|
||||
handlers: handlers,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||
let registered_tool = RegisteredTool {
|
||||
tool: Tool {
|
||||
name: T::NAME.into(),
|
||||
description: Some(tool.description().into()),
|
||||
input_schema: schemars::schema_for!(T::Input).into(),
|
||||
annotations: Some(tool.annotations()),
|
||||
},
|
||||
handler: Box::new({
|
||||
let tool = tool.clone();
|
||||
move |input_value, cx| {
|
||||
let input = match input_value {
|
||||
Some(input) => serde_json::from_value(input),
|
||||
None => serde_json::from_value(serde_json::Value::Null),
|
||||
};
|
||||
|
||||
let tool = tool.clone();
|
||||
match input {
|
||||
Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
|
||||
Err(err) => Task::ready(Err(err.into())),
|
||||
}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
self.tools.borrow_mut().insert(T::NAME, registered_tool);
|
||||
}
|
||||
|
||||
pub fn handle_request<R: Request>(
|
||||
&mut self,
|
||||
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
|
||||
|
@ -120,7 +165,8 @@ impl McpServer {
|
|||
|
||||
fn serve_connection(
|
||||
stream: UnixStream,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
|
||||
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let (read, write) = smol::io::split(stream);
|
||||
|
@ -135,7 +181,13 @@ impl McpServer {
|
|||
let Some(request_id) = request.id.clone() else {
|
||||
continue;
|
||||
};
|
||||
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
|
||||
|
||||
if request.method == CallTool::METHOD {
|
||||
Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
|
||||
.await;
|
||||
} else if request.method == ListTools::METHOD {
|
||||
Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
|
||||
} else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
|
||||
let outgoing_tx = outgoing_tx.clone();
|
||||
|
||||
if let Some(task) = cx
|
||||
|
@ -149,25 +201,122 @@ impl McpServer {
|
|||
.detach();
|
||||
}
|
||||
} else {
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response::<()> {
|
||||
jsonrpc: "2.0",
|
||||
id: request.id.unwrap(),
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: format!("unhandled method {}", request.method),
|
||||
code: -32601,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.ok();
|
||||
Self::send_err(
|
||||
request_id,
|
||||
format!("unhandled method {}", request.method),
|
||||
&outgoing_tx,
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn handle_list_tools(
|
||||
request_id: RequestId,
|
||||
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
) {
|
||||
let response = ListToolsResponse {
|
||||
tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
|
||||
next_cursor: None,
|
||||
meta: None,
|
||||
};
|
||||
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Ok(Some(response)),
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn handle_call_tool(
|
||||
request_id: RequestId,
|
||||
params: Option<Box<RawValue>>,
|
||||
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
|
||||
Some(params) => serde_json::from_str(params.get()),
|
||||
None => serde_json::from_value(serde_json::Value::Null),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(params) => {
|
||||
if let Some(tool) = tools.borrow().get(¶ms.name.as_ref()) {
|
||||
let outgoing_tx = outgoing_tx.clone();
|
||||
|
||||
let task = (tool.handler)(params.arguments, cx);
|
||||
cx.spawn(async move |_| {
|
||||
let response = match task.await {
|
||||
Ok(result) => CallToolResponse {
|
||||
content: result.content,
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
structured_content: result.structured_content,
|
||||
},
|
||||
Err(err) => CallToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: err.to_string(),
|
||||
}],
|
||||
is_error: Some(true),
|
||||
meta: None,
|
||||
structured_content: None,
|
||||
},
|
||||
};
|
||||
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Ok(Some(response)),
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
} else {
|
||||
Self::send_err(
|
||||
request_id,
|
||||
format!("Tool not found: {}", params.name),
|
||||
&outgoing_tx,
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
Self::send_err(request_id, err.to_string(), &outgoing_tx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_err(
|
||||
request_id: RequestId,
|
||||
message: impl Into<String>,
|
||||
outgoing_tx: &UnboundedSender<String>,
|
||||
) {
|
||||
outgoing_tx
|
||||
.unbounded_send(
|
||||
serde_json::to_string(&Response::<()> {
|
||||
jsonrpc: "2.0",
|
||||
id: request_id,
|
||||
value: CspResult::Error(Some(crate::client::Error {
|
||||
message: message.into(),
|
||||
code: -32601,
|
||||
})),
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
|
||||
async fn handle_io(
|
||||
mut outgoing_rx: UnboundedReceiver<String>,
|
||||
incoming_tx: UnboundedSender<RawRequest>,
|
||||
|
@ -216,6 +365,34 @@ impl McpServer {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait McpServerTool {
|
||||
type Input: DeserializeOwned + JsonSchema;
|
||||
const NAME: &'static str;
|
||||
|
||||
fn description(&self) -> &'static str;
|
||||
|
||||
fn annotations(&self) -> ToolAnnotations {
|
||||
ToolAnnotations {
|
||||
title: None,
|
||||
read_only_hint: None,
|
||||
destructive_hint: None,
|
||||
idempotent_hint: None,
|
||||
open_world_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> impl Future<Output = Result<ToolResponse>>;
|
||||
}
|
||||
|
||||
pub struct ToolResponse {
|
||||
pub content: Vec<ToolResponseContent>,
|
||||
pub structured_content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct RawRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
|
|
@ -495,7 +495,7 @@ pub struct RootsCapabilities {
|
|||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub name: String,
|
||||
|
@ -506,7 +506,7 @@ pub struct Tool {
|
|||
pub annotations: Option<ToolAnnotations>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolAnnotations {
|
||||
/// A human-readable title for the tool.
|
||||
|
@ -679,6 +679,8 @@ pub struct CallToolResponse {
|
|||
pub is_error: Option<bool>,
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub structured_content: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue