ZIm/crates/context_server/src/context_server.rs
Ben Brandt 81c111510f
Refactor handling of ContextServer notifications
The notification handler registration is now more explicit, with
handlers set up before server initialization to avoid potential race
conditions.
2025-07-30 15:48:40 +02:00

165 lines
4.9 KiB
Rust

pub mod client;
pub mod listener;
pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
pub mod transport;
pub mod types;
use std::path::Path;
use std::sync::Arc;
use std::{fmt::Display, path::PathBuf};
use anyhow::Result;
use client::Client;
use collections::HashMap;
use gpui::AsyncApp;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::redact::should_redact;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ContextServerId(pub Arc<str>);
impl Display for ContextServerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
pub struct ContextServerCommand {
#[serde(rename = "command")]
pub path: PathBuf,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
impl std::fmt::Debug for ContextServerCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let filtered_env = self.env.as_ref().map(|env| {
env.iter()
.map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
.collect::<Vec<_>>()
});
f.debug_struct("ContextServerCommand")
.field("path", &self.path)
.field("args", &self.args)
.field("env", &filtered_env)
.finish()
}
}
enum ContextServerTransport {
Stdio(ContextServerCommand, Option<PathBuf>),
Custom(Arc<dyn crate::transport::Transport>),
}
pub struct ContextServer {
id: ContextServerId,
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
configuration: ContextServerTransport,
}
impl ContextServer {
pub fn stdio(
id: ContextServerId,
command: ContextServerCommand,
working_directory: Option<Arc<Path>>,
) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Stdio(
command,
working_directory.map(|directory| directory.to_path_buf()),
),
}
}
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Custom(transport),
}
}
pub fn id(&self) -> ContextServerId {
self.id.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
self.initialize(self.new_client(cx)?).await
}
/// Starts the context server, making sure handlers are registered before initialization happens
pub async fn start_with_handlers(
&self,
notification_handlers: Vec<(
&'static str,
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
)>,
cx: &AsyncApp,
) -> Result<()> {
let client = self.new_client(cx)?;
for (method, handler) in notification_handlers {
client.on_notification(method, handler);
}
self.initialize(client).await
}
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
working_directory,
cx.clone(),
)?,
ContextServerTransport::Custom(transport) => Client::new(
client::ContextServerId(self.id.0.clone()),
self.id().0,
transport.clone(),
cx.clone(),
)?,
})
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}