context_store: Refactor state management (#29910)

Because we instantiated `ContextServerManager` both in `agent` and
`assistant-context-editor`, and these two entities track the running MCP
servers separately, we were effectively running every MCP server twice.

This PR moves the `ContextServerManager` into the project crate (now
called `ContextServerStore`). The store can be accessed via a project
instance. This ensures that we only instantiate one `ContextServerStore`
per project.

Also, this PR adds a bunch of tests to ensure that the
`ContextServerStore` behaves correctly (Previously there were none).

Closes #28714
Closes #29530

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-05-05 21:36:12 +02:00 committed by GitHub
parent 8199664a5a
commit 9cb5ffac25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 1570 additions and 1049 deletions

View file

@ -1,30 +1,117 @@
pub mod client;
mod context_server_tool;
mod extension_context_server;
pub mod manager;
pub mod protocol;
mod registry;
mod transport;
pub mod transport;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;
pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
use gpui::{App, actions};
use std::fmt::Display;
use std::path::Path;
use std::sync::Arc;
pub use crate::context_server_tool::ContextServerTool;
pub use crate::registry::ContextServerDescriptorRegistry;
use anyhow::Result;
use client::Client;
use collections::HashMap;
use gpui::AsyncApp;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
actions!(context_servers, [Restart]);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ContextServerId(pub Arc<str>);
/// The namespace for the context servers actions.
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut App) {
context_server_settings::init(cx);
ContextServerDescriptorRegistry::default_global(cx);
extension_context_server::init(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
});
impl Display for ContextServerId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
pub struct ContextServerCommand {
pub path: String,
pub args: Vec<String>,
pub env: Option<HashMap<String, String>>,
}
enum ContextServerTransport {
Stdio(ContextServerCommand),
Custom(Arc<dyn crate::transport::Transport>),
}
pub struct ContextServer {
id: ContextServerId,
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
configuration: ContextServerTransport,
}
impl ContextServer {
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Stdio(command),
}
}
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Custom(transport),
}
}
pub fn id(&self) -> ContextServerId {
self.id.clone()
}
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration {
ContextServerTransport::Stdio(command) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
cx.clone(),
)?,
ContextServerTransport::Custom(transport) => Client::new(
client::ContextServerId(self.id.0.clone()),
self.id().0,
transport.clone(),
cx.clone(),
)?,
};
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {
log::info!("starting context server {}", self.id);
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
pub fn stop(&self) -> Result<()> {
let mut client = self.client.write();
if let Some(protocol) = client.take() {
drop(protocol);
}
Ok(())
}
}