context server: Make requests type safe (#32254)
This changes the context server crate so that the input/output for a request are encoded at the type level, similar to how it is done for LSP requests. This also makes it easier to write tests that mock context servers, e.g. you can write something like this now when using the `test-support` feature of the `context-server` crate: ```rust create_fake_transport("mcp-1", cx.background_executor()) .on_request::<context_server::types::request::PromptsList>(|_params| { PromptsListResponse { prompts: vec![/* some prompts */], .. } }) ``` Release Notes: - N/A
This commit is contained in:
parent
454adfacae
commit
95d78ff8d5
11 changed files with 320 additions and 433 deletions
|
@ -104,7 +104,15 @@ impl Tool for ContextServerTool {
|
||||||
tool_name,
|
tool_name,
|
||||||
arguments
|
arguments
|
||||||
);
|
);
|
||||||
let response = protocol.run_tool(tool_name, arguments).await?;
|
let response = protocol
|
||||||
|
.request::<context_server::types::request::CallTool>(
|
||||||
|
context_server::types::CallToolParams {
|
||||||
|
name: tool_name,
|
||||||
|
arguments,
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
for content in response.content {
|
for content in response.content {
|
||||||
|
|
|
@ -566,10 +566,14 @@ impl ThreadStore {
|
||||||
};
|
};
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
|
||||||
if let Some(tools) = protocol.list_tools().await.log_err() {
|
if let Some(response) = protocol
|
||||||
|
.request::<context_server::types::request::ListTools>(())
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
{
|
||||||
let tool_ids = tool_working_set
|
let tool_ids = tool_working_set
|
||||||
.update(cx, |tool_working_set, _| {
|
.update(cx, |tool_working_set, _| {
|
||||||
tools
|
response
|
||||||
.tools
|
.tools
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|tool| {
|
.map(|tool| {
|
||||||
|
|
|
@ -864,8 +864,13 @@ impl ContextStore {
|
||||||
};
|
};
|
||||||
|
|
||||||
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
|
||||||
if let Some(prompts) = protocol.list_prompts().await.log_err() {
|
if let Some(response) = protocol
|
||||||
let slash_command_ids = prompts
|
.request::<context_server::types::request::PromptsList>(())
|
||||||
|
.await
|
||||||
|
.log_err()
|
||||||
|
{
|
||||||
|
let slash_command_ids = response
|
||||||
|
.prompts
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(assistant_slash_commands::acceptable_prompt)
|
.filter(assistant_slash_commands::acceptable_prompt)
|
||||||
.map(|prompt| {
|
.map(|prompt| {
|
||||||
|
|
|
@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let protocol = server.client().context("Context server not initialized")?;
|
let protocol = server.client().context("Context server not initialized")?;
|
||||||
|
|
||||||
let completion_result = protocol
|
let response = protocol
|
||||||
.completion(
|
.request::<context_server::types::request::CompletionComplete>(
|
||||||
context_server::types::CompletionReference::Prompt(
|
context_server::types::CompletionCompleteParams {
|
||||||
|
reference: context_server::types::CompletionReference::Prompt(
|
||||||
context_server::types::PromptReference {
|
context_server::types::PromptReference {
|
||||||
r#type: context_server::types::PromptReferenceType::Prompt,
|
ty: context_server::types::PromptReferenceType::Prompt,
|
||||||
name: prompt_name,
|
name: prompt_name,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
arg_name,
|
argument: context_server::types::CompletionArgument {
|
||||||
arg_value,
|
name: arg_name,
|
||||||
|
value: arg_value,
|
||||||
|
},
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let completions = completion_result
|
let completions = response
|
||||||
|
.completion
|
||||||
.values
|
.values
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|value| ArgumentCompletion {
|
.map(|value| ArgumentCompletion {
|
||||||
|
@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
if let Some(server) = store.get_running_server(&server_id) {
|
if let Some(server) = store.get_running_server(&server_id) {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let protocol = server.client().context("Context server not initialized")?;
|
let protocol = server.client().context("Context server not initialized")?;
|
||||||
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
|
let response = protocol
|
||||||
|
.request::<context_server::types::request::PromptsGet>(
|
||||||
|
context_server::types::PromptsGetParams {
|
||||||
|
name: prompt_name.clone(),
|
||||||
|
arguments: Some(prompt_args),
|
||||||
|
meta: None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
result
|
response
|
||||||
.messages
|
.messages
|
||||||
.iter()
|
.iter()
|
||||||
.all(|msg| matches!(msg.role, context_server::types::Role::User)),
|
.all(|msg| matches!(msg.role, context_server::types::Role::User)),
|
||||||
|
@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract text from user messages into a single prompt string
|
// Extract text from user messages into a single prompt string
|
||||||
let mut prompt = result
|
let mut prompt = response
|
||||||
.messages
|
.messages
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|msg| match msg.content {
|
.filter_map(|msg| match msg.content {
|
||||||
|
@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
range: 0..(prompt.len()),
|
range: 0..(prompt.len()),
|
||||||
icon: IconName::ZedAssistant,
|
icon: IconName::ZedAssistant,
|
||||||
label: SharedString::from(
|
label: SharedString::from(
|
||||||
result
|
response
|
||||||
.description
|
.description
|
||||||
.unwrap_or(format!("Result from {}", prompt_name)),
|
.unwrap_or(format!("Result from {}", prompt_name)),
|
||||||
),
|
),
|
||||||
|
|
|
@ -11,6 +11,9 @@ workspace = true
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/context_server.rs"
|
path = "src/context_server.rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
test-support = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
|
pub mod test;
|
||||||
pub mod transport;
|
pub mod transport;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,9 @@
|
||||||
//! of messages.
|
//! of messages.
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use collections::HashMap;
|
|
||||||
|
|
||||||
use crate::client::Client;
|
use crate::client::Client;
|
||||||
use crate::types;
|
use crate::types::{self, Request};
|
||||||
|
|
||||||
pub struct ModelContextProtocol {
|
pub struct ModelContextProtocol {
|
||||||
inner: Client,
|
inner: Client,
|
||||||
|
@ -43,7 +42,7 @@ impl ModelContextProtocol {
|
||||||
|
|
||||||
let response: types::InitializeResponse = self
|
let response: types::InitializeResponse = self
|
||||||
.inner
|
.inner
|
||||||
.request(types::RequestType::Initialize.as_str(), params)
|
.request(types::request::Initialize::METHOD, params)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
anyhow::ensure!(
|
anyhow::ensure!(
|
||||||
|
@ -94,137 +93,7 @@ impl InitializedContextServerProtocol {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_capability(&self, capability: ServerCapability) -> Result<()> {
|
pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
|
||||||
anyhow::ensure!(
|
self.inner.request(T::METHOD, params).await
|
||||||
self.capable(capability),
|
|
||||||
"Server does not support {capability:?} capability"
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List the MCP prompts.
|
|
||||||
pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
|
|
||||||
self.check_capability(ServerCapability::Prompts)?;
|
|
||||||
|
|
||||||
let response: types::PromptsListResponse = self
|
|
||||||
.inner
|
|
||||||
.request(
|
|
||||||
types::RequestType::PromptsList.as_str(),
|
|
||||||
serde_json::json!({}),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(response.prompts)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List the MCP resources.
|
|
||||||
pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
|
|
||||||
self.check_capability(ServerCapability::Resources)?;
|
|
||||||
|
|
||||||
let response: types::ResourcesListResponse = self
|
|
||||||
.inner
|
|
||||||
.request(
|
|
||||||
types::RequestType::ResourcesList.as_str(),
|
|
||||||
serde_json::json!({}),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Executes a prompt with the given arguments and returns the result.
|
|
||||||
pub async fn run_prompt<P: AsRef<str>>(
|
|
||||||
&self,
|
|
||||||
prompt: P,
|
|
||||||
arguments: HashMap<String, String>,
|
|
||||||
) -> Result<types::PromptsGetResponse> {
|
|
||||||
self.check_capability(ServerCapability::Prompts)?;
|
|
||||||
|
|
||||||
let params = types::PromptsGetParams {
|
|
||||||
name: prompt.as_ref().to_string(),
|
|
||||||
arguments: Some(arguments),
|
|
||||||
meta: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response: types::PromptsGetResponse = self
|
|
||||||
.inner
|
|
||||||
.request(types::RequestType::PromptsGet.as_str(), params)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn completion<P: Into<String>>(
|
|
||||||
&self,
|
|
||||||
reference: types::CompletionReference,
|
|
||||||
argument: P,
|
|
||||||
value: P,
|
|
||||||
) -> Result<types::Completion> {
|
|
||||||
let params = types::CompletionCompleteParams {
|
|
||||||
r#ref: reference,
|
|
||||||
argument: types::CompletionArgument {
|
|
||||||
name: argument.into(),
|
|
||||||
value: value.into(),
|
|
||||||
},
|
|
||||||
meta: None,
|
|
||||||
};
|
|
||||||
let result: types::CompletionCompleteResponse = self
|
|
||||||
.inner
|
|
||||||
.request(types::RequestType::CompletionComplete.as_str(), params)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let completion = types::Completion {
|
|
||||||
values: result.completion.values,
|
|
||||||
total: types::CompletionTotal::from_options(
|
|
||||||
result.completion.has_more,
|
|
||||||
result.completion.total,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(completion)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List MCP tools.
|
|
||||||
pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
|
|
||||||
self.check_capability(ServerCapability::Tools)?;
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.inner
|
|
||||||
.request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Executes a tool with the given arguments
|
|
||||||
pub async fn run_tool<P: AsRef<str>>(
|
|
||||||
&self,
|
|
||||||
tool: P,
|
|
||||||
arguments: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
) -> Result<types::CallToolResponse> {
|
|
||||||
self.check_capability(ServerCapability::Tools)?;
|
|
||||||
|
|
||||||
let params = types::CallToolParams {
|
|
||||||
name: tool.as_ref().to_string(),
|
|
||||||
arguments,
|
|
||||||
meta: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response: types::CallToolResponse = self
|
|
||||||
.inner
|
|
||||||
.request(types::RequestType::CallTool.as_str(), params)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InitializedContextServerProtocol {
|
|
||||||
pub async fn request<R: serde::de::DeserializeOwned>(
|
|
||||||
&self,
|
|
||||||
method: &str,
|
|
||||||
params: impl serde::Serialize,
|
|
||||||
) -> Result<R> {
|
|
||||||
self.inner.request(method, params).await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
118
crates/context_server/src/test.rs
Normal file
118
crates/context_server/src/test.rs
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
use anyhow::Context as _;
|
||||||
|
use collections::HashMap;
|
||||||
|
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||||
|
use gpui::BackgroundExecutor;
|
||||||
|
use std::{pin::Pin, sync::Arc};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
transport::Transport,
|
||||||
|
types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn create_fake_transport(
|
||||||
|
name: impl Into<String>,
|
||||||
|
executor: BackgroundExecutor,
|
||||||
|
) -> FakeTransport {
|
||||||
|
let name = name.into();
|
||||||
|
FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
|
||||||
|
create_initialize_response(name.clone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_initialize_response(server_name: String) -> InitializeResponse {
|
||||||
|
InitializeResponse {
|
||||||
|
protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||||
|
server_info: Implementation {
|
||||||
|
name: server_name,
|
||||||
|
version: "1.0.0".to_string(),
|
||||||
|
},
|
||||||
|
capabilities: ServerCapabilities::default(),
|
||||||
|
meta: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FakeTransport {
|
||||||
|
request_handlers:
|
||||||
|
HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
|
||||||
|
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||||
|
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||||
|
executor: BackgroundExecutor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeTransport {
|
||||||
|
pub fn new(executor: BackgroundExecutor) -> Self {
|
||||||
|
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||||
|
Self {
|
||||||
|
request_handlers: Default::default(),
|
||||||
|
tx,
|
||||||
|
rx: Arc::new(Mutex::new(rx)),
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_request<T: crate::types::Request>(
|
||||||
|
mut self,
|
||||||
|
handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
|
||||||
|
) -> Self {
|
||||||
|
self.request_handlers.insert(
|
||||||
|
T::METHOD,
|
||||||
|
Arc::new(move |value| {
|
||||||
|
let params = value.get("params").expect("Missing parameters").clone();
|
||||||
|
let params: T::Params =
|
||||||
|
serde_json::from_value(params).expect("Invalid parameters received");
|
||||||
|
let response = handler(params);
|
||||||
|
serde_json::to_value(response).unwrap()
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Transport for FakeTransport {
|
||||||
|
async fn send(&self, message: String) -> anyhow::Result<()> {
|
||||||
|
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||||
|
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||||
|
|
||||||
|
if let Some(method) = msg.get("method") {
|
||||||
|
let method = method.as_str().expect("Invalid method received");
|
||||||
|
if let Some(handler) = self.request_handlers.get(method) {
|
||||||
|
let payload = handler(msg);
|
||||||
|
let response = serde_json::json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": id,
|
||||||
|
"result": payload
|
||||||
|
});
|
||||||
|
self.tx
|
||||||
|
.unbounded_send(response.to_string())
|
||||||
|
.context("sending a message")?;
|
||||||
|
} else {
|
||||||
|
log::debug!("No handler registered for MCP request '{method}'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||||
|
let rx = self.rx.clone();
|
||||||
|
let executor = self.executor.clone();
|
||||||
|
Box::pin(futures::stream::unfold(rx, move |rx| {
|
||||||
|
let executor = executor.clone();
|
||||||
|
async move {
|
||||||
|
let mut rx_guard = rx.lock().await;
|
||||||
|
executor.simulate_random_delay().await;
|
||||||
|
if let Some(message) = rx_guard.next().await {
|
||||||
|
drop(rx_guard);
|
||||||
|
Some((message, rx))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||||
|
Box::pin(futures::stream::empty())
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,76 +1,92 @@
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
|
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||||
|
|
||||||
pub enum RequestType {
|
pub mod request {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
macro_rules! request {
|
||||||
|
($method:expr, $name:ident, $params:ty, $response:ty) => {
|
||||||
|
pub struct $name;
|
||||||
|
|
||||||
|
impl Request for $name {
|
||||||
|
type Params = $params;
|
||||||
|
type Response = $response;
|
||||||
|
const METHOD: &'static str = $method;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
request!(
|
||||||
|
"initialize",
|
||||||
Initialize,
|
Initialize,
|
||||||
CallTool,
|
InitializeParams,
|
||||||
|
InitializeResponse
|
||||||
|
);
|
||||||
|
request!("tools/call", CallTool, CallToolParams, CallToolResponse);
|
||||||
|
request!(
|
||||||
|
"resources/unsubscribe",
|
||||||
ResourcesUnsubscribe,
|
ResourcesUnsubscribe,
|
||||||
|
ResourcesUnsubscribeParams,
|
||||||
|
()
|
||||||
|
);
|
||||||
|
request!(
|
||||||
|
"resources/subscribe",
|
||||||
ResourcesSubscribe,
|
ResourcesSubscribe,
|
||||||
|
ResourcesSubscribeParams,
|
||||||
|
()
|
||||||
|
);
|
||||||
|
request!(
|
||||||
|
"resources/read",
|
||||||
ResourcesRead,
|
ResourcesRead,
|
||||||
ResourcesList,
|
ResourcesReadParams,
|
||||||
|
ResourcesReadResponse
|
||||||
|
);
|
||||||
|
request!("resources/list", ResourcesList, (), ResourcesListResponse);
|
||||||
|
request!(
|
||||||
|
"logging/setLevel",
|
||||||
LoggingSetLevel,
|
LoggingSetLevel,
|
||||||
|
LoggingSetLevelParams,
|
||||||
|
()
|
||||||
|
);
|
||||||
|
request!(
|
||||||
|
"prompts/get",
|
||||||
PromptsGet,
|
PromptsGet,
|
||||||
PromptsList,
|
PromptsGetParams,
|
||||||
|
PromptsGetResponse
|
||||||
|
);
|
||||||
|
request!("prompts/list", PromptsList, (), PromptsListResponse);
|
||||||
|
request!(
|
||||||
|
"completion/complete",
|
||||||
CompletionComplete,
|
CompletionComplete,
|
||||||
Ping,
|
CompletionCompleteParams,
|
||||||
ListTools,
|
CompletionCompleteResponse
|
||||||
|
);
|
||||||
|
request!("ping", Ping, (), ());
|
||||||
|
request!("tools/list", ListTools, (), ListToolsResponse);
|
||||||
|
request!(
|
||||||
|
"resources/templates/list",
|
||||||
ListResourceTemplates,
|
ListResourceTemplates,
|
||||||
ListRoots,
|
(),
|
||||||
|
ListResourceTemplatesResponse
|
||||||
|
);
|
||||||
|
request!("roots/list", ListRoots, (), ListRootsResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RequestType {
|
pub trait Request {
|
||||||
pub fn as_str(&self) -> &'static str {
|
type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||||
match self {
|
type Response: DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||||
RequestType::Initialize => "initialize",
|
const METHOD: &'static str;
|
||||||
RequestType::CallTool => "tools/call",
|
|
||||||
RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
|
|
||||||
RequestType::ResourcesSubscribe => "resources/subscribe",
|
|
||||||
RequestType::ResourcesRead => "resources/read",
|
|
||||||
RequestType::ResourcesList => "resources/list",
|
|
||||||
RequestType::LoggingSetLevel => "logging/setLevel",
|
|
||||||
RequestType::PromptsGet => "prompts/get",
|
|
||||||
RequestType::PromptsList => "prompts/list",
|
|
||||||
RequestType::CompletionComplete => "completion/complete",
|
|
||||||
RequestType::Ping => "ping",
|
|
||||||
RequestType::ListTools => "tools/list",
|
|
||||||
RequestType::ListResourceTemplates => "resources/templates/list",
|
|
||||||
RequestType::ListRoots => "roots/list",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&str> for RequestType {
|
|
||||||
type Error = ();
|
|
||||||
|
|
||||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
|
||||||
match s {
|
|
||||||
"initialize" => Ok(RequestType::Initialize),
|
|
||||||
"tools/call" => Ok(RequestType::CallTool),
|
|
||||||
"resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
|
|
||||||
"resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
|
|
||||||
"resources/read" => Ok(RequestType::ResourcesRead),
|
|
||||||
"resources/list" => Ok(RequestType::ResourcesList),
|
|
||||||
"logging/setLevel" => Ok(RequestType::LoggingSetLevel),
|
|
||||||
"prompts/get" => Ok(RequestType::PromptsGet),
|
|
||||||
"prompts/list" => Ok(RequestType::PromptsList),
|
|
||||||
"completion/complete" => Ok(RequestType::CompletionComplete),
|
|
||||||
"ping" => Ok(RequestType::Ping),
|
|
||||||
"tools/list" => Ok(RequestType::ListTools),
|
|
||||||
"resources/templates/list" => Ok(RequestType::ListResourceTemplates),
|
|
||||||
"roots/list" => Ok(RequestType::ListRoots),
|
|
||||||
_ => Err(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub struct ProtocolVersion(pub String);
|
pub struct ProtocolVersion(pub String);
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct InitializeParams {
|
pub struct InitializeParams {
|
||||||
pub protocol_version: ProtocolVersion,
|
pub protocol_version: ProtocolVersion,
|
||||||
|
@ -80,7 +96,7 @@ pub struct InitializeParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CallToolParams {
|
pub struct CallToolParams {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -90,7 +106,7 @@ pub struct CallToolParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourcesUnsubscribeParams {
|
pub struct ResourcesUnsubscribeParams {
|
||||||
pub uri: Url,
|
pub uri: Url,
|
||||||
|
@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourcesSubscribeParams {
|
pub struct ResourcesSubscribeParams {
|
||||||
pub uri: Url,
|
pub uri: Url,
|
||||||
|
@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourcesReadParams {
|
pub struct ResourcesReadParams {
|
||||||
pub uri: Url,
|
pub uri: Url,
|
||||||
|
@ -114,7 +130,7 @@ pub struct ResourcesReadParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct LoggingSetLevelParams {
|
pub struct LoggingSetLevelParams {
|
||||||
pub level: LoggingLevel,
|
pub level: LoggingLevel,
|
||||||
|
@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PromptsGetParams {
|
pub struct PromptsGetParams {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -132,37 +148,40 @@ pub struct PromptsGetParams {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CompletionCompleteParams {
|
pub struct CompletionCompleteParams {
|
||||||
pub r#ref: CompletionReference,
|
#[serde(rename = "ref")]
|
||||||
|
pub reference: CompletionReference,
|
||||||
pub argument: CompletionArgument,
|
pub argument: CompletionArgument,
|
||||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum CompletionReference {
|
pub enum CompletionReference {
|
||||||
Prompt(PromptReference),
|
Prompt(PromptReference),
|
||||||
Resource(ResourceReference),
|
Resource(ResourceReference),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PromptReference {
|
pub struct PromptReference {
|
||||||
pub r#type: PromptReferenceType,
|
#[serde(rename = "type")]
|
||||||
|
pub ty: PromptReferenceType,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourceReference {
|
pub struct ResourceReference {
|
||||||
pub r#type: PromptReferenceType,
|
#[serde(rename = "type")]
|
||||||
|
pub ty: PromptReferenceType,
|
||||||
pub uri: Url,
|
pub uri: Url,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum PromptReferenceType {
|
pub enum PromptReferenceType {
|
||||||
#[serde(rename = "ref/prompt")]
|
#[serde(rename = "ref/prompt")]
|
||||||
|
@ -171,7 +190,7 @@ pub enum PromptReferenceType {
|
||||||
Resource,
|
Resource,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CompletionArgument {
|
pub struct CompletionArgument {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -188,7 +207,7 @@ pub struct InitializeResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourcesReadResponse {
|
pub struct ResourcesReadResponse {
|
||||||
pub contents: Vec<ResourceContentsType>,
|
pub contents: Vec<ResourceContentsType>,
|
||||||
|
@ -196,14 +215,14 @@ pub struct ResourcesReadResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ResourceContentsType {
|
pub enum ResourceContentsType {
|
||||||
Text(TextResourceContents),
|
Text(TextResourceContents),
|
||||||
Blob(BlobResourceContents),
|
Blob(BlobResourceContents),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ResourcesListResponse {
|
pub struct ResourcesListResponse {
|
||||||
pub resources: Vec<Resource>,
|
pub resources: Vec<Resource>,
|
||||||
|
@ -220,7 +239,7 @@ pub struct SamplingMessage {
|
||||||
pub content: MessageContent,
|
pub content: MessageContent,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CreateMessageRequest {
|
pub struct CreateMessageRequest {
|
||||||
pub messages: Vec<SamplingMessage>,
|
pub messages: Vec<SamplingMessage>,
|
||||||
|
@ -296,7 +315,7 @@ pub struct MessageAnnotations {
|
||||||
pub priority: Option<f64>,
|
pub priority: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PromptsGetResponse {
|
pub struct PromptsGetResponse {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
@ -306,7 +325,7 @@ pub struct PromptsGetResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PromptsListResponse {
|
pub struct PromptsListResponse {
|
||||||
pub prompts: Vec<Prompt>,
|
pub prompts: Vec<Prompt>,
|
||||||
|
@ -316,7 +335,7 @@ pub struct PromptsListResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CompletionCompleteResponse {
|
pub struct CompletionCompleteResponse {
|
||||||
pub completion: CompletionResult,
|
pub completion: CompletionResult,
|
||||||
|
@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CompletionResult {
|
pub struct CompletionResult {
|
||||||
pub values: Vec<String>,
|
pub values: Vec<String>,
|
||||||
|
@ -336,7 +355,7 @@ pub struct CompletionResult {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct Prompt {
|
pub struct Prompt {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -346,7 +365,7 @@ pub struct Prompt {
|
||||||
pub arguments: Option<Vec<PromptArgument>>,
|
pub arguments: Option<Vec<PromptArgument>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct PromptArgument {
|
pub struct PromptArgument {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
@ -509,7 +528,7 @@ pub struct ModelHint {
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum NotificationType {
|
pub enum NotificationType {
|
||||||
Initialized,
|
Initialized,
|
||||||
|
@ -589,7 +608,7 @@ pub struct Completion {
|
||||||
pub total: CompletionTotal,
|
pub total: CompletionTotal,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct CallToolResponse {
|
pub struct CallToolResponse {
|
||||||
pub content: Vec<ToolResponseContent>,
|
pub content: Vec<ToolResponseContent>,
|
||||||
|
@ -620,7 +639,7 @@ pub struct ListToolsResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ListResourceTemplatesResponse {
|
pub struct ListResourceTemplatesResponse {
|
||||||
pub resource_templates: Vec<ResourceTemplate>,
|
pub resource_templates: Vec<ResourceTemplate>,
|
||||||
|
@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse {
|
||||||
pub meta: Option<HashMap<String, serde_json::Value>>,
|
pub meta: Option<HashMap<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ListRootsResponse {
|
pub struct ListRootsResponse {
|
||||||
pub roots: Vec<Root>,
|
pub roots: Vec<Root>,
|
||||||
|
|
|
@ -91,6 +91,7 @@ workspace-hack.workspace = true
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
client = { workspace = true, features = ["test-support"] }
|
client = { workspace = true, features = ["test-support"] }
|
||||||
collections = { workspace = true, features = ["test-support"] }
|
collections = { workspace = true, features = ["test-support"] }
|
||||||
|
context_server = { workspace = true, features = ["test-support"] }
|
||||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||||
dap = { workspace = true, features = ["test-support"] }
|
dap = { workspace = true, features = ["test-support"] }
|
||||||
dap_adapters = { workspace = true, features = ["test-support"] }
|
dap_adapters = { workspace = true, features = ["test-support"] }
|
||||||
|
|
|
@ -499,17 +499,10 @@ impl ContextServerStore {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::{FakeFs, Project, project_settings::ProjectSettings};
|
use crate::{FakeFs, Project, project_settings::ProjectSettings};
|
||||||
use context_server::{
|
use context_server::test::create_fake_transport;
|
||||||
transport::Transport,
|
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||||
types::{
|
|
||||||
self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
|
|
||||||
ServerCapabilities,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
|
||||||
use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
|
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::{cell::RefCell, pin::Pin, rc::Rc};
|
use std::{cell::RefCell, rc::Rc};
|
||||||
use util::path;
|
use util::path;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -532,34 +525,18 @@ mod tests {
|
||||||
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let server_1_id = ContextServerId("mcp-1".into());
|
let server_1_id = ContextServerId(SERVER_1_ID.into());
|
||||||
let server_2_id = ContextServerId("mcp-2".into());
|
let server_2_id = ContextServerId(SERVER_2_ID.into());
|
||||||
|
|
||||||
let transport_1 =
|
let server_1 = Arc::new(ContextServer::new(
|
||||||
Arc::new(FakeTransport::new(
|
server_1_id.clone(),
|
||||||
cx.executor(),
|
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response("mcp-1".to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
let server_2 = Arc::new(ContextServer::new(
|
||||||
let transport_2 =
|
server_2_id.clone(),
|
||||||
Arc::new(FakeTransport::new(
|
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
|
||||||
cx.executor(),
|
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response("mcp-2".to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
|
|
||||||
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
|
|
||||||
|
|
||||||
store
|
store
|
||||||
.update(cx, |store, cx| store.start_server(server_1, cx))
|
.update(cx, |store, cx| store.start_server(server_1, cx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -627,34 +604,18 @@ mod tests {
|
||||||
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
||||||
});
|
});
|
||||||
|
|
||||||
let server_1_id = ContextServerId("mcp-1".into());
|
let server_1_id = ContextServerId(SERVER_1_ID.into());
|
||||||
let server_2_id = ContextServerId("mcp-2".into());
|
let server_2_id = ContextServerId(SERVER_2_ID.into());
|
||||||
|
|
||||||
let transport_1 =
|
let server_1 = Arc::new(ContextServer::new(
|
||||||
Arc::new(FakeTransport::new(
|
server_1_id.clone(),
|
||||||
cx.executor(),
|
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response("mcp-1".to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
let server_2 = Arc::new(ContextServer::new(
|
||||||
let transport_2 =
|
server_2_id.clone(),
|
||||||
Arc::new(FakeTransport::new(
|
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
|
||||||
cx.executor(),
|
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response("mcp-2".to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
|
|
||||||
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
|
|
||||||
|
|
||||||
let _server_events = assert_server_events(
|
let _server_events = assert_server_events(
|
||||||
&store,
|
&store,
|
||||||
vec![
|
vec![
|
||||||
|
@ -702,31 +663,15 @@ mod tests {
|
||||||
|
|
||||||
let server_id = ContextServerId(SERVER_1_ID.into());
|
let server_id = ContextServerId(SERVER_1_ID.into());
|
||||||
|
|
||||||
let transport_1 =
|
let server_with_same_id_1 = Arc::new(ContextServer::new(
|
||||||
Arc::new(FakeTransport::new(
|
server_id.clone(),
|
||||||
cx.executor(),
|
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response(SERVER_1_ID.to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
let server_with_same_id_2 = Arc::new(ContextServer::new(
|
||||||
let transport_2 =
|
server_id.clone(),
|
||||||
Arc::new(FakeTransport::new(
|
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||||
cx.executor(),
|
|
||||||
|_, request_type, _| match request_type {
|
|
||||||
Some(RequestType::Initialize) => {
|
|
||||||
Some(create_initialize_response(SERVER_1_ID.to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
},
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
|
|
||||||
let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
|
|
||||||
|
|
||||||
// If we start another server with the same id, we should report that we stopped the previous one
|
// If we start another server with the same id, we should report that we stopped the previous one
|
||||||
let _server_events = assert_server_events(
|
let _server_events = assert_server_events(
|
||||||
&store,
|
&store,
|
||||||
|
@ -794,16 +739,10 @@ mod tests {
|
||||||
let store = cx.new(|cx| {
|
let store = cx.new(|cx| {
|
||||||
ContextServerStore::test_maintain_server_loop(
|
ContextServerStore::test_maintain_server_loop(
|
||||||
Box::new(move |id, _| {
|
Box::new(move |id, _| {
|
||||||
let transport = FakeTransport::new(executor.clone(), {
|
Arc::new(ContextServer::new(
|
||||||
let id = id.0.clone();
|
id.clone(),
|
||||||
move |_, request_type, _| match request_type {
|
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
|
||||||
Some(RequestType::Initialize) => {
|
))
|
||||||
Some(create_initialize_response(id.clone().to_string()))
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
});
|
|
||||||
Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
|
|
||||||
}),
|
}),
|
||||||
registry.clone(),
|
registry.clone(),
|
||||||
project.read(cx).worktree_store(),
|
project.read(cx).worktree_store(),
|
||||||
|
@ -1033,99 +972,4 @@ mod tests {
|
||||||
|
|
||||||
(fs, project)
|
(fs, project)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
|
||||||
serde_json::to_value(&InitializeResponse {
|
|
||||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
|
||||||
server_info: Implementation {
|
|
||||||
name: server_name,
|
|
||||||
version: "1.0.0".to_string(),
|
|
||||||
},
|
|
||||||
capabilities: ServerCapabilities::default(),
|
|
||||||
meta: None,
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
struct FakeTransport {
|
|
||||||
on_request: Arc<
|
|
||||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
|
||||||
+ Send
|
|
||||||
+ Sync,
|
|
||||||
>,
|
|
||||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
|
||||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
|
||||||
executor: BackgroundExecutor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FakeTransport {
|
|
||||||
fn new(
|
|
||||||
executor: BackgroundExecutor,
|
|
||||||
on_request: impl Fn(
|
|
||||||
u64,
|
|
||||||
Option<RequestType>,
|
|
||||||
serde_json::Value,
|
|
||||||
) -> Option<serde_json::Value>
|
|
||||||
+ 'static
|
|
||||||
+ Send
|
|
||||||
+ Sync,
|
|
||||||
) -> Self {
|
|
||||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
|
||||||
Self {
|
|
||||||
on_request: Arc::new(on_request),
|
|
||||||
tx,
|
|
||||||
rx: Arc::new(Mutex::new(rx)),
|
|
||||||
executor,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
impl Transport for FakeTransport {
|
|
||||||
async fn send(&self, message: String) -> Result<()> {
|
|
||||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
|
||||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
|
||||||
|
|
||||||
if let Some(method) = msg.get("method") {
|
|
||||||
let request_type = method
|
|
||||||
.as_str()
|
|
||||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
|
||||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
|
||||||
let response = serde_json::json!({
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": id,
|
|
||||||
"result": payload
|
|
||||||
});
|
|
||||||
|
|
||||||
self.tx
|
|
||||||
.unbounded_send(response.to_string())
|
|
||||||
.context("sending a message")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
|
||||||
let rx = self.rx.clone();
|
|
||||||
let executor = self.executor.clone();
|
|
||||||
Box::pin(futures::stream::unfold(rx, move |rx| {
|
|
||||||
let executor = executor.clone();
|
|
||||||
async move {
|
|
||||||
let mut rx_guard = rx.lock().await;
|
|
||||||
executor.simulate_random_delay().await;
|
|
||||||
if let Some(message) = rx_guard.next().await {
|
|
||||||
drop(rx_guard);
|
|
||||||
Some((message, rx))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
|
||||||
Box::pin(futures::stream::empty())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue