This commit is contained in:
Cole Miller 2025-07-31 23:45:28 +00:00 committed by GitHub
parent 2b36d4ec94
commit 73387532ce
313 changed files with 6560 additions and 20270 deletions

View file

@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
use futures::{FutureExt, StreamExt, channel::oneshot, future, select};
use futures::{FutureExt, StreamExt, channel::oneshot, select};
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, Task};
use parking_lot::Mutex;
use postage::barrier;
@ -10,19 +10,15 @@ use smol::channel;
use std::{
fmt,
path::PathBuf,
pin::pin,
sync::{
Arc,
atomic::{AtomicI32, Ordering::SeqCst},
},
time::{Duration, Instant},
};
use util::{ResultExt, TryFutureExt};
use util::TryFutureExt;
use crate::{
transport::{StdioTransport, Transport},
types::{CancelledParams, ClientNotification, Notification as _, notifications::Cancelled},
};
use crate::transport::{StdioTransport, Transport};
const JSON_RPC_VERSION: &str = "2.0";
const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
@ -36,7 +32,6 @@ pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
@ -83,15 +78,6 @@ pub struct Request<'a, T> {
pub params: T,
}
#[derive(Serialize, Deserialize)]
pub struct AnyRequest<'a> {
pub jsonrpc: &'a str,
pub id: RequestId,
pub method: &'a str,
#[serde(skip_serializing_if = "is_null_value")]
pub params: Option<&'a RawValue>,
}
#[derive(Serialize, Deserialize)]
struct AnyResponse<'a> {
jsonrpc: &'a str,
@ -191,23 +177,15 @@ impl Client {
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
let request_handlers = request_handlers.clone();
let transport = transport.clone();
async move |cx| {
Self::handle_input(
transport,
notification_handlers,
request_handlers,
response_handlers,
cx,
)
.log_err()
.await
Self::handle_input(transport, notification_handlers, response_handlers, cx)
.log_err()
.await
}
});
let receive_err_task = cx.spawn({
@ -253,24 +231,13 @@ impl Client {
async fn handle_input(
transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: &mut AsyncApp,
) -> anyhow::Result<()> {
let mut receiver = transport.receive();
while let Some(message) = receiver.next().await {
log::trace!("recv: {}", &message);
if let Ok(request) = serde_json::from_str::<AnyRequest>(&message) {
let mut request_handlers = request_handlers.lock();
if let Some(handler) = request_handlers.get_mut(request.method) {
handler(
request.id,
request.params.unwrap_or(RawValue::NULL),
cx.clone(),
);
}
} else if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string()));
@ -281,8 +248,6 @@ impl Client {
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
} else {
log::error!("Unhandled JSON from context_server: {}", message);
}
}
@ -330,17 +295,6 @@ impl Client {
&self,
method: &str,
params: impl Serialize,
) -> Result<T> {
self.request_with(method, params, None, Some(REQUEST_TIMEOUT))
.await
}
pub async fn request_with<T: DeserializeOwned>(
&self,
method: &str,
params: impl Serialize,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T> {
let id = self.next_id.fetch_add(1, SeqCst);
let request = serde_json::to_string(&Request {
@ -376,23 +330,7 @@ impl Client {
handle_response?;
send?;
let mut timeout_fut = pin!(
match timeout {
Some(timeout) => future::Either::Left(executor.timer(timeout)),
None => future::Either::Right(future::pending()),
}
.fuse()
);
let mut cancel_fut = pin!(
match cancel_rx {
Some(rx) => future::Either::Left(async {
rx.await.log_err();
}),
None => future::Either::Right(future::pending()),
}
.fuse()
);
let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
select! {
response = rx.fuse() => {
let elapsed = started.elapsed();
@ -411,18 +349,8 @@ impl Client {
Err(_) => anyhow::bail!("cancelled")
}
}
_ = cancel_fut => {
self.notify(
Cancelled::METHOD,
ClientNotification::Cancelled(CancelledParams {
request_id: RequestId::Int(id),
reason: None
})
).log_err();
anyhow::bail!(RequestCanceled)
}
_ = timeout_fut => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", timeout.unwrap());
_ = timeout => {
log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
anyhow::bail!("Context server request timeout");
}
}
@ -452,17 +380,6 @@ impl Client {
}
}
#[derive(Debug)]
pub struct RequestCanceled;
impl std::error::Error for RequestCanceled {}
impl std::fmt::Display for RequestCanceled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Context server request was canceled")
}
}
impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)

View file

@ -9,8 +9,6 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Task};
use net::async_net::{UnixListener, UnixStream};
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde_json::{json, value::RawValue};
use smol::stream::StreamExt;
use std::{
@ -22,32 +20,16 @@ use util::ResultExt;
use crate::{
client::{CspResult, RequestId, Response},
types::{
CallToolParams, CallToolResponse, ListToolsResponse, Request, Tool, ToolAnnotations,
ToolResponseContent,
requests::{CallTool, ListTools},
},
types::Request,
};
pub struct McpServer {
socket_path: PathBuf,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
_server_task: Task<()>,
}
struct RegisteredTool {
tool: Tool,
handler: ToolHandler,
}
type ToolHandler = Box<
dyn Fn(
Option<serde_json::Value>,
&mut AsyncApp,
) -> Task<Result<ToolResponse<serde_json::Value>>>,
>;
type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
type McpHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
impl McpServer {
pub fn new(cx: &AsyncApp) -> Task<Result<Self>> {
@ -61,14 +43,12 @@ impl McpServer {
cx.spawn(async move |cx| {
let (temp_dir, socket_path, listener) = task.await?;
let tools = Rc::new(RefCell::new(HashMap::default()));
let handlers = Rc::new(RefCell::new(HashMap::default()));
let server_task = cx.spawn({
let tools = tools.clone();
let handlers = handlers.clone();
async move |cx| {
while let Ok((stream, _)) = listener.accept().await {
Self::serve_connection(stream, tools.clone(), handlers.clone(), cx);
Self::serve_connection(stream, handlers.clone(), cx);
}
drop(temp_dir)
}
@ -76,56 +56,11 @@ impl McpServer {
Ok(Self {
socket_path,
_server_task: server_task,
tools,
handlers: handlers,
handlers: handlers.clone(),
})
})
}
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let output_schema = schemars::schema_for!(T::Output);
let unit_schema = schemars::schema_for!(());
let registered_tool = RegisteredTool {
tool: Tool {
name: T::NAME.into(),
description: Some(tool.description().into()),
input_schema: schemars::schema_for!(T::Input).into(),
output_schema: if output_schema == unit_schema {
None
} else {
Some(output_schema.into())
},
annotations: Some(tool.annotations()),
},
handler: Box::new({
let tool = tool.clone();
move |input_value, cx| {
let input = match input_value {
Some(input) => serde_json::from_value(input),
None => serde_json::from_value(serde_json::Value::Null),
};
let tool = tool.clone();
match input {
Ok(input) => cx.spawn(async move |cx| {
let output = tool.run(input, cx).await?;
Ok(ToolResponse {
content: output.content,
structured_content: serde_json::to_value(output.structured_content)
.unwrap_or_default(),
})
}),
Err(err) => Task::ready(Err(err.into())),
}
}
}),
};
self.tools.borrow_mut().insert(T::NAME, registered_tool);
}
pub fn handle_request<R: Request>(
&mut self,
f: impl Fn(R::Params, &App) -> Task<Result<R::Response>> + 'static,
@ -185,8 +120,7 @@ impl McpServer {
fn serve_connection(
stream: UnixStream,
tools: Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
handlers: Rc<RefCell<HashMap<&'static str, RequestHandler>>>,
handlers: Rc<RefCell<HashMap<&'static str, McpHandler>>>,
cx: &mut AsyncApp,
) {
let (read, write) = smol::io::split(stream);
@ -201,13 +135,7 @@ impl McpServer {
let Some(request_id) = request.id.clone() else {
continue;
};
if request.method == CallTool::METHOD {
Self::handle_call_tool(request_id, request.params, &tools, &outgoing_tx, cx)
.await;
} else if request.method == ListTools::METHOD {
Self::handle_list_tools(request.id.unwrap(), &tools, &outgoing_tx);
} else if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
if let Some(handler) = handlers.borrow().get(&request.method.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
if let Some(task) = cx
@ -221,126 +149,25 @@ impl McpServer {
.detach();
}
} else {
Self::send_err(
request_id,
format!("unhandled method {}", request.method),
&outgoing_tx,
);
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response::<()> {
jsonrpc: "2.0",
id: request.id.unwrap(),
value: CspResult::Error(Some(crate::client::Error {
message: format!("unhandled method {}", request.method),
code: -32601,
})),
})
.unwrap(),
)
.ok();
}
}
})
.detach();
}
fn handle_list_tools(
request_id: RequestId,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
) {
let response = ListToolsResponse {
tools: tools.borrow().values().map(|t| t.tool.clone()).collect(),
next_cursor: None,
meta: None,
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
}
async fn handle_call_tool(
request_id: RequestId,
params: Option<Box<RawValue>>,
tools: &Rc<RefCell<HashMap<&'static str, RegisteredTool>>>,
outgoing_tx: &UnboundedSender<String>,
cx: &mut AsyncApp,
) {
let result: Result<CallToolParams, serde_json::Error> = match params.as_ref() {
Some(params) => serde_json::from_str(params.get()),
None => serde_json::from_value(serde_json::Value::Null),
};
match result {
Ok(params) => {
if let Some(tool) = tools.borrow().get(&params.name.as_ref()) {
let outgoing_tx = outgoing_tx.clone();
let task = (tool.handler)(params.arguments, cx);
cx.spawn(async move |_| {
let response = match task.await {
Ok(result) => CallToolResponse {
content: result.content,
is_error: Some(false),
meta: None,
structured_content: if result.structured_content.is_null() {
None
} else {
Some(result.structured_content)
},
},
Err(err) => CallToolResponse {
content: vec![ToolResponseContent::Text {
text: err.to_string(),
}],
is_error: Some(true),
meta: None,
structured_content: None,
},
};
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Ok(Some(response)),
})
.unwrap_or_default(),
)
.ok();
})
.detach();
} else {
Self::send_err(
request_id,
format!("Tool not found: {}", params.name),
&outgoing_tx,
);
}
}
Err(err) => {
Self::send_err(request_id, err.to_string(), &outgoing_tx);
}
}
}
fn send_err(
request_id: RequestId,
message: impl Into<String>,
outgoing_tx: &UnboundedSender<String>,
) {
outgoing_tx
.unbounded_send(
serde_json::to_string(&Response::<()> {
jsonrpc: "2.0",
id: request_id,
value: CspResult::Error(Some(crate::client::Error {
message: message.into(),
code: -32601,
})),
})
.unwrap(),
)
.ok();
}
async fn handle_io(
mut outgoing_rx: UnboundedReceiver<String>,
incoming_tx: UnboundedSender<RawRequest>,
@ -389,37 +216,7 @@ impl McpServer {
}
}
pub trait McpServerTool {
type Input: DeserializeOwned + JsonSchema;
type Output: Serialize + JsonSchema;
const NAME: &'static str;
fn description(&self) -> &'static str;
fn annotations(&self) -> ToolAnnotations {
ToolAnnotations {
title: None,
read_only_hint: None,
destructive_hint: None,
idempotent_hint: None,
open_world_hint: None,
}
}
fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
}
pub struct ToolResponse<T> {
pub content: Vec<ToolResponseContent>,
pub structured_content: T,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
struct RawRequest {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<RequestId>,

View file

@ -5,12 +5,7 @@
//! read/write messages and the types from types.rs for serialization/deserialization
//! of messages.
use std::time::Duration;
use anyhow::Result;
use futures::channel::oneshot;
use gpui::AsyncApp;
use serde_json::Value;
use crate::client::Client;
use crate::types::{self, Notification, Request};
@ -100,25 +95,7 @@ impl InitializedContextServerProtocol {
self.inner.request(T::METHOD, params).await
}
pub async fn request_with<T: Request>(
&self,
params: T::Params,
cancel_rx: Option<oneshot::Receiver<()>>,
timeout: Option<Duration>,
) -> Result<T::Response> {
self.inner
.request_with(T::METHOD, params, cancel_rx, timeout)
.await
}
pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
self.inner.notify(T::METHOD, params)
}
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.inner.on_notification(method, f);
}
}

View file

@ -3,8 +3,6 @@ use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::client::RequestId;
pub const LATEST_PROTOCOL_VERSION: &str = "2025-03-26";
pub const VERSION_2024_11_05: &str = "2024-11-05";
@ -102,7 +100,6 @@ pub mod notifications {
notification!("notifications/initialized", Initialized, ());
notification!("notifications/progress", Progress, ProgressParams);
notification!("notifications/message", Message, MessageParams);
notification!("notifications/cancelled", Cancelled, CancelledParams);
notification!(
"notifications/resources/updated",
ResourcesUpdated,
@ -495,20 +492,18 @@ pub struct RootsCapabilities {
pub list_changed: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub annotations: Option<ToolAnnotations>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolAnnotations {
/// A human-readable title for the tool.
@ -622,15 +617,11 @@ pub enum ClientNotification {
Initialized,
Progress(ProgressParams),
RootsListChanged,
Cancelled(CancelledParams),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CancelledParams {
pub request_id: RequestId,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
Cancelled {
request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
reason: Option<String>,
},
}
#[derive(Debug, Serialize, Deserialize)]
@ -682,20 +673,6 @@ pub struct CallToolResponse {
pub is_error: Option<bool>,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub structured_content: Option<serde_json::Value>,
}
impl CallToolResponse {
pub fn text_contents(&self) -> String {
let mut text = String::new();
for chunk in &self.content {
if let ToolResponseContent::Text { text: chunk } = chunk {
text.push_str(&chunk)
};
}
text
}
}
#[derive(Debug, Serialize, Deserialize)]