From 95d78ff8d573c88d51a8f77e039165c1de0ff16a Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 6 Jun 2025 17:47:21 +0200 Subject: [PATCH] context server: Make requests type safe (#32254) This changes the context server crate so that the input/output for a request are encoded at the type level, similar to how it is done for LSP requests. This also makes it easier to write tests that mock context servers, e.g. you can write something like this now when using the `test-support` feature of the `context-server` crate: ```rust create_fake_transport("mcp-1", cx.background_executor()) .on_request::(|_params| { PromptsListResponse { prompts: vec![/* some prompts */], .. } }) ``` Release Notes: - N/A --- crates/agent/src/context_server_tool.rs | 10 +- crates/agent/src/thread_store.rs | 8 +- .../src/context_store.rs | 9 +- .../src/context_server_command.rs | 42 ++-- crates/context_server/Cargo.toml | 3 + crates/context_server/src/context_server.rs | 2 + crates/context_server/src/protocol.rs | 139 +---------- crates/context_server/src/test.rs | 118 +++++++++ crates/context_server/src/types.rs | 195 ++++++++------- crates/project/Cargo.toml | 1 + crates/project/src/context_server_store.rs | 226 +++--------------- 11 files changed, 320 insertions(+), 433 deletions(-) create mode 100644 crates/context_server/src/test.rs diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index e4461f94de..2de43d157f 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -104,7 +104,15 @@ impl Tool for ContextServerTool { tool_name, arguments ); - let response = protocol.run_tool(tool_name, arguments).await?; + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; let mut result = String::new(); for content in response.content { diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cb3a0d3c63..5d5cf21d93 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -566,10 +566,14 @@ impl ThreadStore { }; if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { let tool_ids = tool_working_set .update(cx, |tool_working_set, _| { - tools + response .tools .into_iter() .map(|tool| { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index 7af97b62a9..7965ee592b 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -864,8 +864,13 @@ impl ContextStore { }; if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - let slash_command_ids = prompts + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { + let slash_command_ids = response + .prompts .into_iter() .filter(assistant_slash_commands::acceptable_prompt) .map(|prompt| { diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index 9b0ac18426..509076c167 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let completion_result = protocol - .completion( - context_server::types::CompletionReference::Prompt( - context_server::types::PromptReference { - r#type: context_server::types::PromptReferenceType::Prompt, - name: prompt_name, + let response = protocol + .request::( + context_server::types::CompletionCompleteParams { + reference: context_server::types::CompletionReference::Prompt( + context_server::types::PromptReference { + ty: context_server::types::PromptReferenceType::Prompt, + name: prompt_name, + }, + ), + argument: context_server::types::CompletionArgument { + name: arg_name, + value: arg_value, }, - ), - arg_name, - arg_value, + meta: None, + }, ) .await?; - let completions = completion_result + let completions = response + .completion .values .into_iter() .map(|value| ArgumentCompletion { @@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand { if let Some(server) = store.get_running_server(&server_id) { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let result = protocol.run_prompt(&prompt_name, prompt_args).await?; + let response = protocol + .request::( + context_server::types::PromptsGetParams { + name: prompt_name.clone(), + arguments: Some(prompt_args), + meta: None, + }, + ) + .await?; anyhow::ensure!( - result + response .messages .iter() .all(|msg| matches!(msg.role, context_server::types::Role::User)), @@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand { ); // Extract text from user messages into a single prompt string - let mut prompt = result + let mut prompt = response .messages .into_iter() .filter_map(|msg| match msg.content { @@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand { range: 0..(prompt.len()), icon: IconName::ZedAssistant, label: SharedString::from( - result + response .description .unwrap_or(format!("Result from {}", prompt_name)), ), diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 62a5354b39..96bb9e071f 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/context_server.rs" +[features] +test-support = [] + [dependencies] anyhow.workspace = true async-trait.workspace = true diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 19f2f75541..387235307a 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,5 +1,7 @@ pub mod client; pub mod protocol; +#[cfg(any(test, feature = "test-support"))] +pub mod test; pub mod transport; pub mod types; diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 782a1a4a67..233df048d6 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,10 +6,9 @@ //! of messages. use anyhow::Result; -use collections::HashMap; use crate::client::Client; -use crate::types; +use crate::types::{self, Request}; pub struct ModelContextProtocol { inner: Client, @@ -43,7 +42,7 @@ impl ModelContextProtocol { let response: types::InitializeResponse = self .inner - .request(types::RequestType::Initialize.as_str(), params) + .request(types::request::Initialize::METHOD, params) .await?; anyhow::ensure!( @@ -94,137 +93,7 @@ impl InitializedContextServerProtocol { } } - fn check_capability(&self, capability: ServerCapability) -> Result<()> { - anyhow::ensure!( - self.capable(capability), - "Server does not support {capability:?} capability" - ); - Ok(()) - } - - /// List the MCP prompts. - pub async fn list_prompts(&self) -> Result> { - self.check_capability(ServerCapability::Prompts)?; - - let response: types::PromptsListResponse = self - .inner - .request( - types::RequestType::PromptsList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response.prompts) - } - - /// List the MCP resources. - pub async fn list_resources(&self) -> Result { - self.check_capability(ServerCapability::Resources)?; - - let response: types::ResourcesListResponse = self - .inner - .request( - types::RequestType::ResourcesList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response) - } - - /// Executes a prompt with the given arguments and returns the result. - pub async fn run_prompt>( - &self, - prompt: P, - arguments: HashMap, - ) -> Result { - self.check_capability(ServerCapability::Prompts)?; - - let params = types::PromptsGetParams { - name: prompt.as_ref().to_string(), - arguments: Some(arguments), - meta: None, - }; - - let response: types::PromptsGetResponse = self - .inner - .request(types::RequestType::PromptsGet.as_str(), params) - .await?; - - Ok(response) - } - - pub async fn completion>( - &self, - reference: types::CompletionReference, - argument: P, - value: P, - ) -> Result { - let params = types::CompletionCompleteParams { - r#ref: reference, - argument: types::CompletionArgument { - name: argument.into(), - value: value.into(), - }, - meta: None, - }; - let result: types::CompletionCompleteResponse = self - .inner - .request(types::RequestType::CompletionComplete.as_str(), params) - .await?; - - let completion = types::Completion { - values: result.completion.values, - total: types::CompletionTotal::from_options( - result.completion.has_more, - result.completion.total, - ), - }; - - Ok(completion) - } - - /// List MCP tools. - pub async fn list_tools(&self) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let response = self - .inner - .request::(types::RequestType::ListTools.as_str(), ()) - .await?; - - Ok(response) - } - - /// Executes a tool with the given arguments - pub async fn run_tool>( - &self, - tool: P, - arguments: Option>, - ) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let params = types::CallToolParams { - name: tool.as_ref().to_string(), - arguments, - meta: None, - }; - - let response: types::CallToolResponse = self - .inner - .request(types::RequestType::CallTool.as_str(), params) - .await?; - - Ok(response) - } -} - -impl InitializedContextServerProtocol { - pub async fn request( - &self, - method: &str, - params: impl serde::Serialize, - ) -> Result { - self.inner.request(method, params).await + pub async fn request(&self, params: T::Params) -> Result { + self.inner.request(T::METHOD, params).await } } diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs new file mode 100644 index 0000000000..d882a56984 --- /dev/null +++ b/crates/context_server/src/test.rs @@ -0,0 +1,118 @@ +use anyhow::Context as _; +use collections::HashMap; +use futures::{Stream, StreamExt as _, lock::Mutex}; +use gpui::BackgroundExecutor; +use std::{pin::Pin, sync::Arc}; + +use crate::{ + transport::Transport, + types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities}, +}; + +pub fn create_fake_transport( + name: impl Into, + executor: BackgroundExecutor, +) -> FakeTransport { + let name = name.into(); + FakeTransport::new(executor).on_request::(move |_params| { + create_initialize_response(name.clone()) + }) +} + +fn create_initialize_response(server_name: String) -> InitializeResponse { + InitializeResponse { + protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()), + server_info: Implementation { + name: server_name, + version: "1.0.0".to_string(), + }, + capabilities: ServerCapabilities::default(), + meta: None, + } +} + +pub struct FakeTransport { + request_handlers: + HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, + tx: futures::channel::mpsc::UnboundedSender, + rx: Arc>>, + executor: BackgroundExecutor, +} + +impl FakeTransport { + pub fn new(executor: BackgroundExecutor) -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { + request_handlers: Default::default(), + tx, + rx: Arc::new(Mutex::new(rx)), + executor, + } + } + + pub fn on_request( + mut self, + handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, + ) -> Self { + self.request_handlers.insert( + T::METHOD, + Arc::new(move |value| { + let params = value.get("params").expect("Missing parameters").clone(); + let params: T::Params = + serde_json::from_value(params).expect("Invalid parameters received"); + let response = handler(params); + serde_json::to_value(response).unwrap() + }), + ); + self + } +} + +#[async_trait::async_trait] +impl Transport for FakeTransport { + async fn send(&self, message: String) -> anyhow::Result<()> { + if let Ok(msg) = serde_json::from_str::(&message) { + let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); + + if let Some(method) = msg.get("method") { + let method = method.as_str().expect("Invalid method received"); + if let Some(handler) = self.request_handlers.get(method) { + let payload = handler(msg); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": payload + }); + self.tx + .unbounded_send(response.to_string()) + .context("sending a message")?; + } else { + log::debug!("No handler registered for MCP request '{method}'"); + } + } + } + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + let rx = self.rx.clone(); + let executor = self.executor.clone(); + Box::pin(futures::stream::unfold(rx, move |rx| { + let executor = executor.clone(); + async move { + let mut rx_guard = rx.lock().await; + executor.simulate_random_delay().await; + if let Some(message) = rx_guard.next().await { + drop(rx_guard); + Some((message, rx)) + } else { + None + } + } + })) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(futures::stream::empty()) + } +} diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 83f08218f3..9c36c40228 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -1,76 +1,92 @@ use collections::HashMap; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; -pub enum RequestType { - Initialize, - CallTool, - ResourcesUnsubscribe, - ResourcesSubscribe, - ResourcesRead, - ResourcesList, - LoggingSetLevel, - PromptsGet, - PromptsList, - CompletionComplete, - Ping, - ListTools, - ListResourceTemplates, - ListRoots, +pub mod request { + use super::*; + + macro_rules! request { + ($method:expr, $name:ident, $params:ty, $response:ty) => { + pub struct $name; + + impl Request for $name { + type Params = $params; + type Response = $response; + const METHOD: &'static str = $method; + } + }; + } + + request!( + "initialize", + Initialize, + InitializeParams, + InitializeResponse + ); + request!("tools/call", CallTool, CallToolParams, CallToolResponse); + request!( + "resources/unsubscribe", + ResourcesUnsubscribe, + ResourcesUnsubscribeParams, + () + ); + request!( + "resources/subscribe", + ResourcesSubscribe, + ResourcesSubscribeParams, + () + ); + request!( + "resources/read", + ResourcesRead, + ResourcesReadParams, + ResourcesReadResponse + ); + request!("resources/list", ResourcesList, (), ResourcesListResponse); + request!( + "logging/setLevel", + LoggingSetLevel, + LoggingSetLevelParams, + () + ); + request!( + "prompts/get", + PromptsGet, + PromptsGetParams, + PromptsGetResponse + ); + request!("prompts/list", PromptsList, (), PromptsListResponse); + request!( + "completion/complete", + CompletionComplete, + CompletionCompleteParams, + CompletionCompleteResponse + ); + request!("ping", Ping, (), ()); + request!("tools/list", ListTools, (), ListToolsResponse); + request!( + "resources/templates/list", + ListResourceTemplates, + (), + ListResourceTemplatesResponse + ); + request!("roots/list", ListRoots, (), ListRootsResponse); } -impl RequestType { - pub fn as_str(&self) -> &'static str { - match self { - RequestType::Initialize => "initialize", - RequestType::CallTool => "tools/call", - RequestType::ResourcesUnsubscribe => "resources/unsubscribe", - RequestType::ResourcesSubscribe => "resources/subscribe", - RequestType::ResourcesRead => "resources/read", - RequestType::ResourcesList => "resources/list", - RequestType::LoggingSetLevel => "logging/setLevel", - RequestType::PromptsGet => "prompts/get", - RequestType::PromptsList => "prompts/list", - RequestType::CompletionComplete => "completion/complete", - RequestType::Ping => "ping", - RequestType::ListTools => "tools/list", - RequestType::ListResourceTemplates => "resources/templates/list", - RequestType::ListRoots => "roots/list", - } - } -} - -impl TryFrom<&str> for RequestType { - type Error = (); - - fn try_from(s: &str) -> Result { - match s { - "initialize" => Ok(RequestType::Initialize), - "tools/call" => Ok(RequestType::CallTool), - "resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe), - "resources/subscribe" => Ok(RequestType::ResourcesSubscribe), - "resources/read" => Ok(RequestType::ResourcesRead), - "resources/list" => Ok(RequestType::ResourcesList), - "logging/setLevel" => Ok(RequestType::LoggingSetLevel), - "prompts/get" => Ok(RequestType::PromptsGet), - "prompts/list" => Ok(RequestType::PromptsList), - "completion/complete" => Ok(RequestType::CompletionComplete), - "ping" => Ok(RequestType::Ping), - "tools/list" => Ok(RequestType::ListTools), - "resources/templates/list" => Ok(RequestType::ListResourceTemplates), - "roots/list" => Ok(RequestType::ListRoots), - _ => Err(()), - } - } +pub trait Request { + type Params: DeserializeOwned + Serialize + Send + Sync + 'static; + type Response: DeserializeOwned + Serialize + Send + Sync + 'static; + const METHOD: &'static str; } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct ProtocolVersion(pub String); -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeParams { pub protocol_version: ProtocolVersion, @@ -80,7 +96,7 @@ pub struct InitializeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolParams { pub name: String, @@ -90,7 +106,7 @@ pub struct CallToolParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesUnsubscribeParams { pub uri: Url, @@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesSubscribeParams { pub uri: Url, @@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadParams { pub uri: Url, @@ -114,7 +130,7 @@ pub struct ResourcesReadParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct LoggingSetLevelParams { pub level: LoggingLevel, @@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetParams { pub name: String, @@ -132,37 +148,40 @@ pub struct PromptsGetParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteParams { - pub r#ref: CompletionReference, + #[serde(rename = "ref")] + pub reference: CompletionReference, pub argument: CompletionArgument, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CompletionReference { Prompt(PromptReference), Resource(ResourceReference), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub name: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourceReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub uri: Url, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum PromptReferenceType { #[serde(rename = "ref/prompt")] @@ -171,7 +190,7 @@ pub enum PromptReferenceType { Resource, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionArgument { pub name: String, @@ -188,7 +207,7 @@ pub struct InitializeResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadResponse { pub contents: Vec, @@ -196,14 +215,14 @@ pub struct ResourcesReadResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ResourceContentsType { Text(TextResourceContents), Blob(BlobResourceContents), } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { pub resources: Vec, @@ -220,7 +239,7 @@ pub struct SamplingMessage { pub content: MessageContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CreateMessageRequest { pub messages: Vec, @@ -296,7 +315,7 @@ pub struct MessageAnnotations { pub priority: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] @@ -306,7 +325,7 @@ pub struct PromptsGetResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { pub prompts: Vec, @@ -316,7 +335,7 @@ pub struct PromptsListResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteResponse { pub completion: CompletionResult, @@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionResult { pub values: Vec, @@ -336,7 +355,7 @@ pub struct CompletionResult { pub meta: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Prompt { pub name: String, @@ -346,7 +365,7 @@ pub struct Prompt { pub arguments: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptArgument { pub name: String, @@ -509,7 +528,7 @@ pub struct ModelHint { pub name: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum NotificationType { Initialized, @@ -589,7 +608,7 @@ pub struct Completion { pub total: CompletionTotal, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolResponse { pub content: Vec, @@ -620,7 +639,7 @@ pub struct ListToolsResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListResourceTemplatesResponse { pub resource_templates: Vec, @@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListRootsResponse { pub roots: Vec, diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 7e506d2184..f208af54d7 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -91,6 +91,7 @@ workspace-hack.workspace = true [dev-dependencies] client = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +context_server = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index aac9d5d460..34d6abb96c 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -499,17 +499,10 @@ impl ContextServerStore { mod tests { use super::*; use crate::{FakeFs, Project, project_settings::ProjectSettings}; - use context_server::{ - transport::Transport, - types::{ - self, Implementation, InitializeResponse, ProtocolVersion, RequestType, - ServerCapabilities, - }, - }; - use futures::{Stream, StreamExt as _, lock::Mutex}; - use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _}; + use context_server::test::create_fake_transport; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use serde_json::json; - use std::{cell::RefCell, pin::Pin, rc::Rc}; + use std::{cell::RefCell, rc::Rc}; use util::path; #[gpui::test] @@ -532,33 +525,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); - - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); store .update(cx, |store, cx| store.start_server(server_1, cx)) @@ -627,33 +604,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); - - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); let _server_events = assert_server_events( &store, @@ -702,30 +663,14 @@ mod tests { let server_id = ContextServerId(SERVER_1_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1)); - let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2)); + let server_with_same_id_1 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_with_same_id_2 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); // If we start another server with the same id, we should report that we stopped the previous one let _server_events = assert_server_events( @@ -794,16 +739,10 @@ mod tests { let store = cx.new(|cx| { ContextServerStore::test_maintain_server_loop( Box::new(move |id, _| { - let transport = FakeTransport::new(executor.clone(), { - let id = id.0.clone(); - move |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(id.clone().to_string())) - } - _ => None, - } - }); - Arc::new(ContextServer::new(id.clone(), Arc::new(transport))) + Arc::new(ContextServer::new( + id.clone(), + Arc::new(create_fake_transport(id.0.to_string(), executor.clone())), + )) }), registry.clone(), project.read(cx).worktree_store(), @@ -1033,99 +972,4 @@ mod tests { (fs, project) } - - fn create_initialize_response(server_name: String) -> serde_json::Value { - serde_json::to_value(&InitializeResponse { - protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), - server_info: Implementation { - name: server_name, - version: "1.0.0".to_string(), - }, - capabilities: ServerCapabilities::default(), - meta: None, - }) - .unwrap() - } - - struct FakeTransport { - on_request: Arc< - dyn Fn(u64, Option, serde_json::Value) -> Option - + Send - + Sync, - >, - tx: futures::channel::mpsc::UnboundedSender, - rx: Arc>>, - executor: BackgroundExecutor, - } - - impl FakeTransport { - fn new( - executor: BackgroundExecutor, - on_request: impl Fn( - u64, - Option, - serde_json::Value, - ) -> Option - + 'static - + Send - + Sync, - ) -> Self { - let (tx, rx) = futures::channel::mpsc::unbounded(); - Self { - on_request: Arc::new(on_request), - tx, - rx: Arc::new(Mutex::new(rx)), - executor, - } - } - } - - #[async_trait::async_trait] - impl Transport for FakeTransport { - async fn send(&self, message: String) -> Result<()> { - if let Ok(msg) = serde_json::from_str::(&message) { - let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); - - if let Some(method) = msg.get("method") { - let request_type = method - .as_str() - .and_then(|method| types::RequestType::try_from(method).ok()); - if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { - let response = serde_json::json!({ - "jsonrpc": "2.0", - "id": id, - "result": payload - }); - - self.tx - .unbounded_send(response.to_string()) - .context("sending a message")?; - } - } - } - Ok(()) - } - - fn receive(&self) -> Pin + Send>> { - let rx = self.rx.clone(); - let executor = self.executor.clone(); - Box::pin(futures::stream::unfold(rx, move |rx| { - let executor = executor.clone(); - async move { - let mut rx_guard = rx.lock().await; - executor.simulate_random_delay().await; - if let Some(message) = rx_guard.next().await { - drop(rx_guard); - Some((message, rx)) - } else { - None - } - } - })) - } - - fn receive_err(&self) -> Pin + Send>> { - Box::pin(futures::stream::empty()) - } - } }