ZIm/crates/context_server/src/context_server.rs
Finn Evers e85c466632 Ensure context servers are spawned in the workspace directory (#35271)
This fixes an issue where we were not setting the context server working
directory at all.

Release Notes:

- Context servers will now be spawned in the currently active project
root.

---------

Co-authored-by: Danilo Leal <danilo@zed.dev>
2025-07-29 18:11:37 +02:00

146 lines
4.3 KiB
Rust

pub mod client;
pub mod listener;
pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
pub mod transport;
pub mod types;
use std::path::Path;
use std::sync::Arc;
use std::{fmt::Display, path::PathBuf};
use anyhow::Result;
use client::Client;
use collections::HashMap;
use gpui::AsyncApp;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use util::redact::should_redact;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ContextServerId(pub Arc<str>);
impl Display for ContextServerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
pub struct ContextServerCommand {
#[serde(rename = "command")]
pub path: PathBuf,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
impl std::fmt::Debug for ContextServerCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let filtered_env = self.env.as_ref().map(|env| {
env.iter()
.map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
.collect::<Vec<_>>()
});
f.debug_struct("ContextServerCommand")
.field("path", &self.path)
.field("args", &self.args)
.field("env", &filtered_env)
.finish()
}
}
enum ContextServerTransport {
Stdio(ContextServerCommand, Option<PathBuf>),
Custom(Arc<dyn crate::transport::Transport>),
}
pub struct ContextServer {
id: ContextServerId,
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
configuration: ContextServerTransport,
}
impl ContextServer {
pub fn stdio(
id: ContextServerId,
command: ContextServerCommand,
working_directory: Option<Arc<Path>>,
) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Stdio(
command,
working_directory.map(|directory| directory.to_path_buf()),
),
}
}
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Custom(transport),
}
}
pub fn id(&self) -> ContextServerId {
self.id.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
working_directory,
cx.clone(),
)?,
ContextServerTransport::Custom(transport) => Client::new(
client::ContextServerId(self.id.0.clone()),
self.id().0,
transport.clone(),
cx.clone(),
)?,
};
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}