context server: Make requests type safe (#32254)
This changes the context server crate so that the input/output for a request are encoded at the type level, similar to how it is done for LSP requests. This also makes it easier to write tests that mock context servers, e.g. you can write something like this now when using the `test-support` feature of the `context-server` crate: ```rust create_fake_transport("mcp-1", cx.background_executor()) .on_request::<context_server::types::request::PromptsList>(|_params| { PromptsListResponse { prompts: vec![/* some prompts */], .. } }) ``` Release Notes: - N/A
This commit is contained in:
parent
454adfacae
commit
95d78ff8d5
11 changed files with 320 additions and 433 deletions
|
@ -91,6 +91,7 @@ workspace-hack.workspace = true
|
|||
[dev-dependencies]
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
context_server = { workspace = true, features = ["test-support"] }
|
||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||
dap = { workspace = true, features = ["test-support"] }
|
||||
dap_adapters = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -499,17 +499,10 @@ impl ContextServerStore {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{FakeFs, Project, project_settings::ProjectSettings};
|
||||
use context_server::{
|
||||
transport::Transport,
|
||||
types::{
|
||||
self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
|
||||
ServerCapabilities,
|
||||
},
|
||||
};
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
|
||||
use context_server::test::create_fake_transport;
|
||||
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
|
||||
use serde_json::json;
|
||||
use std::{cell::RefCell, pin::Pin, rc::Rc};
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -532,33 +525,17 @@ mod tests {
|
|||
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
||||
});
|
||||
|
||||
let server_1_id = ContextServerId("mcp-1".into());
|
||||
let server_2_id = ContextServerId("mcp-2".into());
|
||||
let server_1_id = ContextServerId(SERVER_1_ID.into());
|
||||
let server_2_id = ContextServerId(SERVER_2_ID.into());
|
||||
|
||||
let transport_1 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
|
||||
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
|
||||
let server_1 = Arc::new(ContextServer::new(
|
||||
server_1_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||
));
|
||||
let server_2 = Arc::new(ContextServer::new(
|
||||
server_2_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
|
||||
));
|
||||
|
||||
store
|
||||
.update(cx, |store, cx| store.start_server(server_1, cx))
|
||||
|
@ -627,33 +604,17 @@ mod tests {
|
|||
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
|
||||
});
|
||||
|
||||
let server_1_id = ContextServerId("mcp-1".into());
|
||||
let server_2_id = ContextServerId("mcp-2".into());
|
||||
let server_1_id = ContextServerId(SERVER_1_ID.into());
|
||||
let server_2_id = ContextServerId(SERVER_2_ID.into());
|
||||
|
||||
let transport_1 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
|
||||
let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
|
||||
let server_1 = Arc::new(ContextServer::new(
|
||||
server_1_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||
));
|
||||
let server_2 = Arc::new(ContextServer::new(
|
||||
server_2_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
|
||||
));
|
||||
|
||||
let _server_events = assert_server_events(
|
||||
&store,
|
||||
|
@ -702,30 +663,14 @@ mod tests {
|
|||
|
||||
let server_id = ContextServerId(SERVER_1_ID.into());
|
||||
|
||||
let transport_1 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response(SERVER_1_ID.to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 =
|
||||
Arc::new(FakeTransport::new(
|
||||
cx.executor(),
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response(SERVER_1_ID.to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
|
||||
let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
|
||||
let server_with_same_id_1 = Arc::new(ContextServer::new(
|
||||
server_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||
));
|
||||
let server_with_same_id_2 = Arc::new(ContextServer::new(
|
||||
server_id.clone(),
|
||||
Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
|
||||
));
|
||||
|
||||
// If we start another server with the same id, we should report that we stopped the previous one
|
||||
let _server_events = assert_server_events(
|
||||
|
@ -794,16 +739,10 @@ mod tests {
|
|||
let store = cx.new(|cx| {
|
||||
ContextServerStore::test_maintain_server_loop(
|
||||
Box::new(move |id, _| {
|
||||
let transport = FakeTransport::new(executor.clone(), {
|
||||
let id = id.0.clone();
|
||||
move |_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response(id.clone().to_string()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
|
||||
Arc::new(ContextServer::new(
|
||||
id.clone(),
|
||||
Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
|
||||
))
|
||||
}),
|
||||
registry.clone(),
|
||||
project.read(cx).worktree_store(),
|
||||
|
@ -1033,99 +972,4 @@ mod tests {
|
|||
|
||||
(fs, project)
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
||||
serde_json::to_value(&InitializeResponse {
|
||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
server_info: Implementation {
|
||||
name: server_name,
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities::default(),
|
||||
meta: None,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
struct FakeTransport {
|
||||
on_request: Arc<
|
||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
impl FakeTransport {
|
||||
fn new(
|
||||
executor: BackgroundExecutor,
|
||||
on_request: impl Fn(
|
||||
u64,
|
||||
Option<RequestType>,
|
||||
serde_json::Value,
|
||||
) -> Option<serde_json::Value>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self {
|
||||
on_request: Arc::new(on_request),
|
||||
tx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
executor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FakeTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||
|
||||
if let Some(method) = msg.get("method") {
|
||||
let request_type = method
|
||||
.as_str()
|
||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": payload
|
||||
});
|
||||
|
||||
self.tx
|
||||
.unbounded_send(response.to_string())
|
||||
.context("sending a message")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
let rx = self.rx.clone();
|
||||
let executor = self.executor.clone();
|
||||
Box::pin(futures::stream::unfold(rx, move |rx| {
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
executor.simulate_random_delay().await;
|
||||
if let Some(message) = rx_guard.next().await {
|
||||
drop(rx_guard);
|
||||
Some((message, rx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(futures::stream::empty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue