context_store: Refactor state management (#29910)
Because we instantiated `ContextServerManager` both in `agent` and `assistant-context-editor`, and these two entities track the running MCP servers separately, we were effectively running every MCP server twice. This PR moves the `ContextServerManager` into the project crate (now called `ContextServerStore`). The store can be accessed via a project instance. This ensures that we only instantiate one `ContextServerStore` per project. Also, this PR adds a bunch of tests to ensure that the `ContextServerStore` behaves correctly (Previously there were none). Closes #28714 Closes #29530 Release Notes: - N/A
This commit is contained in:
parent
8199664a5a
commit
9cb5ffac25
43 changed files with 1570 additions and 1049 deletions
|
@ -13,28 +13,17 @@ path = "src/context_server.rs"
|
|||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-trait.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
context_server_settings.workspace = true
|
||||
extension.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
icons.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -40,7 +40,7 @@ pub enum RequestId {
|
|||
Str(String),
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
pub(crate) struct Client {
|
||||
server_id: ContextServerId,
|
||||
next_id: AtomicI32,
|
||||
outbound_tx: channel::Sender<String>,
|
||||
|
@ -59,7 +59,7 @@ pub struct Client {
|
|||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct ContextServerId(pub Arc<str>);
|
||||
pub(crate) struct ContextServerId(pub Arc<str>);
|
||||
|
||||
fn is_null_value<T: Serialize>(value: &T) -> bool {
|
||||
if let Ok(Value::Null) = serde_json::to_value(value) {
|
||||
|
@ -367,6 +367,7 @@ impl Client {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
|
@ -375,14 +376,6 @@ impl Client {
|
|||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn server_id(&self) -> ContextServerId {
|
||||
self.server_id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ContextServerId {
|
||||
|
|
|
@ -1,30 +1,117 @@
|
|||
pub mod client;
|
||||
mod context_server_tool;
|
||||
mod extension_context_server;
|
||||
pub mod manager;
|
||||
pub mod protocol;
|
||||
mod registry;
|
||||
mod transport;
|
||||
pub mod transport;
|
||||
pub mod types;
|
||||
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
pub use context_server_settings::{ContextServerSettings, ServerCommand, ServerConfig};
|
||||
use gpui::{App, actions};
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use crate::context_server_tool::ContextServerTool;
|
||||
pub use crate::registry::ContextServerDescriptorRegistry;
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use parking_lot::RwLock;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
actions!(context_servers, [Restart]);
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ContextServerId(pub Arc<str>);
|
||||
|
||||
/// The namespace for the context servers actions.
|
||||
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
context_server_settings::init(cx);
|
||||
ContextServerDescriptorRegistry::default_global(cx);
|
||||
extension_context_server::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
|
||||
});
|
||||
impl Display for ContextServerId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
|
||||
pub struct ContextServerCommand {
|
||||
pub path: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
enum ContextServerTransport {
|
||||
Stdio(ContextServerCommand),
|
||||
Custom(Arc<dyn crate::transport::Transport>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
id: ContextServerId,
|
||||
client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
configuration: ContextServerTransport,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
|
||||
Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
configuration: ContextServerTransport::Stdio(command),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
configuration: ContextServerTransport::Custom(transport),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> ContextServerId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = match &self.configuration {
|
||||
ContextServerTransport::Stdio(command) => Client::stdio(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?,
|
||||
ContextServerTransport::Custom(transport) => Client::new(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
self.id().0,
|
||||
transport.clone(),
|
||||
cx.clone(),
|
||||
)?,
|
||||
};
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||
|
||||
log::debug!(
|
||||
"context server {} initialized: {:?}",
|
||||
self.id,
|
||||
initialized_protocol.initialize,
|
||||
);
|
||||
|
||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop(&self) -> Result<()> {
|
||||
let mut client = self.client.write();
|
||||
if let Some(protocol) = client.take() {
|
||||
drop(protocol);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use icons::IconName;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
|
||||
use crate::manager::ContextServerManager;
|
||||
use crate::types;
|
||||
|
||||
pub struct ContextServerTool {
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server_id: Arc<str>,
|
||||
tool: types::Tool,
|
||||
}
|
||||
|
||||
impl ContextServerTool {
|
||||
pub fn new(
|
||||
server_manager: Entity<ContextServerManager>,
|
||||
server_id: impl Into<Arc<str>>,
|
||||
tool: types::Tool,
|
||||
) -> Self {
|
||||
Self {
|
||||
server_manager,
|
||||
server_id: server_id.into(),
|
||||
tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for ContextServerTool {
|
||||
fn name(&self) -> String {
|
||||
self.tool.name.clone()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
self.tool.description.clone().unwrap_or_default()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Cog
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
ToolSource::ContextServer {
|
||||
id: self.server_id.clone().into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
let mut schema = self.tool.input_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||
Ok(match schema {
|
||||
serde_json::Value::Null => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
serde_json::Value::Object(map) if map.is_empty() => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
_ => schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
format!("Run MCP tool `{}`", self.tool.name)
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
|
||||
let tool_name = self.tool.name.clone();
|
||||
let server_clone = server.clone();
|
||||
let input_clone = input.clone();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let Some(protocol) = server_clone.client() else {
|
||||
bail!("Context server not initialized");
|
||||
};
|
||||
|
||||
let arguments = if let serde_json::Value::Object(map) = input_clone {
|
||||
Some(map.into_iter().collect())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"Running tool: {} with arguments: {:?}",
|
||||
tool_name,
|
||||
arguments
|
||||
);
|
||||
let response = protocol.run_tool(tool_name, arguments).await?;
|
||||
|
||||
let mut result = String::new();
|
||||
for content in response.content {
|
||||
match content {
|
||||
types::ToolResponseContent::Text { text } => {
|
||||
result.push_str(&text);
|
||||
}
|
||||
types::ToolResponseContent::Image { .. } => {
|
||||
log::warn!("Ignoring image content from tool response");
|
||||
}
|
||||
types::ToolResponseContent::Resource { .. } => {
|
||||
log::warn!("Ignoring resource content from tool response");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
})
|
||||
.into()
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("Context server not found"))).into()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use extension::{
|
||||
ContextServerConfiguration, Extension, ExtensionContextServerProxy, ExtensionHostProxy,
|
||||
ProjectDelegate,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{ContextServerDescriptorRegistry, ServerCommand, registry};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
proxy.register_context_server_proxy(ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: ContextServerDescriptorRegistry::global(cx),
|
||||
});
|
||||
}
|
||||
|
||||
struct ExtensionProject {
|
||||
worktree_ids: Vec<u64>,
|
||||
}
|
||||
|
||||
impl ProjectDelegate for ExtensionProject {
|
||||
fn worktree_ids(&self) -> Vec<u64> {
|
||||
self.worktree_ids.clone()
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerDescriptor {
|
||||
id: Arc<str>,
|
||||
extension: Arc<dyn Extension>,
|
||||
}
|
||||
|
||||
fn extension_project(project: Entity<Project>, cx: &mut AsyncApp) -> Result<Arc<ExtensionProject>> {
|
||||
project.update(cx, |project, cx| {
|
||||
Arc::new(ExtensionProject {
|
||||
worktree_ids: project
|
||||
.visible_worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).id().to_proto())
|
||||
.collect(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
impl registry::ContextServerDescriptor for ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let mut command = extension
|
||||
.context_server_command(id.clone(), extension_project.clone())
|
||||
.await?;
|
||||
command.command = extension
|
||||
.path_from_extension(command.command.as_ref())
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
log::info!("loaded command for context server {id}: {command:?}");
|
||||
|
||||
Ok(ServerCommand {
|
||||
path: command.command,
|
||||
args: command.args,
|
||||
env: Some(command.env.into_iter().collect()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>> {
|
||||
let id = self.id.clone();
|
||||
let extension = self.extension.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let extension_project = extension_project(project, cx)?;
|
||||
let configuration = extension
|
||||
.context_server_configuration(id.clone(), extension_project)
|
||||
.await?;
|
||||
|
||||
log::debug!("loaded configuration for context server {id}: {configuration:?}");
|
||||
|
||||
Ok(configuration)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerDescriptorRegistryProxy {
|
||||
context_server_factory_registry: Entity<ContextServerDescriptorRegistry>,
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ContextServerDescriptorRegistryProxy {
|
||||
fn register_context_server(&self, extension: Arc<dyn Extension>, id: Arc<str>, cx: &mut App) {
|
||||
self.context_server_factory_registry
|
||||
.update(cx, |registry, _| {
|
||||
registry.register_context_server_descriptor(
|
||||
id.clone(),
|
||||
Arc::new(ContextServerDescriptor { id, extension })
|
||||
as Arc<dyn registry::ContextServerDescriptor>,
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
|
@ -1,584 +0,0 @@
|
|||
//! This module implements a context server management system for Zed.
|
||||
//!
|
||||
//! It provides functionality to:
|
||||
//! - Define and load context server settings
|
||||
//! - Manage individual context servers (start, stop, restart)
|
||||
//! - Maintain a global manager for all context servers
|
||||
//!
|
||||
//! Key components:
|
||||
//! - `ContextServerSettings`: Defines the structure for server configurations
|
||||
//! - `ContextServer`: Represents an individual context server
|
||||
//! - `ContextServerManager`: Manages multiple context servers
|
||||
//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
|
||||
//!
|
||||
//! The module also includes initialization logic to set up the context server system
|
||||
//! and react to changes in settings.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use collections::HashMap;
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use log;
|
||||
use parking_lot::RwLock;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::transport::Transport;
|
||||
use crate::{ContextServerSettings, ServerConfig};
|
||||
|
||||
use crate::{
|
||||
CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
|
||||
client::{self, Client},
|
||||
types,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ContextServerStatus {
|
||||
Starting,
|
||||
Running,
|
||||
Error(Arc<str>),
|
||||
}
|
||||
|
||||
pub struct ContextServer {
|
||||
pub id: Arc<str>,
|
||||
pub config: Arc<ServerConfig>,
|
||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||
transport: Option<Arc<dyn Transport>>,
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
|
||||
Self {
|
||||
id,
|
||||
config,
|
||||
client: RwLock::new(None),
|
||||
transport: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
config: Arc::new(ServerConfig::default()),
|
||||
transport: Some(transport),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Arc<str> {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
pub fn config(&self) -> Arc<ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = if let Some(transport) = self.transport.clone() {
|
||||
Client::new(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
self.id(),
|
||||
transport,
|
||||
cx.clone(),
|
||||
)?
|
||||
} else {
|
||||
let Some(command) = &self.config.command else {
|
||||
bail!("no command specified for server {}", self.id);
|
||||
};
|
||||
Client::stdio(
|
||||
client::ContextServerId(self.id.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
cx.clone(),
|
||||
)?
|
||||
};
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
log::info!("starting context server {}", self.id);
|
||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||
let client_info = types::Implementation {
|
||||
name: "Zed".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||
|
||||
log::debug!(
|
||||
"context server {} initialized: {:?}",
|
||||
self.id,
|
||||
initialized_protocol.initialize,
|
||||
);
|
||||
|
||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn stop(&self) -> Result<()> {
|
||||
let mut client = self.client.write();
|
||||
if let Some(protocol) = client.take() {
|
||||
drop(protocol);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ContextServerManager {
|
||||
servers: HashMap<Arc<str>, Arc<ContextServer>>,
|
||||
server_status: HashMap<Arc<str>, ContextServerStatus>,
|
||||
project: Entity<Project>,
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
update_servers_task: Option<Task<Result<()>>>,
|
||||
needs_server_update: bool,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
ServerStatusChanged {
|
||||
server_id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
},
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for ContextServerManager {}
|
||||
|
||||
impl ContextServerManager {
|
||||
pub fn new(
|
||||
registry: Entity<ContextServerDescriptorRegistry>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let mut this = Self {
|
||||
_subscriptions: vec![
|
||||
cx.observe(®istry, |this, _registry, cx| {
|
||||
this.available_context_servers_changed(cx);
|
||||
}),
|
||||
cx.observe_global::<SettingsStore>(|this, cx| {
|
||||
this.available_context_servers_changed(cx);
|
||||
}),
|
||||
],
|
||||
project,
|
||||
registry,
|
||||
needs_server_update: false,
|
||||
servers: HashMap::default(),
|
||||
server_status: HashMap::default(),
|
||||
update_servers_task: None,
|
||||
};
|
||||
this.available_context_servers_changed(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
|
||||
if self.update_servers_task.is_some() {
|
||||
self.needs_server_update = true;
|
||||
} else {
|
||||
self.update_servers_task = Some(cx.spawn(async move |this, cx| {
|
||||
this.update(cx, |this, _| {
|
||||
this.needs_server_update = false;
|
||||
})?;
|
||||
|
||||
if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
|
||||
log::error!("Error maintaining context servers: {}", err);
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let has_any_context_servers = !this.running_servers().is_empty();
|
||||
if has_any_context_servers {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
|
||||
});
|
||||
}
|
||||
|
||||
this.update_servers_task.take();
|
||||
if this.needs_server_update {
|
||||
this.available_context_servers_changed(cx);
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
|
||||
self.servers
|
||||
.get(id)
|
||||
.filter(|server| server.client().is_some())
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
|
||||
self.server_status.get(id).cloned()
|
||||
}
|
||||
|
||||
pub fn start_server(
|
||||
&self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
|
||||
}
|
||||
|
||||
pub fn stop_server(
|
||||
&mut self,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
server.stop().log_err();
|
||||
self.update_server_status(server.id().clone(), None, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let id = id.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
|
||||
let config = server.config();
|
||||
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx))??;
|
||||
let new_server = Arc::new(ContextServer::new(id.clone(), config));
|
||||
Self::run_server(this, new_server, cx).await?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
|
||||
self.servers.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
|
||||
self.servers
|
||||
.values()
|
||||
.filter(|server| server.client().is_some())
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
|
||||
let mut desired_servers = HashMap::default();
|
||||
|
||||
let (registry, project) = this.update(cx, |this, cx| {
|
||||
let location = this
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| settings::SettingsLocation {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: Path::new(""),
|
||||
});
|
||||
let settings = ContextServerSettings::get(location, cx);
|
||||
desired_servers = settings.context_servers.clone();
|
||||
|
||||
(this.registry.clone(), this.project.clone())
|
||||
})?;
|
||||
|
||||
for (id, descriptor) in
|
||||
registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
|
||||
{
|
||||
let config = desired_servers.entry(id).or_default();
|
||||
if config.command.is_none() {
|
||||
if let Some(extension_command) =
|
||||
descriptor.command(project.clone(), &cx).await.log_err()
|
||||
{
|
||||
config.command = Some(extension_command);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut servers_to_start = HashMap::default();
|
||||
let mut servers_to_stop = HashMap::default();
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.servers.retain(|id, server| {
|
||||
if desired_servers.contains_key(id) {
|
||||
true
|
||||
} else {
|
||||
servers_to_stop.insert(id.clone(), server.clone());
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
for (id, config) in desired_servers {
|
||||
let existing_config = this.servers.get(&id).map(|server| server.config());
|
||||
if existing_config.as_deref() != Some(&config) {
|
||||
let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
|
||||
servers_to_start.insert(id.clone(), server.clone());
|
||||
if let Some(old_server) = this.servers.remove(&id) {
|
||||
servers_to_stop.insert(id, old_server);
|
||||
}
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
for (_, server) in servers_to_stop {
|
||||
this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
|
||||
}
|
||||
|
||||
for (_, server) in servers_to_start {
|
||||
Self::run_server(this.clone(), server, cx).await.ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(
|
||||
this: WeakEntity<Self>,
|
||||
server: Arc<ContextServer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let id = server.id();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
|
||||
this.servers.insert(id.clone(), server.clone());
|
||||
})?;
|
||||
|
||||
match server.start(&cx).await {
|
||||
Ok(_) => {
|
||||
log::debug!("`{}` context server started", id);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("`{}` context server failed to start\n{}", id, err);
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_server_status(
|
||||
id.clone(),
|
||||
Some(ContextServerStatus::Error(err.to_string().into())),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_server_status(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
status: Option<ContextServerStatus>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(status) = status.clone() {
|
||||
self.server_status.insert(id.clone(), status);
|
||||
} else {
|
||||
self.server_status.remove(&id);
|
||||
}
|
||||
|
||||
cx.emit(Event::ServerStatusChanged {
|
||||
server_id: id,
|
||||
status,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::types::{
|
||||
Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use futures::{Stream, StreamExt as _, lock::Mutex};
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_context_server_status(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
let project = create_test_project(cx, json!({"code.rs": ""})).await;
|
||||
|
||||
let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
|
||||
let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
|
||||
|
||||
let server_1_id: Arc<str> = "mcp-1".into();
|
||||
let server_2_id: Arc<str> = "mcp-2".into();
|
||||
|
||||
let transport_1 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-1".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let transport_2 = Arc::new(FakeTransport::new(
|
||||
|_, request_type, _| match request_type {
|
||||
Some(RequestType::Initialize) => {
|
||||
Some(create_initialize_response("mcp-2".to_string()))
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
));
|
||||
|
||||
let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
|
||||
let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_1, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_2_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
});
|
||||
|
||||
manager
|
||||
.update(cx, |manager, cx| manager.stop_server(server_2, cx))
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
manager.read(cx).status_for_server(&server_1_id),
|
||||
Some(ContextServerStatus::Running)
|
||||
);
|
||||
assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
|
||||
});
|
||||
}
|
||||
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn create_initialize_response(server_name: String) -> serde_json::Value {
|
||||
serde_json::to_value(&InitializeResponse {
|
||||
protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
|
||||
server_info: Implementation {
|
||||
name: server_name,
|
||||
version: "1.0.0".to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities::default(),
|
||||
meta: None,
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
struct FakeTransport {
|
||||
on_request: Arc<
|
||||
dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>,
|
||||
tx: futures::channel::mpsc::UnboundedSender<String>,
|
||||
rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
|
||||
}
|
||||
|
||||
impl FakeTransport {
|
||||
fn new(
|
||||
on_request: impl Fn(
|
||||
u64,
|
||||
Option<RequestType>,
|
||||
serde_json::Value,
|
||||
) -> Option<serde_json::Value>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync,
|
||||
) -> Self {
|
||||
let (tx, rx) = futures::channel::mpsc::unbounded();
|
||||
Self {
|
||||
on_request: Arc::new(on_request),
|
||||
tx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Transport for FakeTransport {
|
||||
async fn send(&self, message: String) -> Result<()> {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
|
||||
let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
|
||||
|
||||
if let Some(method) = msg.get("method") {
|
||||
let request_type = method
|
||||
.as_str()
|
||||
.and_then(|method| types::RequestType::try_from(method).ok());
|
||||
if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
|
||||
let response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": payload
|
||||
});
|
||||
|
||||
self.tx
|
||||
.unbounded_send(response.to_string())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
let rx = self.rx.clone();
|
||||
Box::pin(futures::stream::unfold(rx, |rx| async move {
|
||||
let mut rx_guard = rx.lock().await;
|
||||
if let Some(message) = rx_guard.next().await {
|
||||
drop(rx_guard);
|
||||
Some((message, rx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
|
||||
Box::pin(futures::stream::empty())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,7 +16,7 @@ pub struct ModelContextProtocol {
|
|||
}
|
||||
|
||||
impl ModelContextProtocol {
|
||||
pub fn new(inner: Client) -> Self {
|
||||
pub(crate) fn new(inner: Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use extension::ContextServerConfiguration;
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Global, ReadGlobal, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::ServerCommand;
|
||||
|
||||
pub trait ContextServerDescriptor {
|
||||
fn command(&self, project: Entity<Project>, cx: &AsyncApp) -> Task<Result<ServerCommand>>;
|
||||
fn configuration(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<ContextServerConfiguration>>>;
|
||||
}
|
||||
|
||||
struct GlobalContextServerDescriptorRegistry(Entity<ContextServerDescriptorRegistry>);
|
||||
|
||||
impl Global for GlobalContextServerDescriptorRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContextServerDescriptorRegistry {
|
||||
context_servers: HashMap<Arc<str>, Arc<dyn ContextServerDescriptor>>,
|
||||
}
|
||||
|
||||
impl ContextServerDescriptorRegistry {
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
GlobalContextServerDescriptorRegistry::global(cx).0.clone()
|
||||
}
|
||||
|
||||
/// Returns the global [`ContextServerDescriptorRegistry`].
|
||||
///
|
||||
/// Inserts a default [`ContextServerDescriptorRegistry`] if one does not yet exist.
|
||||
pub fn default_global(cx: &mut App) -> Entity<Self> {
|
||||
if !cx.has_global::<GlobalContextServerDescriptorRegistry>() {
|
||||
let registry = cx.new(|_| Self::new());
|
||||
cx.set_global(GlobalContextServerDescriptorRegistry(registry));
|
||||
}
|
||||
cx.global::<GlobalContextServerDescriptorRegistry>()
|
||||
.0
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
context_servers: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_server_descriptors(&self) -> Vec<(Arc<str>, Arc<dyn ContextServerDescriptor>)> {
|
||||
self.context_servers
|
||||
.iter()
|
||||
.map(|(id, factory)| (id.clone(), factory.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn context_server_descriptor(&self, id: &str) -> Option<Arc<dyn ContextServerDescriptor>> {
|
||||
self.context_servers.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Registers the provided [`ContextServerDescriptor`].
|
||||
pub fn register_context_server_descriptor(
|
||||
&mut self,
|
||||
id: Arc<str>,
|
||||
descriptor: Arc<dyn ContextServerDescriptor>,
|
||||
) {
|
||||
self.context_servers.insert(id, descriptor);
|
||||
}
|
||||
|
||||
/// Unregisters the [`ContextServerDescriptor`] for the server with the given ID.
|
||||
pub fn unregister_context_server_descriptor_by_id(&mut self, server_id: &str) {
|
||||
self.context_servers.remove(server_id);
|
||||
}
|
||||
}
|
|
@ -610,7 +610,7 @@ pub enum ToolResponseContent {
|
|||
Resource { resource: ResourceContents },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ListToolsResponse {
|
||||
pub tools: Vec<Tool>,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue