diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index e4461f94de..2de43d157f 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -104,7 +104,15 @@ impl Tool for ContextServerTool { tool_name, arguments ); - let response = protocol.run_tool(tool_name, arguments).await?; + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; let mut result = String::new(); for content in response.content { diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cb3a0d3c63..5d5cf21d93 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -566,10 +566,14 @@ impl ThreadStore { }; if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { let tool_ids = tool_working_set .update(cx, |tool_working_set, _| { - tools + response .tools .into_iter() .map(|tool| { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index 7af97b62a9..7965ee592b 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -864,8 +864,13 @@ impl ContextStore { }; if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - let slash_command_ids = prompts + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { + let slash_command_ids = response + .prompts .into_iter() .filter(assistant_slash_commands::acceptable_prompt) .map(|prompt| { diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index 9b0ac18426..509076c167 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let completion_result = protocol - .completion( - context_server::types::CompletionReference::Prompt( - context_server::types::PromptReference { - r#type: context_server::types::PromptReferenceType::Prompt, - name: prompt_name, + let response = protocol + .request::( + context_server::types::CompletionCompleteParams { + reference: context_server::types::CompletionReference::Prompt( + context_server::types::PromptReference { + ty: context_server::types::PromptReferenceType::Prompt, + name: prompt_name, + }, + ), + argument: context_server::types::CompletionArgument { + name: arg_name, + value: arg_value, }, - ), - arg_name, - arg_value, + meta: None, + }, ) .await?; - let completions = completion_result + let completions = response + .completion .values .into_iter() .map(|value| ArgumentCompletion { @@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand { if let Some(server) = store.get_running_server(&server_id) { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let result = protocol.run_prompt(&prompt_name, prompt_args).await?; + let response = protocol + .request::( + context_server::types::PromptsGetParams { + name: prompt_name.clone(), + arguments: Some(prompt_args), + meta: None, + }, + ) + .await?; anyhow::ensure!( - result + response .messages .iter() .all(|msg| matches!(msg.role, context_server::types::Role::User)), @@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand { ); // Extract text from user messages into a single prompt string - let mut prompt = result + let mut prompt = response .messages .into_iter() .filter_map(|msg| match msg.content { @@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand { range: 0..(prompt.len()), icon: IconName::ZedAssistant, label: SharedString::from( - result + response .description .unwrap_or(format!("Result from {}", prompt_name)), ), diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 62a5354b39..96bb9e071f 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/context_server.rs" +[features] +test-support = [] + [dependencies] anyhow.workspace = true async-trait.workspace = true diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 19f2f75541..387235307a 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,5 +1,7 @@ pub mod client; pub mod protocol; +#[cfg(any(test, feature = "test-support"))] +pub mod test; pub mod transport; pub mod types; diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 782a1a4a67..233df048d6 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,10 +6,9 @@ //! of messages. use anyhow::Result; -use collections::HashMap; use crate::client::Client; -use crate::types; +use crate::types::{self, Request}; pub struct ModelContextProtocol { inner: Client, @@ -43,7 +42,7 @@ impl ModelContextProtocol { let response: types::InitializeResponse = self .inner - .request(types::RequestType::Initialize.as_str(), params) + .request(types::request::Initialize::METHOD, params) .await?; anyhow::ensure!( @@ -94,137 +93,7 @@ impl InitializedContextServerProtocol { } } - fn check_capability(&self, capability: ServerCapability) -> Result<()> { - anyhow::ensure!( - self.capable(capability), - "Server does not support {capability:?} capability" - ); - Ok(()) - } - - /// List the MCP prompts. - pub async fn list_prompts(&self) -> Result> { - self.check_capability(ServerCapability::Prompts)?; - - let response: types::PromptsListResponse = self - .inner - .request( - types::RequestType::PromptsList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response.prompts) - } - - /// List the MCP resources. - pub async fn list_resources(&self) -> Result { - self.check_capability(ServerCapability::Resources)?; - - let response: types::ResourcesListResponse = self - .inner - .request( - types::RequestType::ResourcesList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response) - } - - /// Executes a prompt with the given arguments and returns the result. - pub async fn run_prompt>( - &self, - prompt: P, - arguments: HashMap, - ) -> Result { - self.check_capability(ServerCapability::Prompts)?; - - let params = types::PromptsGetParams { - name: prompt.as_ref().to_string(), - arguments: Some(arguments), - meta: None, - }; - - let response: types::PromptsGetResponse = self - .inner - .request(types::RequestType::PromptsGet.as_str(), params) - .await?; - - Ok(response) - } - - pub async fn completion>( - &self, - reference: types::CompletionReference, - argument: P, - value: P, - ) -> Result { - let params = types::CompletionCompleteParams { - r#ref: reference, - argument: types::CompletionArgument { - name: argument.into(), - value: value.into(), - }, - meta: None, - }; - let result: types::CompletionCompleteResponse = self - .inner - .request(types::RequestType::CompletionComplete.as_str(), params) - .await?; - - let completion = types::Completion { - values: result.completion.values, - total: types::CompletionTotal::from_options( - result.completion.has_more, - result.completion.total, - ), - }; - - Ok(completion) - } - - /// List MCP tools. - pub async fn list_tools(&self) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let response = self - .inner - .request::(types::RequestType::ListTools.as_str(), ()) - .await?; - - Ok(response) - } - - /// Executes a tool with the given arguments - pub async fn run_tool>( - &self, - tool: P, - arguments: Option>, - ) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let params = types::CallToolParams { - name: tool.as_ref().to_string(), - arguments, - meta: None, - }; - - let response: types::CallToolResponse = self - .inner - .request(types::RequestType::CallTool.as_str(), params) - .await?; - - Ok(response) - } -} - -impl InitializedContextServerProtocol { - pub async fn request( - &self, - method: &str, - params: impl serde::Serialize, - ) -> Result { - self.inner.request(method, params).await + pub async fn request(&self, params: T::Params) -> Result { + self.inner.request(T::METHOD, params).await } } diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs new file mode 100644 index 0000000000..d882a56984 --- /dev/null +++ b/crates/context_server/src/test.rs @@ -0,0 +1,118 @@ +use anyhow::Context as _; +use collections::HashMap; +use futures::{Stream, StreamExt as _, lock::Mutex}; +use gpui::BackgroundExecutor; +use std::{pin::Pin, sync::Arc}; + +use crate::{ + transport::Transport, + types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities}, +}; + +pub fn create_fake_transport( + name: impl Into, + executor: BackgroundExecutor, +) -> FakeTransport { + let name = name.into(); + FakeTransport::new(executor).on_request::(move |_params| { + create_initialize_response(name.clone()) + }) +} + +fn create_initialize_response(server_name: String) -> InitializeResponse { + InitializeResponse { + protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()), + server_info: Implementation { + name: server_name, + version: "1.0.0".to_string(), + }, + capabilities: ServerCapabilities::default(), + meta: None, + } +} + +pub struct FakeTransport { + request_handlers: + HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, + tx: futures::channel::mpsc::UnboundedSender, + rx: Arc>>, + executor: BackgroundExecutor, +} + +impl FakeTransport { + pub fn new(executor: BackgroundExecutor) -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { + request_handlers: Default::default(), + tx, + rx: Arc::new(Mutex::new(rx)), + executor, + } + } + + pub fn on_request( + mut self, + handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, + ) -> Self { + self.request_handlers.insert( + T::METHOD, + Arc::new(move |value| { + let params = value.get("params").expect("Missing parameters").clone(); + let params: T::Params = + serde_json::from_value(params).expect("Invalid parameters received"); + let response = handler(params); + serde_json::to_value(response).unwrap() + }), + ); + self + } +} + +#[async_trait::async_trait] +impl Transport for FakeTransport { + async fn send(&self, message: String) -> anyhow::Result<()> { + if let Ok(msg) = serde_json::from_str::(&message) { + let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); + + if let Some(method) = msg.get("method") { + let method = method.as_str().expect("Invalid method received"); + if let Some(handler) = self.request_handlers.get(method) { + let payload = handler(msg); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": payload + }); + self.tx + .unbounded_send(response.to_string()) + .context("sending a message")?; + } else { + log::debug!("No handler registered for MCP request '{method}'"); + } + } + } + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + let rx = self.rx.clone(); + let executor = self.executor.clone(); + Box::pin(futures::stream::unfold(rx, move |rx| { + let executor = executor.clone(); + async move { + let mut rx_guard = rx.lock().await; + executor.simulate_random_delay().await; + if let Some(message) = rx_guard.next().await { + drop(rx_guard); + Some((message, rx)) + } else { + None + } + } + })) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(futures::stream::empty()) + } +} diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 83f08218f3..9c36c40228 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -1,76 +1,92 @@ use collections::HashMap; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; -pub enum RequestType { - Initialize, - CallTool, - ResourcesUnsubscribe, - ResourcesSubscribe, - ResourcesRead, - ResourcesList, - LoggingSetLevel, - PromptsGet, - PromptsList, - CompletionComplete, - Ping, - ListTools, - ListResourceTemplates, - ListRoots, +pub mod request { + use super::*; + + macro_rules! request { + ($method:expr, $name:ident, $params:ty, $response:ty) => { + pub struct $name; + + impl Request for $name { + type Params = $params; + type Response = $response; + const METHOD: &'static str = $method; + } + }; + } + + request!( + "initialize", + Initialize, + InitializeParams, + InitializeResponse + ); + request!("tools/call", CallTool, CallToolParams, CallToolResponse); + request!( + "resources/unsubscribe", + ResourcesUnsubscribe, + ResourcesUnsubscribeParams, + () + ); + request!( + "resources/subscribe", + ResourcesSubscribe, + ResourcesSubscribeParams, + () + ); + request!( + "resources/read", + ResourcesRead, + ResourcesReadParams, + ResourcesReadResponse + ); + request!("resources/list", ResourcesList, (), ResourcesListResponse); + request!( + "logging/setLevel", + LoggingSetLevel, + LoggingSetLevelParams, + () + ); + request!( + "prompts/get", + PromptsGet, + PromptsGetParams, + PromptsGetResponse + ); + request!("prompts/list", PromptsList, (), PromptsListResponse); + request!( + "completion/complete", + CompletionComplete, + CompletionCompleteParams, + CompletionCompleteResponse + ); + request!("ping", Ping, (), ()); + request!("tools/list", ListTools, (), ListToolsResponse); + request!( + "resources/templates/list", + ListResourceTemplates, + (), + ListResourceTemplatesResponse + ); + request!("roots/list", ListRoots, (), ListRootsResponse); } -impl RequestType { - pub fn as_str(&self) -> &'static str { - match self { - RequestType::Initialize => "initialize", - RequestType::CallTool => "tools/call", - RequestType::ResourcesUnsubscribe => "resources/unsubscribe", - RequestType::ResourcesSubscribe => "resources/subscribe", - RequestType::ResourcesRead => "resources/read", - RequestType::ResourcesList => "resources/list", - RequestType::LoggingSetLevel => "logging/setLevel", - RequestType::PromptsGet => "prompts/get", - RequestType::PromptsList => "prompts/list", - RequestType::CompletionComplete => "completion/complete", - RequestType::Ping => "ping", - RequestType::ListTools => "tools/list", - RequestType::ListResourceTemplates => "resources/templates/list", - RequestType::ListRoots => "roots/list", - } - } -} - -impl TryFrom<&str> for RequestType { - type Error = (); - - fn try_from(s: &str) -> Result { - match s { - "initialize" => Ok(RequestType::Initialize), - "tools/call" => Ok(RequestType::CallTool), - "resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe), - "resources/subscribe" => Ok(RequestType::ResourcesSubscribe), - "resources/read" => Ok(RequestType::ResourcesRead), - "resources/list" => Ok(RequestType::ResourcesList), - "logging/setLevel" => Ok(RequestType::LoggingSetLevel), - "prompts/get" => Ok(RequestType::PromptsGet), - "prompts/list" => Ok(RequestType::PromptsList), - "completion/complete" => Ok(RequestType::CompletionComplete), - "ping" => Ok(RequestType::Ping), - "tools/list" => Ok(RequestType::ListTools), - "resources/templates/list" => Ok(RequestType::ListResourceTemplates), - "roots/list" => Ok(RequestType::ListRoots), - _ => Err(()), - } - } +pub trait Request { + type Params: DeserializeOwned + Serialize + Send + Sync + 'static; + type Response: DeserializeOwned + Serialize + Send + Sync + 'static; + const METHOD: &'static str; } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct ProtocolVersion(pub String); -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeParams { pub protocol_version: ProtocolVersion, @@ -80,7 +96,7 @@ pub struct InitializeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolParams { pub name: String, @@ -90,7 +106,7 @@ pub struct CallToolParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesUnsubscribeParams { pub uri: Url, @@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesSubscribeParams { pub uri: Url, @@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadParams { pub uri: Url, @@ -114,7 +130,7 @@ pub struct ResourcesReadParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct LoggingSetLevelParams { pub level: LoggingLevel, @@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetParams { pub name: String, @@ -132,37 +148,40 @@ pub struct PromptsGetParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteParams { - pub r#ref: CompletionReference, + #[serde(rename = "ref")] + pub reference: CompletionReference, pub argument: CompletionArgument, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CompletionReference { Prompt(PromptReference), Resource(ResourceReference), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub name: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourceReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub uri: Url, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum PromptReferenceType { #[serde(rename = "ref/prompt")] @@ -171,7 +190,7 @@ pub enum PromptReferenceType { Resource, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionArgument { pub name: String, @@ -188,7 +207,7 @@ pub struct InitializeResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadResponse { pub contents: Vec, @@ -196,14 +215,14 @@ pub struct ResourcesReadResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ResourceContentsType { Text(TextResourceContents), Blob(BlobResourceContents), } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { pub resources: Vec, @@ -220,7 +239,7 @@ pub struct SamplingMessage { pub content: MessageContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CreateMessageRequest { pub messages: Vec, @@ -296,7 +315,7 @@ pub struct MessageAnnotations { pub priority: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] @@ -306,7 +325,7 @@ pub struct PromptsGetResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { pub prompts: Vec, @@ -316,7 +335,7 @@ pub struct PromptsListResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteResponse { pub completion: CompletionResult, @@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionResult { pub values: Vec, @@ -336,7 +355,7 @@ pub struct CompletionResult { pub meta: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Prompt { pub name: String, @@ -346,7 +365,7 @@ pub struct Prompt { pub arguments: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptArgument { pub name: String, @@ -509,7 +528,7 @@ pub struct ModelHint { pub name: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum NotificationType { Initialized, @@ -589,7 +608,7 @@ pub struct Completion { pub total: CompletionTotal, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolResponse { pub content: Vec, @@ -620,7 +639,7 @@ pub struct ListToolsResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListResourceTemplatesResponse { pub resource_templates: Vec, @@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListRootsResponse { pub roots: Vec, diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 7e506d2184..f208af54d7 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -91,6 +91,7 @@ workspace-hack.workspace = true [dev-dependencies] client = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +context_server = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index aac9d5d460..34d6abb96c 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -499,17 +499,10 @@ impl ContextServerStore { mod tests { use super::*; use crate::{FakeFs, Project, project_settings::ProjectSettings}; - use context_server::{ - transport::Transport, - types::{ - self, Implementation, InitializeResponse, ProtocolVersion, RequestType, - ServerCapabilities, - }, - }; - use futures::{Stream, StreamExt as _, lock::Mutex}; - use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _}; + use context_server::test::create_fake_transport; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use serde_json::json; - use std::{cell::RefCell, pin::Pin, rc::Rc}; + use std::{cell::RefCell, rc::Rc}; use util::path; #[gpui::test] @@ -532,33 +525,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); - - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); store .update(cx, |store, cx| store.start_server(server_1, cx)) @@ -627,33 +604,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); - - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); let _server_events = assert_server_events( &store, @@ -702,30 +663,14 @@ mod tests { let server_id = ContextServerId(SERVER_1_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1)); - let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2)); + let server_with_same_id_1 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_with_same_id_2 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); // If we start another server with the same id, we should report that we stopped the previous one let _server_events = assert_server_events( @@ -794,16 +739,10 @@ mod tests { let store = cx.new(|cx| { ContextServerStore::test_maintain_server_loop( Box::new(move |id, _| { - let transport = FakeTransport::new(executor.clone(), { - let id = id.0.clone(); - move |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(id.clone().to_string())) - } - _ => None, - } - }); - Arc::new(ContextServer::new(id.clone(), Arc::new(transport))) + Arc::new(ContextServer::new( + id.clone(), + Arc::new(create_fake_transport(id.0.to_string(), executor.clone())), + )) }), registry.clone(), project.read(cx).worktree_store(), @@ -1033,99 +972,4 @@ mod tests { (fs, project) } - - fn create_initialize_response(server_name: String) -> serde_json::Value { - serde_json::to_value(&InitializeResponse { - protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), - server_info: Implementation { - name: server_name, - version: "1.0.0".to_string(), - }, - capabilities: ServerCapabilities::default(), - meta: None, - }) - .unwrap() - } - - struct FakeTransport { - on_request: Arc< - dyn Fn(u64, Option, serde_json::Value) -> Option - + Send - + Sync, - >, - tx: futures::channel::mpsc::UnboundedSender, - rx: Arc>>, - executor: BackgroundExecutor, - } - - impl FakeTransport { - fn new( - executor: BackgroundExecutor, - on_request: impl Fn( - u64, - Option, - serde_json::Value, - ) -> Option - + 'static - + Send - + Sync, - ) -> Self { - let (tx, rx) = futures::channel::mpsc::unbounded(); - Self { - on_request: Arc::new(on_request), - tx, - rx: Arc::new(Mutex::new(rx)), - executor, - } - } - } - - #[async_trait::async_trait] - impl Transport for FakeTransport { - async fn send(&self, message: String) -> Result<()> { - if let Ok(msg) = serde_json::from_str::(&message) { - let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); - - if let Some(method) = msg.get("method") { - let request_type = method - .as_str() - .and_then(|method| types::RequestType::try_from(method).ok()); - if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { - let response = serde_json::json!({ - "jsonrpc": "2.0", - "id": id, - "result": payload - }); - - self.tx - .unbounded_send(response.to_string()) - .context("sending a message")?; - } - } - } - Ok(()) - } - - fn receive(&self) -> Pin + Send>> { - let rx = self.rx.clone(); - let executor = self.executor.clone(); - Box::pin(futures::stream::unfold(rx, move |rx| { - let executor = executor.clone(); - async move { - let mut rx_guard = rx.lock().await; - executor.simulate_random_delay().await; - if let Some(message) = rx_guard.next().await { - drop(rx_guard); - Some((message, rx)) - } else { - None - } - } - })) - } - - fn receive_err(&self) -> Pin + Send>> { - Box::pin(futures::stream::empty()) - } - } }