Put context servers behind a trait (#20432)

This PR puts context servers behind the `ContextServer` trait to allow
us to provide context servers from an extension.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-08 13:36:41 -05:00 committed by GitHub
parent 01503511ad
commit 09c599385a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 102 additions and 61 deletions

1
Cargo.lock generated
View file

@ -2815,6 +2815,7 @@ name = "context_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"futures 0.3.30", "futures 0.3.30",

View file

@ -819,7 +819,7 @@ impl ContextStore {
|context_server_manager, cx| { |context_server_manager, cx| {
for server in context_server_manager.servers() { for server in context_server_manager.servers() {
context_server_manager context_server_manager
.restart_server(&server.id, cx) .restart_server(&server.id(), cx)
.detach_and_log_err(cx); .detach_and_log_err(cx);
} }
}, },
@ -850,7 +850,7 @@ impl ContextStore {
let server = server.clone(); let server = server.clone();
let server_id = server_id.clone(); let server_id = server_id.clone();
|this, mut cx| async move { |this, mut cx| async move {
let Some(protocol) = server.client.read().clone() else { let Some(protocol) = server.client() else {
return; return;
}; };
@ -889,7 +889,7 @@ impl ContextStore {
tool_working_set.insert( tool_working_set.insert(
Arc::new(tools::context_server_tool::ContextServerTool::new( Arc::new(tools::context_server_tool::ContextServerTool::new(
context_server_manager.clone(), context_server_manager.clone(),
server.id.clone(), server.id(),
tool, tool,
)), )),
) )

View file

@ -20,18 +20,18 @@ use crate::slash_command::create_label_for_command;
pub struct ContextServerSlashCommand { pub struct ContextServerSlashCommand {
server_manager: Model<ContextServerManager>, server_manager: Model<ContextServerManager>,
server_id: String, server_id: Arc<str>,
prompt: Prompt, prompt: Prompt,
} }
impl ContextServerSlashCommand { impl ContextServerSlashCommand {
pub fn new( pub fn new(
server_manager: Model<ContextServerManager>, server_manager: Model<ContextServerManager>,
server: &Arc<ContextServer>, server: &Arc<dyn ContextServer>,
prompt: Prompt, prompt: Prompt,
) -> Self { ) -> Self {
Self { Self {
server_id: server.id.clone(), server_id: server.id(),
prompt, prompt,
server_manager, server_manager,
} }
@ -89,7 +89,7 @@ impl SlashCommand for ContextServerSlashCommand {
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) { if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client.read().clone() else { let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized")); return Err(anyhow!("Context server not initialized"));
}; };
@ -143,7 +143,7 @@ impl SlashCommand for ContextServerSlashCommand {
let manager = self.server_manager.read(cx); let manager = self.server_manager.read(cx);
if let Some(server) = manager.get_server(&server_id) { if let Some(server) = manager.get_server(&server_id) {
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client.read().clone() else { let Some(protocol) = server.client() else {
return Err(anyhow!("Context server not initialized")); return Err(anyhow!("Context server not initialized"));
}; };
let result = protocol.run_prompt(&prompt_name, prompt_args).await?; let result = protocol.run_prompt(&prompt_name, prompt_args).await?;

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use anyhow::{anyhow, bail}; use anyhow::{anyhow, bail};
use assistant_tool::Tool; use assistant_tool::Tool;
use context_servers::manager::ContextServerManager; use context_servers::manager::ContextServerManager;
@ -6,14 +8,14 @@ use gpui::{Model, Task};
pub struct ContextServerTool { pub struct ContextServerTool {
server_manager: Model<ContextServerManager>, server_manager: Model<ContextServerManager>,
server_id: String, server_id: Arc<str>,
tool: types::Tool, tool: types::Tool,
} }
impl ContextServerTool { impl ContextServerTool {
pub fn new( pub fn new(
server_manager: Model<ContextServerManager>, server_manager: Model<ContextServerManager>,
server_id: impl Into<String>, server_id: impl Into<Arc<str>>,
tool: types::Tool, tool: types::Tool,
) -> Self { ) -> Self {
Self { Self {
@ -55,7 +57,7 @@ impl Tool for ContextServerTool {
cx.foreground_executor().spawn({ cx.foreground_executor().spawn({
let tool_name = self.tool.name.clone(); let tool_name = self.tool.name.clone();
async move { async move {
let Some(protocol) = server.client.read().clone() else { let Some(protocol) = server.client() else {
bail!("Context server not initialized"); bail!("Context server not initialized");
}; };

View file

@ -13,6 +13,7 @@ path = "src/context_servers.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-trait.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
futures.workspace = true futures.workspace = true

View file

@ -15,9 +15,13 @@
//! and react to changes in settings. //! and react to changes in settings.
use std::path::Path; use std::path::Path;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use futures::{Future, FutureExt};
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task}; use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
use log; use log;
use parking_lot::RwLock; use parking_lot::RwLock;
@ -56,51 +60,84 @@ impl Settings for ContextServerSettings {
} }
} }
pub struct ContextServer { #[async_trait(?Send)]
pub id: String, pub trait ContextServer: Send + Sync + 'static {
pub config: ServerConfig, fn id(&self) -> Arc<str>;
fn config(&self) -> Arc<ServerConfig>;
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
fn start<'a>(
self: Arc<Self>,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
fn stop(&self) -> Result<()>;
}
pub struct NativeContextServer {
pub id: Arc<str>,
pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>, pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
} }
impl ContextServer { impl NativeContextServer {
fn new(config: ServerConfig) -> Self { fn new(config: Arc<ServerConfig>) -> Self {
Self { Self {
id: config.id.clone(), id: config.id.clone().into(),
config, config,
client: RwLock::new(None), client: RwLock::new(None),
} }
} }
}
async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> { #[async_trait(?Send)]
log::info!("starting context server {}", self.config.id,); impl ContextServer for NativeContextServer {
let client = Client::new( fn id(&self) -> Arc<str> {
client::ContextServerId(self.config.id.clone()), self.id.clone()
client::ModelContextServerBinary {
executable: Path::new(&self.config.executable).to_path_buf(),
args: self.config.args.clone(),
env: self.config.env.clone(),
},
cx.clone(),
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.config.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
} }
async fn stop(&self) -> anyhow::Result<()> { fn config(&self) -> Arc<ServerConfig> {
self.config.clone()
}
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone()
}
fn start<'a>(
self: Arc<Self>,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
async move {
log::info!("starting context server {}", self.config.id,);
let client = Client::new(
client::ContextServerId(self.config.id.clone()),
client::ModelContextServerBinary {
executable: Path::new(&self.config.executable).to_path_buf(),
args: self.config.args.clone(),
env: self.config.env.clone(),
},
cx.clone(),
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation {
name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!(
"context server {} initialized: {:?}",
self.config.id,
initialized_protocol.initialize,
);
*self.client.write() = Some(Arc::new(initialized_protocol));
Ok(())
}
.boxed_local()
}
fn stop(&self) -> Result<()> {
let mut client = self.client.write(); let mut client = self.client.write();
if let Some(protocol) = client.take() { if let Some(protocol) = client.take() {
drop(protocol); drop(protocol);
@ -114,7 +151,7 @@ impl ContextServer {
/// must go through the `GlobalContextServerManager` which holds /// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager. /// a model to the ContextServerManager.
pub struct ContextServerManager { pub struct ContextServerManager {
servers: HashMap<String, Arc<ContextServer>>, servers: HashMap<String, Arc<dyn ContextServer>>,
pending_servers: HashSet<String>, pending_servers: HashSet<String>,
} }
@ -141,7 +178,7 @@ impl ContextServerManager {
pub fn add_server( pub fn add_server(
&mut self, &mut self,
config: ServerConfig, config: Arc<ServerConfig>,
cx: &ModelContext<Self>, cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> { ) -> Task<anyhow::Result<()>> {
let server_id = config.id.clone(); let server_id = config.id.clone();
@ -153,8 +190,8 @@ impl ContextServerManager {
let task = { let task = {
let server_id = server_id.clone(); let server_id = server_id.clone();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
let server = Arc::new(ContextServer::new(config)); let server = Arc::new(NativeContextServer::new(config));
server.start(&cx).await?; server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server); this.servers.insert(server_id.clone(), server);
this.pending_servers.remove(&server_id); this.pending_servers.remove(&server_id);
@ -170,7 +207,7 @@ impl ContextServerManager {
task task
} }
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> { pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
self.servers.get(id).cloned() self.servers.get(id).cloned()
} }
@ -178,7 +215,7 @@ impl ContextServerManager {
let id = id.to_string(); let id = id.to_string();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop().await?; server.stop()?;
} }
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.pending_servers.remove(&id); this.pending_servers.remove(&id);
@ -192,16 +229,16 @@ impl ContextServerManager {
pub fn restart_server( pub fn restart_server(
&mut self, &mut self,
id: &str, id: &Arc<str>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> { ) -> Task<anyhow::Result<()>> {
let id = id.to_string(); let id = id.to_string();
cx.spawn(|this, mut cx| async move { cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop().await?; server.stop()?;
let config = server.config.clone(); let config = server.config();
let new_server = Arc::new(ContextServer::new(config)); let new_server = Arc::new(NativeContextServer::new(config));
new_server.start(&cx).await?; new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server); this.servers.insert(id.clone(), new_server);
cx.emit(Event::ServerStopped { cx.emit(Event::ServerStopped {
@ -216,7 +253,7 @@ impl ContextServerManager {
}) })
} }
pub fn servers(&self) -> Vec<Arc<ContextServer>> { pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
self.servers.values().cloned().collect() self.servers.values().cloned().collect()
} }
@ -224,7 +261,7 @@ impl ContextServerManager {
let current_servers = self let current_servers = self
.servers() .servers()
.into_iter() .into_iter()
.map(|server| (server.id.clone(), server.config.clone())) .map(|server| (server.id(), server.config()))
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let new_servers = settings let new_servers = settings
@ -235,19 +272,19 @@ impl ContextServerManager {
let servers_to_add = new_servers let servers_to_add = new_servers
.values() .values()
.filter(|config| !current_servers.contains_key(&config.id)) .filter(|config| !current_servers.contains_key(config.id.as_str()))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let servers_to_remove = current_servers let servers_to_remove = current_servers
.keys() .keys()
.filter(|id| !new_servers.contains_key(*id)) .filter(|id| !new_servers.contains_key(id.as_ref()))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
log::trace!("servers_to_add={:?}", servers_to_add); log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add { for config in servers_to_add {
self.add_server(config, cx).detach_and_log_err(cx); self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
} }
for id in servers_to_remove { for id in servers_to_remove {