Put context servers behind a trait (#20432)
This PR puts context servers behind the `ContextServer` trait to allow us to provide context servers from an extension. Release Notes: - N/A
This commit is contained in:
parent
01503511ad
commit
09c599385a
6 changed files with 102 additions and 61 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2815,6 +2815,7 @@ name = "context_servers"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-trait",
|
||||||
"collections",
|
"collections",
|
||||||
"command_palette_hooks",
|
"command_palette_hooks",
|
||||||
"futures 0.3.30",
|
"futures 0.3.30",
|
||||||
|
|
|
@ -819,7 +819,7 @@ impl ContextStore {
|
||||||
|context_server_manager, cx| {
|
|context_server_manager, cx| {
|
||||||
for server in context_server_manager.servers() {
|
for server in context_server_manager.servers() {
|
||||||
context_server_manager
|
context_server_manager
|
||||||
.restart_server(&server.id, cx)
|
.restart_server(&server.id(), cx)
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -850,7 +850,7 @@ impl ContextStore {
|
||||||
let server = server.clone();
|
let server = server.clone();
|
||||||
let server_id = server_id.clone();
|
let server_id = server_id.clone();
|
||||||
|this, mut cx| async move {
|
|this, mut cx| async move {
|
||||||
let Some(protocol) = server.client.read().clone() else {
|
let Some(protocol) = server.client() else {
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -889,7 +889,7 @@ impl ContextStore {
|
||||||
tool_working_set.insert(
|
tool_working_set.insert(
|
||||||
Arc::new(tools::context_server_tool::ContextServerTool::new(
|
Arc::new(tools::context_server_tool::ContextServerTool::new(
|
||||||
context_server_manager.clone(),
|
context_server_manager.clone(),
|
||||||
server.id.clone(),
|
server.id(),
|
||||||
tool,
|
tool,
|
||||||
)),
|
)),
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,18 +20,18 @@ use crate::slash_command::create_label_for_command;
|
||||||
|
|
||||||
pub struct ContextServerSlashCommand {
|
pub struct ContextServerSlashCommand {
|
||||||
server_manager: Model<ContextServerManager>,
|
server_manager: Model<ContextServerManager>,
|
||||||
server_id: String,
|
server_id: Arc<str>,
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextServerSlashCommand {
|
impl ContextServerSlashCommand {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
server_manager: Model<ContextServerManager>,
|
server_manager: Model<ContextServerManager>,
|
||||||
server: &Arc<ContextServer>,
|
server: &Arc<dyn ContextServer>,
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
server_id: server.id.clone(),
|
server_id: server.id(),
|
||||||
prompt,
|
prompt,
|
||||||
server_manager,
|
server_manager,
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
|
|
||||||
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
|
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let Some(protocol) = server.client.read().clone() else {
|
let Some(protocol) = server.client() else {
|
||||||
return Err(anyhow!("Context server not initialized"));
|
return Err(anyhow!("Context server not initialized"));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
||||||
let manager = self.server_manager.read(cx);
|
let manager = self.server_manager.read(cx);
|
||||||
if let Some(server) = manager.get_server(&server_id) {
|
if let Some(server) = manager.get_server(&server_id) {
|
||||||
cx.foreground_executor().spawn(async move {
|
cx.foreground_executor().spawn(async move {
|
||||||
let Some(protocol) = server.client.read().clone() else {
|
let Some(protocol) = server.client() else {
|
||||||
return Err(anyhow!("Context server not initialized"));
|
return Err(anyhow!("Context server not initialized"));
|
||||||
};
|
};
|
||||||
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
|
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{anyhow, bail};
|
use anyhow::{anyhow, bail};
|
||||||
use assistant_tool::Tool;
|
use assistant_tool::Tool;
|
||||||
use context_servers::manager::ContextServerManager;
|
use context_servers::manager::ContextServerManager;
|
||||||
|
@ -6,14 +8,14 @@ use gpui::{Model, Task};
|
||||||
|
|
||||||
pub struct ContextServerTool {
|
pub struct ContextServerTool {
|
||||||
server_manager: Model<ContextServerManager>,
|
server_manager: Model<ContextServerManager>,
|
||||||
server_id: String,
|
server_id: Arc<str>,
|
||||||
tool: types::Tool,
|
tool: types::Tool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextServerTool {
|
impl ContextServerTool {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
server_manager: Model<ContextServerManager>,
|
server_manager: Model<ContextServerManager>,
|
||||||
server_id: impl Into<String>,
|
server_id: impl Into<Arc<str>>,
|
||||||
tool: types::Tool,
|
tool: types::Tool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -55,7 +57,7 @@ impl Tool for ContextServerTool {
|
||||||
cx.foreground_executor().spawn({
|
cx.foreground_executor().spawn({
|
||||||
let tool_name = self.tool.name.clone();
|
let tool_name = self.tool.name.clone();
|
||||||
async move {
|
async move {
|
||||||
let Some(protocol) = server.client.read().clone() else {
|
let Some(protocol) = server.client() else {
|
||||||
bail!("Context server not initialized");
|
bail!("Context server not initialized");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ path = "src/context_servers.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
collections.workspace = true
|
collections.workspace = true
|
||||||
command_palette_hooks.workspace = true
|
command_palette_hooks.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
|
|
@ -15,9 +15,13 @@
|
||||||
//! and react to changes in settings.
|
//! and react to changes in settings.
|
||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
|
use futures::{Future, FutureExt};
|
||||||
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
|
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
|
||||||
use log;
|
use log;
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
|
@ -56,51 +60,84 @@ impl Settings for ContextServerSettings {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ContextServer {
|
#[async_trait(?Send)]
|
||||||
pub id: String,
|
pub trait ContextServer: Send + Sync + 'static {
|
||||||
pub config: ServerConfig,
|
fn id(&self) -> Arc<str>;
|
||||||
|
fn config(&self) -> Arc<ServerConfig>;
|
||||||
|
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
|
||||||
|
fn start<'a>(
|
||||||
|
self: Arc<Self>,
|
||||||
|
cx: &'a AsyncAppContext,
|
||||||
|
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
|
||||||
|
fn stop(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct NativeContextServer {
|
||||||
|
pub id: Arc<str>,
|
||||||
|
pub config: Arc<ServerConfig>,
|
||||||
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ContextServer {
|
impl NativeContextServer {
|
||||||
fn new(config: ServerConfig) -> Self {
|
fn new(config: Arc<ServerConfig>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: config.id.clone(),
|
id: config.id.clone().into(),
|
||||||
config,
|
config,
|
||||||
client: RwLock::new(None),
|
client: RwLock::new(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
|
#[async_trait(?Send)]
|
||||||
log::info!("starting context server {}", self.config.id,);
|
impl ContextServer for NativeContextServer {
|
||||||
let client = Client::new(
|
fn id(&self) -> Arc<str> {
|
||||||
client::ContextServerId(self.config.id.clone()),
|
self.id.clone()
|
||||||
client::ModelContextServerBinary {
|
|
||||||
executable: Path::new(&self.config.executable).to_path_buf(),
|
|
||||||
args: self.config.args.clone(),
|
|
||||||
env: self.config.env.clone(),
|
|
||||||
},
|
|
||||||
cx.clone(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
|
||||||
let client_info = types::Implementation {
|
|
||||||
name: "Zed".to_string(),
|
|
||||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
|
||||||
};
|
|
||||||
let initialized_protocol = protocol.initialize(client_info).await?;
|
|
||||||
|
|
||||||
log::debug!(
|
|
||||||
"context server {} initialized: {:?}",
|
|
||||||
self.config.id,
|
|
||||||
initialized_protocol.initialize,
|
|
||||||
);
|
|
||||||
|
|
||||||
*self.client.write() = Some(Arc::new(initialized_protocol));
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stop(&self) -> anyhow::Result<()> {
|
fn config(&self) -> Arc<ServerConfig> {
|
||||||
|
self.config.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
|
||||||
|
self.client.read().clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start<'a>(
|
||||||
|
self: Arc<Self>,
|
||||||
|
cx: &'a AsyncAppContext,
|
||||||
|
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
|
||||||
|
async move {
|
||||||
|
log::info!("starting context server {}", self.config.id,);
|
||||||
|
let client = Client::new(
|
||||||
|
client::ContextServerId(self.config.id.clone()),
|
||||||
|
client::ModelContextServerBinary {
|
||||||
|
executable: Path::new(&self.config.executable).to_path_buf(),
|
||||||
|
args: self.config.args.clone(),
|
||||||
|
env: self.config.env.clone(),
|
||||||
|
},
|
||||||
|
cx.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let protocol = crate::protocol::ModelContextProtocol::new(client);
|
||||||
|
let client_info = types::Implementation {
|
||||||
|
name: "Zed".to_string(),
|
||||||
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||||
|
};
|
||||||
|
let initialized_protocol = protocol.initialize(client_info).await?;
|
||||||
|
|
||||||
|
log::debug!(
|
||||||
|
"context server {} initialized: {:?}",
|
||||||
|
self.config.id,
|
||||||
|
initialized_protocol.initialize,
|
||||||
|
);
|
||||||
|
|
||||||
|
*self.client.write() = Some(Arc::new(initialized_protocol));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
.boxed_local()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stop(&self) -> Result<()> {
|
||||||
let mut client = self.client.write();
|
let mut client = self.client.write();
|
||||||
if let Some(protocol) = client.take() {
|
if let Some(protocol) = client.take() {
|
||||||
drop(protocol);
|
drop(protocol);
|
||||||
|
@ -114,7 +151,7 @@ impl ContextServer {
|
||||||
/// must go through the `GlobalContextServerManager` which holds
|
/// must go through the `GlobalContextServerManager` which holds
|
||||||
/// a model to the ContextServerManager.
|
/// a model to the ContextServerManager.
|
||||||
pub struct ContextServerManager {
|
pub struct ContextServerManager {
|
||||||
servers: HashMap<String, Arc<ContextServer>>,
|
servers: HashMap<String, Arc<dyn ContextServer>>,
|
||||||
pending_servers: HashSet<String>,
|
pending_servers: HashSet<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,7 +178,7 @@ impl ContextServerManager {
|
||||||
|
|
||||||
pub fn add_server(
|
pub fn add_server(
|
||||||
&mut self,
|
&mut self,
|
||||||
config: ServerConfig,
|
config: Arc<ServerConfig>,
|
||||||
cx: &ModelContext<Self>,
|
cx: &ModelContext<Self>,
|
||||||
) -> Task<anyhow::Result<()>> {
|
) -> Task<anyhow::Result<()>> {
|
||||||
let server_id = config.id.clone();
|
let server_id = config.id.clone();
|
||||||
|
@ -153,8 +190,8 @@ impl ContextServerManager {
|
||||||
let task = {
|
let task = {
|
||||||
let server_id = server_id.clone();
|
let server_id = server_id.clone();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
let server = Arc::new(ContextServer::new(config));
|
let server = Arc::new(NativeContextServer::new(config));
|
||||||
server.start(&cx).await?;
|
server.clone().start(&cx).await?;
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.servers.insert(server_id.clone(), server);
|
this.servers.insert(server_id.clone(), server);
|
||||||
this.pending_servers.remove(&server_id);
|
this.pending_servers.remove(&server_id);
|
||||||
|
@ -170,7 +207,7 @@ impl ContextServerManager {
|
||||||
task
|
task
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
|
pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
|
||||||
self.servers.get(id).cloned()
|
self.servers.get(id).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +215,7 @@ impl ContextServerManager {
|
||||||
let id = id.to_string();
|
let id = id.to_string();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
||||||
server.stop().await?;
|
server.stop()?;
|
||||||
}
|
}
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.pending_servers.remove(&id);
|
this.pending_servers.remove(&id);
|
||||||
|
@ -192,16 +229,16 @@ impl ContextServerManager {
|
||||||
|
|
||||||
pub fn restart_server(
|
pub fn restart_server(
|
||||||
&mut self,
|
&mut self,
|
||||||
id: &str,
|
id: &Arc<str>,
|
||||||
cx: &mut ModelContext<Self>,
|
cx: &mut ModelContext<Self>,
|
||||||
) -> Task<anyhow::Result<()>> {
|
) -> Task<anyhow::Result<()>> {
|
||||||
let id = id.to_string();
|
let id = id.to_string();
|
||||||
cx.spawn(|this, mut cx| async move {
|
cx.spawn(|this, mut cx| async move {
|
||||||
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
|
||||||
server.stop().await?;
|
server.stop()?;
|
||||||
let config = server.config.clone();
|
let config = server.config();
|
||||||
let new_server = Arc::new(ContextServer::new(config));
|
let new_server = Arc::new(NativeContextServer::new(config));
|
||||||
new_server.start(&cx).await?;
|
new_server.clone().start(&cx).await?;
|
||||||
this.update(&mut cx, |this, cx| {
|
this.update(&mut cx, |this, cx| {
|
||||||
this.servers.insert(id.clone(), new_server);
|
this.servers.insert(id.clone(), new_server);
|
||||||
cx.emit(Event::ServerStopped {
|
cx.emit(Event::ServerStopped {
|
||||||
|
@ -216,7 +253,7 @@ impl ContextServerManager {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn servers(&self) -> Vec<Arc<ContextServer>> {
|
pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
|
||||||
self.servers.values().cloned().collect()
|
self.servers.values().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,7 +261,7 @@ impl ContextServerManager {
|
||||||
let current_servers = self
|
let current_servers = self
|
||||||
.servers()
|
.servers()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|server| (server.id.clone(), server.config.clone()))
|
.map(|server| (server.id(), server.config()))
|
||||||
.collect::<HashMap<_, _>>();
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
let new_servers = settings
|
let new_servers = settings
|
||||||
|
@ -235,19 +272,19 @@ impl ContextServerManager {
|
||||||
|
|
||||||
let servers_to_add = new_servers
|
let servers_to_add = new_servers
|
||||||
.values()
|
.values()
|
||||||
.filter(|config| !current_servers.contains_key(&config.id))
|
.filter(|config| !current_servers.contains_key(config.id.as_str()))
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let servers_to_remove = current_servers
|
let servers_to_remove = current_servers
|
||||||
.keys()
|
.keys()
|
||||||
.filter(|id| !new_servers.contains_key(*id))
|
.filter(|id| !new_servers.contains_key(id.as_ref()))
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
log::trace!("servers_to_add={:?}", servers_to_add);
|
log::trace!("servers_to_add={:?}", servers_to_add);
|
||||||
for config in servers_to_add {
|
for config in servers_to_add {
|
||||||
self.add_server(config, cx).detach_and_log_err(cx);
|
self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
|
||||||
}
|
}
|
||||||
|
|
||||||
for id in servers_to_remove {
|
for id in servers_to_remove {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue