use anyhow::Context as _; use collections::HashMap; use futures::{Stream, StreamExt as _, lock::Mutex}; use gpui::BackgroundExecutor; use std::{pin::Pin, sync::Arc}; use crate::{ transport::Transport, types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities}, }; pub fn create_fake_transport( name: impl Into, executor: BackgroundExecutor, ) -> FakeTransport { let name = name.into(); FakeTransport::new(executor).on_request::(move |_params| { create_initialize_response(name.clone()) }) } fn create_initialize_response(server_name: String) -> InitializeResponse { InitializeResponse { protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()), server_info: Implementation { name: server_name, version: "1.0.0".to_string(), }, capabilities: ServerCapabilities::default(), meta: None, } } pub struct FakeTransport { request_handlers: HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, tx: futures::channel::mpsc::UnboundedSender, rx: Arc>>, executor: BackgroundExecutor, } impl FakeTransport { pub fn new(executor: BackgroundExecutor) -> Self { let (tx, rx) = futures::channel::mpsc::unbounded(); Self { request_handlers: Default::default(), tx, rx: Arc::new(Mutex::new(rx)), executor, } } pub fn on_request( mut self, handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, ) -> Self { self.request_handlers.insert( T::METHOD, Arc::new(move |value| { let params = value.get("params").expect("Missing parameters").clone(); let params: T::Params = serde_json::from_value(params).expect("Invalid parameters received"); let response = handler(params); serde_json::to_value(response).unwrap() }), ); self } } #[async_trait::async_trait] impl Transport for FakeTransport { async fn send(&self, message: String) -> anyhow::Result<()> { if let Ok(msg) = serde_json::from_str::(&message) { let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); if let Some(method) = msg.get("method") { let method = method.as_str().expect("Invalid method received"); if let Some(handler) = self.request_handlers.get(method) { let payload = handler(msg); let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": payload }); self.tx .unbounded_send(response.to_string()) .context("sending a message")?; } else { log::debug!("No handler registered for MCP request '{method}'"); } } } Ok(()) } fn receive(&self) -> Pin + Send>> { let rx = self.rx.clone(); let executor = self.executor.clone(); Box::pin(futures::stream::unfold(rx, move |rx| { let executor = executor.clone(); async move { let mut rx_guard = rx.lock().await; executor.simulate_random_delay().await; if let Some(message) = rx_guard.next().await { drop(rx_guard); Some((message, rx)) } else { None } } })) } fn receive_err(&self) -> Pin + Send>> { Box::pin(futures::stream::empty()) } }