context server: Make requests type safe (#32254)

This changes the context server crate so that the input/output for a
request are encoded at the type level, similar to how it is done for LSP
requests.

This also makes it easier to write tests that mock context servers, e.g.
you can write something like this now when using the `test-support`
feature of the `context-server` crate:

```rust
create_fake_transport("mcp-1", cx.background_executor())
    .on_request::<context_server::types::request::PromptsList>(|_params| {
        PromptsListResponse {
            prompts: vec![/* some prompts */],
            ..
        }
    })
```

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-06 17:47:21 +02:00 committed by GitHub
parent 454adfacae
commit 95d78ff8d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 320 additions and 433 deletions

View file

@ -91,6 +91,7 @@ workspace-hack.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
context_server = { workspace = true, features = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] }
dap = { workspace = true, features = ["test-support"] }
dap_adapters = { workspace = true, features = ["test-support"] }

View file

@ -499,17 +499,10 @@ impl ContextServerStore {
mod tests {
use super::*;
use crate::{FakeFs, Project, project_settings::ProjectSettings};
use context_server::{
transport::Transport,
types::{
self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
ServerCapabilities,
},
};
use futures::{Stream, StreamExt as _, lock::Mutex};
use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
use context_server::test::create_fake_transport;
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use serde_json::json;
use std::{cell::RefCell, pin::Pin, rc::Rc};
use std::{cell::RefCell, rc::Rc};
use util::path;
#[gpui::test]
@ -532,33 +525,17 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
});
let server_1_id = ContextServerId("mcp-1".into());
let server_2_id = ContextServerId("mcp-2".into());
let server_1_id = ContextServerId(SERVER_1_ID.into());
let server_2_id = ContextServerId(SERVER_2_ID.into());
let transport_1 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
let server_1 = Arc::new(ContextServer::new(
server_1_id.clone(),
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
));
let server_2 = Arc::new(ContextServer::new(
server_2_id.clone(),
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
));
store
.update(cx, |store, cx| store.start_server(server_1, cx))
@ -627,33 +604,17 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
});
let server_1_id = ContextServerId("mcp-1".into());
let server_2_id = ContextServerId("mcp-2".into());
let server_1_id = ContextServerId(SERVER_1_ID.into());
let server_2_id = ContextServerId(SERVER_2_ID.into());
let transport_1 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-1".to_string()))
}
_ => None,
},
));
let transport_2 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response("mcp-2".to_string()))
}
_ => None,
},
));
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
let server_1 = Arc::new(ContextServer::new(
server_1_id.clone(),
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
));
let server_2 = Arc::new(ContextServer::new(
server_2_id.clone(),
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
));
let _server_events = assert_server_events(
&store,
@ -702,30 +663,14 @@ mod tests {
let server_id = ContextServerId(SERVER_1_ID.into());
let transport_1 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response(SERVER_1_ID.to_string()))
}
_ => None,
},
));
let transport_2 =
Arc::new(FakeTransport::new(
cx.executor(),
|_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response(SERVER_1_ID.to_string()))
}
_ => None,
},
));
let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
let server_with_same_id_1 = Arc::new(ContextServer::new(
server_id.clone(),
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
));
let server_with_same_id_2 = Arc::new(ContextServer::new(
server_id.clone(),
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
));
// If we start another server with the same id, we should report that we stopped the previous one
let _server_events = assert_server_events(
@ -794,16 +739,10 @@ mod tests {
let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop(
Box::new(move |id, _| {
let transport = FakeTransport::new(executor.clone(), {
let id = id.0.clone();
move |_, request_type, _| match request_type {
Some(RequestType::Initialize) => {
Some(create_initialize_response(id.clone().to_string()))
}
_ => None,
}
});
Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
Arc::new(ContextServer::new(
id.clone(),
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
))
}),
registry.clone(),
project.read(cx).worktree_store(),
@ -1033,99 +972,4 @@ mod tests {
(fs, project)
}
fn create_initialize_response(server_name: String) -> serde_json::Value {
serde_json::to_value(&InitializeResponse {
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
server_info: Implementation {
name: server_name,
version: "1.0.0".to_string(),
},
capabilities: ServerCapabilities::default(),
meta: None,
})
.unwrap()
}
struct FakeTransport {
on_request: Arc<
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
+ Send
+ Sync,
>,
tx: futures::channel::mpsc::UnboundedSender<String>,
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
executor: BackgroundExecutor,
}
impl FakeTransport {
fn new(
executor: BackgroundExecutor,
on_request: impl Fn(
u64,
Option<RequestType>,
serde_json::Value,
) -> Option<serde_json::Value>
+ 'static
+ Send
+ Sync,
) -> Self {
let (tx, rx) = futures::channel::mpsc::unbounded();
Self {
on_request: Arc::new(on_request),
tx,
rx: Arc::new(Mutex::new(rx)),
executor,
}
}
}
#[async_trait::async_trait]
impl Transport for FakeTransport {
async fn send(&self, message: String) -> Result<()> {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if let Some(method) = msg.get("method") {
let request_type = method
.as_str()
.and_then(|method| types::RequestType::try_from(method).ok());
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": payload
});
self.tx
.unbounded_send(response.to_string())
.context("sending a message")?;
}
}
}
Ok(())
}
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
let rx = self.rx.clone();
let executor = self.executor.clone();
Box::pin(futures::stream::unfold(rx, move |rx| {
let executor = executor.clone();
async move {
let mut rx_guard = rx.lock().await;
executor.simulate_random_delay().await;
if let Some(message) = rx_guard.next().await {
drop(rx_guard);
Some((message, rx))
} else {
None
}
}
}))
}
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
Box::pin(futures::stream::empty())
}
}
}