pub mod client; pub mod protocol; #[cfg(any(test, feature = "test-support"))] pub mod test; pub mod transport; pub mod types; use std::fmt::Display; use std::path::Path; use std::sync::Arc; use anyhow::Result; use client::Client; use collections::HashMap; use gpui::AsyncApp; use parking_lot::RwLock; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use util::redact::should_redact; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ContextServerId(pub Arc); impl Display for ContextServerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)] pub struct ContextServerCommand { #[serde(rename = "command")] pub path: String, pub args: Vec, pub env: Option>, } impl std::fmt::Debug for ContextServerCommand { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let filtered_env = self.env.as_ref().map(|env| { env.iter() .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v })) .collect::>() }); f.debug_struct("ContextServerCommand") .field("path", &self.path) .field("args", &self.args) .field("env", &filtered_env) .finish() } } enum ContextServerTransport { Stdio(ContextServerCommand), Custom(Arc), } pub struct ContextServer { id: ContextServerId, client: RwLock>>, configuration: ContextServerTransport, } impl ContextServer { pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self { Self { id, client: RwLock::new(None), configuration: ContextServerTransport::Stdio(command), } } pub fn new(id: ContextServerId, transport: Arc) -> Self { Self { id, client: RwLock::new(None), configuration: ContextServerTransport::Custom(transport), } } pub fn id(&self) -> ContextServerId { self.id.clone() } pub fn client(&self) -> Option> { self.client.read().clone() } pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { let client = match &self.configuration { ContextServerTransport::Stdio(command) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { executable: Path::new(&command.path).to_path_buf(), args: command.args.clone(), env: command.env.clone(), }, cx.clone(), )?, ContextServerTransport::Custom(transport) => Client::new( client::ContextServerId(self.id.0.clone()), self.id().0, transport.clone(), cx.clone(), )?, }; self.initialize(client).await } async fn initialize(&self, client: Client) -> Result<()> { log::info!("starting context server {}", self.id); let protocol = crate::protocol::ModelContextProtocol::new(client); let client_info = types::Implementation { name: "Zed".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), }; let initialized_protocol = protocol.initialize(client_info).await?; log::debug!( "context server {} initialized: {:?}", self.id, initialized_protocol.initialize, ); *self.client.write() = Some(Arc::new(initialized_protocol)); Ok(()) } pub fn stop(&self) -> Result<()> { let mut client = self.client.write(); if let Some(protocol) = client.take() { drop(protocol); } Ok(()) } }