context server: Make requests type safe (#32254)

This changes the context server crate so that the input/output for a
request are encoded at the type level, similar to how it is done for LSP
requests.

This also makes it easier to write tests that mock context servers, e.g.
you can write something like this now when using the `test-support`
feature of the `context-server` crate:

```rust
create_fake_transport("mcp-1", cx.background_executor())
    .on_request::<context_server::types::request::PromptsList>(|_params| {
        PromptsListResponse {
            prompts: vec![/* some prompts */],
            ..
        }
    })
```

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-06 17:47:21 +02:00 committed by GitHub
parent 454adfacae
commit 95d78ff8d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 320 additions and 433 deletions

View file

@ -104,7 +104,15 @@ impl Tool for ContextServerTool {
tool_name, tool_name,
arguments arguments
); );
let response = protocol.run_tool(tool_name, arguments).await?; let response = protocol
.request::<context_server::types::request::CallTool>(
context_server::types::CallToolParams {
name: tool_name,
arguments,
meta: None,
},
)
.await?;
let mut result = String::new(); let mut result = String::new();
for content in response.content { for content in response.content {

View file

@ -566,10 +566,14 @@ impl ThreadStore {
}; };
if protocol.capable(context_server::protocol::ServerCapability::Tools) { if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() { if let Some(response) = protocol
.request::<context_server::types::request::ListTools>(())
.await
.log_err()
{
let tool_ids = tool_working_set let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| { .update(cx, |tool_working_set, _| {
tools response
.tools .tools
.into_iter() .into_iter()
.map(|tool| { .map(|tool| {

View file

@ -864,8 +864,13 @@ impl ContextStore {
}; };
if protocol.capable(context_server::protocol::ServerCapability::Prompts) { if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() { if let Some(response) = protocol
let slash_command_ids = prompts .request::<context_server::types::request::PromptsList>(())
.await
.log_err()
{
let slash_command_ids = response
.prompts
.into_iter() .into_iter()
.filter(assistant_slash_commands::acceptable_prompt) .filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| { .map(|prompt| {

View file

@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand {
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?; let protocol = server.client().context("Context server not initialized")?;
let completion_result = protocol let response = protocol
.completion( .request::<context_server::types::request::CompletionComplete>(
context_server::types::CompletionReference::Prompt( context_server::types::CompletionCompleteParams {
reference: context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference { context_server::types::PromptReference {
r#type: context_server::types::PromptReferenceType::Prompt, ty: context_server::types::PromptReferenceType::Prompt,
name: prompt_name, name: prompt_name,
}, },
), ),
arg_name, argument: context_server::types::CompletionArgument {
arg_value, name: arg_name,
value: arg_value,
},
meta: None,
},
) )
.await?; .await?;
let completions = completion_result let completions = response
.completion
.values .values
.into_iter() .into_iter()
.map(|value| ArgumentCompletion { .map(|value| ArgumentCompletion {
@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand {
if let Some(server) = store.get_running_server(&server_id) { if let Some(server) = store.get_running_server(&server_id) {
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?; let protocol = server.client().context("Context server not initialized")?;
let result = protocol.run_prompt(&prompt_name, prompt_args).await?; let response = protocol
.request::<context_server::types::request::PromptsGet>(
context_server::types::PromptsGetParams {
name: prompt_name.clone(),
arguments: Some(prompt_args),
meta: None,
},
)
.await?;
anyhow::ensure!( anyhow::ensure!(
result response
.messages .messages
.iter() .iter()
.all(|msg| matches!(msg.role, context_server::types::Role::User)), .all(|msg| matches!(msg.role, context_server::types::Role::User)),
@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand {
); );
// Extract text from user messages into a single prompt string // Extract text from user messages into a single prompt string
let mut prompt = result let mut prompt = response
.messages .messages
.into_iter() .into_iter()
.filter_map(|msg| match msg.content { .filter_map(|msg| match msg.content {
@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand {
range: 0..(prompt.len()), range: 0..(prompt.len()),
icon: IconName::ZedAssistant, icon: IconName::ZedAssistant,
label: SharedString::from( label: SharedString::from(
result response
.description .description
.unwrap_or(format!("Result from {}", prompt_name)), .unwrap_or(format!("Result from {}", prompt_name)),
), ),

View file

@ -11,6 +11,9 @@ workspace = true
[lib] [lib]
path = "src/context_server.rs" path = "src/context_server.rs"
[features]
test-support = []
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-trait.workspace = true async-trait.workspace = true

View file

@ -1,5 +1,7 @@
pub mod client; pub mod client;
pub mod protocol; pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
pub mod transport; pub mod transport;
pub mod types; pub mod types;

View file

@ -6,10 +6,9 @@
//! of messages. //! of messages.
use anyhow::Result; use anyhow::Result;
use collections::HashMap;
use crate::client::Client; use crate::client::Client;
use crate::types; use crate::types::{self, Request};
pub struct ModelContextProtocol { pub struct ModelContextProtocol {
inner: Client, inner: Client,
@ -43,7 +42,7 @@ impl ModelContextProtocol {
let response: types::InitializeResponse = self let response: types::InitializeResponse = self
.inner .inner
.request(types::RequestType::Initialize.as_str(), params) .request(types::request::Initialize::METHOD, params)
.await?; .await?;
anyhow::ensure!( anyhow::ensure!(
@ -94,137 +93,7 @@ impl InitializedContextServerProtocol {
} }
} }
fn check_capability(&self, capability: ServerCapability) -> Result<()> { pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
anyhow::ensure!( self.inner.request(T::METHOD, params).await
self.capable(capability),
"Server does not support {capability:?} capability"
);
Ok(())
}
/// List the MCP prompts.
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
self.check_capability(ServerCapability::Prompts)?;
let response: types::PromptsListResponse = self
.inner
.request(
types::RequestType::PromptsList.as_str(),
serde_json::json!({}),
)
.await?;
Ok(response.prompts)
}
/// List the MCP resources.
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
self.check_capability(ServerCapability::Resources)?;
let response: types::ResourcesListResponse = self
.inner
.request(
types::RequestType::ResourcesList.as_str(),
serde_json::json!({}),
)
.await?;
Ok(response)
}
/// Executes a prompt with the given arguments and returns the result.
pub async fn run_prompt<P: AsRef<str>>(
&self,
prompt: P,
arguments: HashMap<String, String>,
) -> Result<types::PromptsGetResponse> {
self.check_capability(ServerCapability::Prompts)?;
let params = types::PromptsGetParams {
name: prompt.as_ref().to_string(),
arguments: Some(arguments),
meta: None,
};
let response: types::PromptsGetResponse = self
.inner
.request(types::RequestType::PromptsGet.as_str(), params)
.await?;
Ok(response)
}
pub async fn completion<P: Into<String>>(
&self,
reference: types::CompletionReference,
argument: P,
value: P,
) -> Result<types::Completion> {
let params = types::CompletionCompleteParams {
r#ref: reference,
argument: types::CompletionArgument {
name: argument.into(),
value: value.into(),
},
meta: None,
};
let result: types::CompletionCompleteResponse = self
.inner
.request(types::RequestType::CompletionComplete.as_str(), params)
.await?;
let completion = types::Completion {
values: result.completion.values,
total: types::CompletionTotal::from_options(
result.completion.has_more,
result.completion.total,
),
};
Ok(completion)
}
/// List MCP tools.
pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
self.check_capability(ServerCapability::Tools)?;
let response = self
.inner
.request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
.await?;
Ok(response)
}
/// Executes a tool with the given arguments
pub async fn run_tool<P: AsRef<str>>(
&self,
tool: P,
arguments: Option<HashMap<String, serde_json::Value>>,
) -> Result<types::CallToolResponse> {
self.check_capability(ServerCapability::Tools)?;
let params = types::CallToolParams {
name: tool.as_ref().to_string(),
arguments,
meta: None,
};
let response: types::CallToolResponse = self
.inner
.request(types::RequestType::CallTool.as_str(), params)
.await?;
Ok(response)
}
}
impl InitializedContextServerProtocol {
pub async fn request<R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: impl serde::Serialize,
) -> Result<R> {
self.inner.request(method, params).await
} }
} }

View file

@ -0,0 +1,118 @@
use anyhow::Context as _;
use collections::HashMap;
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::BackgroundExecutor;
use std::{pin::Pin, sync::Arc};
use crate::{
transport::Transport,
types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
};
pub fn create_fake_transport(
name: impl Into<String>,
executor: BackgroundExecutor,
) -> FakeTransport {
let name = name.into();
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
create_initialize_response(name.clone())
})
}
fn create_initialize_response(server_name: String) -> InitializeResponse {
InitializeResponse {
protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
}
}
pub struct FakeTransport {
request_handlers:
HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
executor: BackgroundExecutor,
}
impl FakeTransport {
pub fn new(executor: BackgroundExecutor) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
request_handlers: Default::default(),
tx,
rx: Arc::new(Mutex::new(rx)),
executor,
}
}
pub fn on_request<T: crate::types::Request>(
mut self,
handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
) -> Self {
self.request_handlers.insert(
T::METHOD,
Arc::new(move |value| {
let params = value.get("params").expect("Missing parameters").clone();
let params: T::Params =
serde_json::from_value(params).expect("Invalid parameters received");
let response = handler(params);
serde_json::to_value(response).unwrap()
}),
);
self
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> anyhow::Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let method = method.as_str().expect("Invalid method received");
if let Some(handler) = self.request_handlers.get(method) {
let payload = handler(msg);
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.context("sending a message")?;
} else {
log::debug!("No handler registered for MCP request '{method}'");
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
let executor = self.executor.clone();
Box::pin(futures::stream::unfold(rx, move |rx| {
let executor = executor.clone();
async move {
let mut rx_guard = rx.lock().await;
executor.simulate_random_delay().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}

View file

@ -1,76 +1,92 @@
use collections::HashMap; use collections::HashMap;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url; use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
pub enum RequestType { pub mod request {
use super::*;
macro_rules! request {
($method:expr, $name:ident, $params:ty, $response:ty) => {
pub struct $name;
impl Request for $name {
type Params = $params;
type Response = $response;
const METHOD: &'static str = $method;
}
};
}
request!(
"initialize",
Initialize, Initialize,
CallTool, InitializeParams,
InitializeResponse
);
request!("tools/call", CallTool, CallToolParams, CallToolResponse);
request!(
"resources/unsubscribe",
ResourcesUnsubscribe, ResourcesUnsubscribe,
ResourcesUnsubscribeParams,
()
);
request!(
"resources/subscribe",
ResourcesSubscribe, ResourcesSubscribe,
ResourcesSubscribeParams,
()
);
request!(
"resources/read",
ResourcesRead, ResourcesRead,
ResourcesList, ResourcesReadParams,
ResourcesReadResponse
);
request!("resources/list", ResourcesList, (), ResourcesListResponse);
request!(
"logging/setLevel",
LoggingSetLevel, LoggingSetLevel,
LoggingSetLevelParams,
()
);
request!(
"prompts/get",
PromptsGet, PromptsGet,
PromptsList, PromptsGetParams,
PromptsGetResponse
);
request!("prompts/list", PromptsList, (), PromptsListResponse);
request!(
"completion/complete",
CompletionComplete, CompletionComplete,
Ping, CompletionCompleteParams,
ListTools, CompletionCompleteResponse
);
request!("ping", Ping, (), ());
request!("tools/list", ListTools, (), ListToolsResponse);
request!(
"resources/templates/list",
ListResourceTemplates, ListResourceTemplates,
ListRoots, (),
ListResourceTemplatesResponse
);
request!("roots/list", ListRoots, (), ListRootsResponse);
} }
impl RequestType { pub trait Request {
pub fn as_str(&self) -> &'static str { type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
match self { type Response: DeserializeOwned + Serialize + Send + Sync + 'static;
RequestType::Initialize => "initialize", const METHOD: &'static str;
RequestType::CallTool => "tools/call",
RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
RequestType::ResourcesSubscribe => "resources/subscribe",
RequestType::ResourcesRead => "resources/read",
RequestType::ResourcesList => "resources/list",
RequestType::LoggingSetLevel => "logging/setLevel",
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
RequestType::Ping => "ping",
RequestType::ListTools => "tools/list",
RequestType::ListResourceTemplates => "resources/templates/list",
RequestType::ListRoots => "roots/list",
}
}
}
impl TryFrom<&str> for RequestType {
type Error = ();
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"initialize" => Ok(RequestType::Initialize),
"tools/call" => Ok(RequestType::CallTool),
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
"resources/read" => Ok(RequestType::ResourcesRead),
"resources/list" => Ok(RequestType::ResourcesList),
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
"prompts/get" => Ok(RequestType::PromptsGet),
"prompts/list" => Ok(RequestType::PromptsList),
"completion/complete" => Ok(RequestType::CompletionComplete),
"ping" => Ok(RequestType::Ping),
"tools/list" => Ok(RequestType::ListTools),
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
"roots/list" => Ok(RequestType::ListRoots),
_ => Err(()),
}
}
} }
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)] #[serde(transparent)]
pub struct ProtocolVersion(pub String); pub struct ProtocolVersion(pub String);
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct InitializeParams { pub struct InitializeParams {
pub protocol_version: ProtocolVersion, pub protocol_version: ProtocolVersion,
@ -80,7 +96,7 @@ pub struct InitializeParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CallToolParams { pub struct CallToolParams {
pub name: String, pub name: String,
@ -90,7 +106,7 @@ pub struct CallToolParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesUnsubscribeParams { pub struct ResourcesUnsubscribeParams {
pub uri: Url, pub uri: Url,
@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesSubscribeParams { pub struct ResourcesSubscribeParams {
pub uri: Url, pub uri: Url,
@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesReadParams { pub struct ResourcesReadParams {
pub uri: Url, pub uri: Url,
@ -114,7 +130,7 @@ pub struct ResourcesReadParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct LoggingSetLevelParams { pub struct LoggingSetLevelParams {
pub level: LoggingLevel, pub level: LoggingLevel,
@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PromptsGetParams { pub struct PromptsGetParams {
pub name: String, pub name: String,
@ -132,37 +148,40 @@ pub struct PromptsGetParams {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CompletionCompleteParams { pub struct CompletionCompleteParams {
pub r#ref: CompletionReference, #[serde(rename = "ref")]
pub reference: CompletionReference,
pub argument: CompletionArgument, pub argument: CompletionArgument,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum CompletionReference { pub enum CompletionReference {
Prompt(PromptReference), Prompt(PromptReference),
Resource(ResourceReference), Resource(ResourceReference),
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PromptReference { pub struct PromptReference {
pub r#type: PromptReferenceType, #[serde(rename = "type")]
pub ty: PromptReferenceType,
pub name: String, pub name: String,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourceReference { pub struct ResourceReference {
pub r#type: PromptReferenceType, #[serde(rename = "type")]
pub ty: PromptReferenceType,
pub uri: Url, pub uri: Url,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptReferenceType { pub enum PromptReferenceType {
#[serde(rename = "ref/prompt")] #[serde(rename = "ref/prompt")]
@ -171,7 +190,7 @@ pub enum PromptReferenceType {
Resource, Resource,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CompletionArgument { pub struct CompletionArgument {
pub name: String, pub name: String,
@ -188,7 +207,7 @@ pub struct InitializeResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesReadResponse { pub struct ResourcesReadResponse {
pub contents: Vec<ResourceContentsType>, pub contents: Vec<ResourceContentsType>,
@ -196,14 +215,14 @@ pub struct ResourcesReadResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum ResourceContentsType { pub enum ResourceContentsType {
Text(TextResourceContents), Text(TextResourceContents),
Blob(BlobResourceContents), Blob(BlobResourceContents),
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse { pub struct ResourcesListResponse {
pub resources: Vec<Resource>, pub resources: Vec<Resource>,
@ -220,7 +239,7 @@ pub struct SamplingMessage {
pub content: MessageContent, pub content: MessageContent,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CreateMessageRequest { pub struct CreateMessageRequest {
pub messages: Vec<SamplingMessage>, pub messages: Vec<SamplingMessage>,
@ -296,7 +315,7 @@ pub struct MessageAnnotations {
pub priority: Option<f64>, pub priority: Option<f64>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PromptsGetResponse { pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@ -306,7 +325,7 @@ pub struct PromptsGetResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PromptsListResponse { pub struct PromptsListResponse {
pub prompts: Vec<Prompt>, pub prompts: Vec<Prompt>,
@ -316,7 +335,7 @@ pub struct PromptsListResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse { pub struct CompletionCompleteResponse {
pub completion: CompletionResult, pub completion: CompletionResult,
@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CompletionResult { pub struct CompletionResult {
pub values: Vec<String>, pub values: Vec<String>,
@ -336,7 +355,7 @@ pub struct CompletionResult {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct Prompt { pub struct Prompt {
pub name: String, pub name: String,
@ -346,7 +365,7 @@ pub struct Prompt {
pub arguments: Option<Vec<PromptArgument>>, pub arguments: Option<Vec<PromptArgument>>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct PromptArgument { pub struct PromptArgument {
pub name: String, pub name: String,
@ -509,7 +528,7 @@ pub struct ModelHint {
pub name: Option<String>, pub name: Option<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub enum NotificationType { pub enum NotificationType {
Initialized, Initialized,
@ -589,7 +608,7 @@ pub struct Completion {
pub total: CompletionTotal, pub total: CompletionTotal,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct CallToolResponse { pub struct CallToolResponse {
pub content: Vec<ToolResponseContent>, pub content: Vec<ToolResponseContent>,
@ -620,7 +639,7 @@ pub struct ListToolsResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ListResourceTemplatesResponse { pub struct ListResourceTemplatesResponse {
pub resource_templates: Vec<ResourceTemplate>, pub resource_templates: Vec<ResourceTemplate>,
@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse {
pub meta: Option<HashMap<String, serde_json::Value>>, pub meta: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ListRootsResponse { pub struct ListRootsResponse {
pub roots: Vec<Root>, pub roots: Vec<Root>,

View file

@ -91,6 +91,7 @@ workspace-hack.workspace = true
[dev-dependencies] [dev-dependencies]
client = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] }
context_server = { workspace = true, features = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] }
dap = { workspace = true, features = ["test-support"] } dap = { workspace = true, features = ["test-support"] }
dap_adapters = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] }

View file

@ -499,17 +499,10 @@ impl ContextServerStore {
mod tests { mod tests {
use super::*; use super::*;
use crate::{FakeFs, Project, project_settings::ProjectSettings}; use crate::{FakeFs, Project, project_settings::ProjectSettings};
use context_server::{ use context_server::test::create_fake_transport;
transport::Transport, use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
types::{
self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
ServerCapabilities,
},
};
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
use serde_json::json; use serde_json::json;
use std::{cell::RefCell, pin::Pin, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use util::path; use util::path;
#[gpui::test] #[gpui::test]
@ -532,34 +525,18 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
}); });
let server_1_id = ContextServerId("mcp-1".into()); let server_1_id = ContextServerId(SERVER_1_ID.into());
let server_2_id = ContextServerId("mcp-2".into()); let server_2_id = ContextServerId(SERVER_2_ID.into());
let transport_1 = let server_1 = Arc::new(ContextServer::new(
Arc::new(FakeTransport::new( server_1_id.clone(),
cx.executor(), Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
)); ));
let server_2 = Arc::new(ContextServer::new(
let transport_2 = server_2_id.clone(),
Arc::new(FakeTransport::new( Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
)); ));
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
store store
.update(cx, |store, cx| store.start_server(server_1, cx)) .update(cx, |store, cx| store.start_server(server_1, cx))
.unwrap(); .unwrap();
@ -627,34 +604,18 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
}); });
let server_1_id = ContextServerId("mcp-1".into()); let server_1_id = ContextServerId(SERVER_1_ID.into());
let server_2_id = ContextServerId("mcp-2".into()); let server_2_id = ContextServerId(SERVER_2_ID.into());
let transport_1 = let server_1 = Arc::new(ContextServer::new(
Arc::new(FakeTransport::new( server_1_id.clone(),
cx.executor(), Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
)); ));
let server_2 = Arc::new(ContextServer::new(
let transport_2 = server_2_id.clone(),
Arc::new(FakeTransport::new( Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
)); ));
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
let _server_events = assert_server_events( let _server_events = assert_server_events(
&store, &store,
vec![ vec![
@ -702,31 +663,15 @@ mod tests {
let server_id = ContextServerId(SERVER_1_ID.into()); let server_id = ContextServerId(SERVER_1_ID.into());
let transport_1 = let server_with_same_id_1 = Arc::new(ContextServer::new(
Arc::new(FakeTransport::new( server_id.clone(),
cx.executor(), Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response(SERVER_1_ID.to_string()))
}
_ => None,
},
)); ));
let server_with_same_id_2 = Arc::new(ContextServer::new(
let transport_2 = server_id.clone(),
Arc::new(FakeTransport::new( Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response(SERVER_1_ID.to_string()))
}
_ => None,
},
)); ));
let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
// If we start another server with the same id, we should report that we stopped the previous one // If we start another server with the same id, we should report that we stopped the previous one
let _server_events = assert_server_events( let _server_events = assert_server_events(
&store, &store,
@ -794,16 +739,10 @@ mod tests {
let store = cx.new(|cx| { let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop( ContextServerStore::test_maintain_server_loop(
Box::new(move |id, _| { Box::new(move |id, _| {
let transport = FakeTransport::new(executor.clone(), { Arc::new(ContextServer::new(
let id = id.0.clone(); id.clone(),
move |_, request_type, _| match request_type { Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
Some(RequestType::Initialize) => { ))
Some(create_initialize_response(id.clone().to_string()))
}
_ => None,
}
});
Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
}), }),
registry.clone(), registry.clone(),
project.read(cx).worktree_store(), project.read(cx).worktree_store(),
@ -1033,99 +972,4 @@ mod tests {
(fs, project) (fs, project)
} }
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
executor: BackgroundExecutor,
}
impl FakeTransport {
fn new(
executor: BackgroundExecutor,
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
executor,
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.context("sending a message")?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
let executor = self.executor.clone();
Box::pin(futures::stream::unfold(rx, move |rx| {
let executor = executor.clone();
async move {
let mut rx_guard = rx.lock().await;
executor.simulate_random_delay().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
} }