Improve context server lifecycle management (#20622)

This optimizes and fixes bugs in our logic for maintaining a set of
running context servers, based on the combination of the user's
`context_servers` settings and their installed extensions.

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
Max Brunsfeld 2024-11-13 13:55:06 -08:00 committed by GitHub
parent 6e477bbf56
commit d3d408d47d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 278 additions and 408 deletions

4
Cargo.lock generated
View file

@ -2572,6 +2572,7 @@ dependencies = [
"clock", "clock",
"collab_ui", "collab_ui",
"collections", "collections",
"context_servers",
"ctor", "ctor",
"dashmap 6.0.1", "dashmap 6.0.1",
"derive_more", "derive_more",
@ -2818,7 +2819,6 @@ name = "context_servers"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"collections", "collections",
"command_palette_hooks", "command_palette_hooks",
"futures 0.3.30", "futures 0.3.30",
@ -4205,7 +4205,6 @@ dependencies = [
"assistant_slash_command", "assistant_slash_command",
"async-compression", "async-compression",
"async-tar", "async-tar",
"async-trait",
"client", "client",
"collections", "collections",
"context_servers", "context_servers",
@ -4222,6 +4221,7 @@ dependencies = [
"http_client", "http_client",
"indexed_docs", "indexed_docs",
"language", "language",
"log",
"lsp", "lsp",
"node_runtime", "node_runtime",
"num-format", "num-format",

View file

@ -8,9 +8,8 @@ use anyhow::{anyhow, Context as _, Result};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
use clock::ReplicaId; use clock::ReplicaId;
use collections::HashMap; use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter; use context_servers::manager::ContextServerManager;
use context_servers::manager::{ContextServerManager, ContextServerSettings}; use context_servers::ContextServerFactoryRegistry;
use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
use fs::Fs; use fs::Fs;
use futures::StreamExt; use futures::StreamExt;
use fuzzy::StringMatchCandidate; use fuzzy::StringMatchCandidate;
@ -22,7 +21,6 @@ use paths::contexts_dir;
use project::Project; use project::Project;
use regex::Regex; use regex::Regex;
use rpc::AnyProtoClient; use rpc::AnyProtoClient;
use settings::{Settings as _, SettingsStore};
use std::{ use std::{
cmp::Reverse, cmp::Reverse,
ffi::OsStr, ffi::OsStr,
@ -111,7 +109,11 @@ impl ContextStore {
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new_model(|cx: &mut ModelContext<Self>| { let this = cx.new_model(|cx: &mut ModelContext<Self>| {
let context_server_manager = cx.new_model(|_cx| ContextServerManager::new()); let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new_model(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let mut this = Self { let mut this = Self {
contexts: Vec::new(), contexts: Vec::new(),
contexts_metadata: Vec::new(), contexts_metadata: Vec::new(),
@ -148,91 +150,16 @@ impl ContextStore {
this.handle_project_changed(project.clone(), cx); this.handle_project_changed(project.clone(), cx);
this.synchronize_contexts(cx); this.synchronize_contexts(cx);
this.register_context_server_handlers(cx); this.register_context_server_handlers(cx);
if project.read(cx).is_local() {
// TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
// In order to register the context servers when the extension is loaded, we're periodically looping to
// see if there are context servers to register.
//
// I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
//
// We should find a more elegant way to do this.
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
cx.spawn(|context_store, mut cx| async move {
loop {
let mut servers_to_register = Vec::new();
for (_id, factory) in
context_server_factory_registry.context_server_factories()
{
if let Some(server) = factory(project.clone(), &cx).await.log_err()
{
servers_to_register.push(server);
}
}
let Some(_) = context_store
.update(&mut cx, |this, cx| {
this.context_server_manager.update(cx, |this, cx| {
for server in servers_to_register {
this.add_server(server, cx).detach_and_log_err(cx);
}
})
})
.log_err()
else {
break;
};
smol::Timer::after(Duration::from_millis(100)).await;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
this this
})?; })?;
this.update(&mut cx, |this, cx| this.reload(cx))? this.update(&mut cx, |this, cx| this.reload(cx))?
.await .await
.log_err(); .log_err();
this.update(&mut cx, |this, cx| {
this.watch_context_server_settings(cx);
})
.log_err();
Ok(this) Ok(this)
}) })
} }
fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
cx.observe_global::<SettingsStore>(move |this, cx| {
this.context_server_manager.update(cx, |manager, cx| {
let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
settings::SettingsLocation {
worktree_id: worktree.read(cx).id(),
path: Path::new(""),
}
});
let settings = ContextServerSettings::get(location, cx);
manager.maintain_servers(settings, cx);
let has_any_context_servers = !manager.servers().is_empty();
CommandPaletteFilter::update_global(cx, |filter, _cx| {
if has_any_context_servers {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
} else {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
}
});
})
})
.detach();
}
async fn handle_advertise_contexts( async fn handle_advertise_contexts(
this: Model<Self>, this: Model<Self>,
envelope: TypedEnvelope<proto::AdvertiseContexts>, envelope: TypedEnvelope<proto::AdvertiseContexts>,

View file

@ -27,7 +27,7 @@ pub struct ContextServerSlashCommand {
impl ContextServerSlashCommand { impl ContextServerSlashCommand {
pub fn new( pub fn new(
server_manager: Model<ContextServerManager>, server_manager: Model<ContextServerManager>,
server: &Arc<dyn ContextServer>, server: &Arc<ContextServer>,
prompt: Prompt, prompt: Prompt,
) -> Self { ) -> Self {
Self { Self {

View file

@ -78,6 +78,7 @@ uuid.workspace = true
[dev-dependencies] [dev-dependencies]
assistant = { workspace = true, features = ["test-support"] } assistant = { workspace = true, features = ["test-support"] }
context_servers.workspace = true
async-trait.workspace = true async-trait.workspace = true
audio.workspace = true audio.workspace = true
call = { workspace = true, features = ["test-support"] } call = { workspace = true, features = ["test-support"] }

View file

@ -6486,6 +6486,8 @@ async fn test_context_collaboration_with_reconnect(
assert_eq!(project.collaborators().len(), 1); assert_eq!(project.collaborators().len(), 1);
}); });
cx_a.update(context_servers::init);
cx_b.update(context_servers::init);
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a let context_store_a = cx_a
.update(|cx| { .update(|cx| {

View file

@ -39,11 +39,13 @@ impl CommandPaletteFilter {
} }
/// Updates the global [`CommandPaletteFilter`] using the given closure. /// Updates the global [`CommandPaletteFilter`] using the given closure.
pub fn update_global<F, R>(cx: &mut AppContext, update: F) -> R pub fn update_global<F>(cx: &mut AppContext, update: F)
where where
F: FnOnce(&mut Self, &mut AppContext) -> R, F: FnOnce(&mut Self, &mut AppContext),
{ {
cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx)) if cx.has_global::<GlobalCommandPaletteFilter>() {
cx.update_global(|this: &mut GlobalCommandPaletteFilter, cx| update(&mut this.0, cx))
}
} }
/// Returns whether the given [`Action`] is hidden by the filter. /// Returns whether the given [`Action`] is hidden by the filter.

View file

@ -13,7 +13,6 @@ path = "src/context_servers.rs"
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-trait.workspace = true
collections.workspace = true collections.workspace = true
command_palette_hooks.workspace = true command_palette_hooks.workspace = true
futures.workspace = true futures.workspace = true

View file

@ -8,7 +8,6 @@ use command_palette_hooks::CommandPaletteFilter;
use gpui::{actions, AppContext}; use gpui::{actions, AppContext};
use settings::Settings; use settings::Settings;
pub use crate::manager::ContextServer;
use crate::manager::ContextServerSettings; use crate::manager::ContextServerSettings;
pub use crate::registry::ContextServerFactoryRegistry; pub use crate::registry::ContextServerFactoryRegistry;

View file

@ -15,23 +15,23 @@
//! and react to changes in settings. //! and react to changes in settings.
use std::path::Path; use std::path::Path;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use async_trait::async_trait; use collections::HashMap;
use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter;
use futures::{Future, FutureExt}; use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
use log; use log;
use parking_lot::RwLock; use parking_lot::RwLock;
use project::Project;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources}; use settings::{Settings, SettingsSources, SettingsStore};
use util::ResultExt as _;
use crate::{ use crate::{
client::{self, Client}, client::{self, Client},
types, types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
}; };
#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)] #[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
@ -66,25 +66,13 @@ impl Settings for ContextServerSettings {
} }
} }
#[async_trait(?Send)] pub struct ContextServer {
pub trait ContextServer: Send + Sync + 'static {
fn id(&self) -> Arc<str>;
fn config(&self) -> Arc<ServerConfig>;
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
fn start<'a>(
self: Arc<Self>,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
fn stop(&self) -> Result<()>;
}
pub struct NativeContextServer {
pub id: Arc<str>, pub id: Arc<str>,
pub config: Arc<ServerConfig>, pub config: Arc<ServerConfig>,
pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>, pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
} }
impl NativeContextServer { impl ContextServer {
pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self { pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
Self { Self {
id, id,
@ -92,61 +80,52 @@ impl NativeContextServer {
client: RwLock::new(None), client: RwLock::new(None),
} }
} }
}
#[async_trait(?Send)] pub fn id(&self) -> Arc<str> {
impl ContextServer for NativeContextServer {
fn id(&self) -> Arc<str> {
self.id.clone() self.id.clone()
} }
fn config(&self) -> Arc<ServerConfig> { pub fn config(&self) -> Arc<ServerConfig> {
self.config.clone() self.config.clone()
} }
fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> { pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
self.client.read().clone() self.client.read().clone()
} }
fn start<'a>( pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
self: Arc<Self>, log::info!("starting context server {}", self.id);
cx: &'a AsyncAppContext, let Some(command) = &self.config.command else {
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> { bail!("no command specified for server {}", self.id);
async move { };
log::info!("starting context server {}", self.id); let client = Client::new(
let Some(command) = &self.config.command else { client::ContextServerId(self.id.clone()),
bail!("no command specified for server {}", self.id); client::ModelContextServerBinary {
}; executable: Path::new(&command.path).to_path_buf(),
let client = Client::new( args: command.args.clone(),
client::ContextServerId(self.id.clone()), env: command.env.clone(),
client::ModelContextServerBinary { },
executable: Path::new(&command.path).to_path_buf(), cx.clone(),
args: command.args.clone(), )?;
env: command.env.clone(),
},
cx.clone(),
)?;
let protocol = crate::protocol::ModelContextProtocol::new(client); let protocol = crate::protocol::ModelContextProtocol::new(client);
let client_info = types::Implementation { let client_info = types::Implementation {
name: "Zed".to_string(), name: "Zed".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(), version: env!("CARGO_PKG_VERSION").to_string(),
}; };
let initialized_protocol = protocol.initialize(client_info).await?; let initialized_protocol = protocol.initialize(client_info).await?;
log::debug!( log::debug!(
"context server {} initialized: {:?}", "context server {} initialized: {:?}",
self.id, self.id,
initialized_protocol.initialize, initialized_protocol.initialize,
); );
*self.client.write() = Some(Arc::new(initialized_protocol)); *self.client.write() = Some(Arc::new(initialized_protocol));
Ok(()) Ok(())
}
.boxed_local()
} }
fn stop(&self) -> Result<()> { pub fn stop(&self) -> Result<()> {
let mut client = self.client.write(); let mut client = self.client.write();
if let Some(protocol) = client.take() { if let Some(protocol) = client.take() {
drop(protocol); drop(protocol);
@ -155,13 +134,13 @@ impl ContextServer for NativeContextServer {
} }
} }
/// A Context server manager manages the starting and stopping
/// of all servers. To obtain a server to interact with, a crate
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager { pub struct ContextServerManager {
servers: HashMap<Arc<str>, Arc<dyn ContextServer>>, servers: HashMap<Arc<str>, Arc<ContextServer>>,
pending_servers: HashSet<Arc<str>>, project: Model<Project>,
registry: Model<ContextServerFactoryRegistry>,
update_servers_task: Option<Task<Result<()>>>,
needs_server_update: bool,
_subscriptions: Vec<Subscription>,
} }
pub enum Event { pub enum Event {
@ -171,74 +150,66 @@ pub enum Event {
impl EventEmitter<Event> for ContextServerManager {} impl EventEmitter<Event> for ContextServerManager {}
impl Default for ContextServerManager {
fn default() -> Self {
Self::new()
}
}
impl ContextServerManager { impl ContextServerManager {
pub fn new() -> Self { pub fn new(
Self { registry: Model<ContextServerFactoryRegistry>,
project: Model<Project>,
cx: &mut ModelContext<Self>,
) -> Self {
let mut this = Self {
_subscriptions: vec![
cx.observe(&registry, |this, _registry, cx| {
this.available_context_servers_changed(cx);
}),
cx.observe_global::<SettingsStore>(|this, cx| {
this.available_context_servers_changed(cx);
}),
],
project,
registry,
needs_server_update: false,
servers: HashMap::default(), servers: HashMap::default(),
pending_servers: HashSet::default(), update_servers_task: None,
}
}
pub fn add_server(
&mut self,
server: Arc<dyn ContextServer>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let server_id = server.id();
if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
return Task::ready(Ok(()));
}
let task = {
let server_id = server_id.clone();
cx.spawn(|this, mut cx| async move {
server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server);
this.pending_servers.remove(&server_id);
cx.emit(Event::ServerStarted {
server_id: server_id.clone(),
});
})?;
Ok(())
})
}; };
this.available_context_servers_changed(cx);
self.pending_servers.insert(server_id); this
task
} }
pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> { fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
self.servers.get(id).cloned() if self.update_servers_task.is_some() {
self.needs_server_update = true;
} else {
self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
this.update(&mut cx, |this, _| {
this.needs_server_update = false;
})?;
Self::maintain_servers(this.clone(), cx.clone()).await?;
this.update(&mut cx, |this, cx| {
let has_any_context_servers = !this.servers().is_empty();
if has_any_context_servers {
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
});
}
this.update_servers_task.take();
if this.needs_server_update {
this.available_context_servers_changed(cx);
}
})?;
Ok(())
}));
}
} }
pub fn remove_server( pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
&mut self, self.servers
id: &Arc<str>, .get(id)
cx: &ModelContext<Self>, .filter(|server| server.client().is_some())
) -> Task<anyhow::Result<()>> { .cloned()
let id = id.clone();
cx.spawn(|this, mut cx| async move {
if let Some(server) =
this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
{
server.stop()?;
}
this.update(&mut cx, |this, cx| {
this.pending_servers.remove(id.as_ref());
cx.emit(Event::ServerStopped {
server_id: id.clone(),
})
})?;
Ok(())
})
} }
pub fn restart_server( pub fn restart_server(
@ -251,7 +222,7 @@ impl ContextServerManager {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?; server.stop()?;
let config = server.config(); let config = server.config();
let new_server = Arc::new(NativeContextServer::new(id.clone(), config)); let new_server = Arc::new(ContextServer::new(id.clone(), config));
new_server.clone().start(&cx).await?; new_server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| { this.update(&mut cx, |this, cx| {
this.servers.insert(id.clone(), new_server); this.servers.insert(id.clone(), new_server);
@ -267,45 +238,83 @@ impl ContextServerManager {
}) })
} }
pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> { pub fn servers(&self) -> Vec<Arc<ContextServer>> {
self.servers.values().cloned().collect() self.servers
.values()
.filter(|server| server.client().is_some())
.cloned()
.collect()
} }
pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) { async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
let current_servers = self let mut desired_servers = HashMap::default();
.servers()
.into_iter()
.map(|server| (server.id(), server.config()))
.collect::<HashMap<_, _>>();
let new_servers = settings let (registry, project) = this.update(&mut cx, |this, cx| {
.context_servers let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
.iter() settings::SettingsLocation {
.map(|(id, config)| (id.clone(), config.clone())) worktree_id: worktree.read(cx).id(),
.collect::<HashMap<_, _>>(); path: Path::new(""),
}
});
let settings = ContextServerSettings::get(location, cx);
desired_servers = settings.context_servers.clone();
let servers_to_add = new_servers (this.registry.clone(), this.project.clone())
.iter() })?;
.filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
.map(|(id, config)| (id.clone(), config.clone()))
.collect::<Vec<_>>();
let servers_to_remove = current_servers for (id, factory) in
.keys() registry.read_with(&cx, |registry, _| registry.context_server_factories())?
.filter(|id| !new_servers.contains_key(id.as_ref())) {
.cloned() let config = desired_servers.entry(id).or_default();
.collect::<Vec<_>>(); if config.command.is_none() {
if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
log::trace!("servers_to_add={:?}", servers_to_add); config.command = Some(extension_command);
for (id, config) in servers_to_add { }
if config.command.is_some() {
let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
self.add_server(server, cx).detach_and_log_err(cx);
} }
} }
for id in servers_to_remove { let mut servers_to_start = HashMap::default();
self.remove_server(&id, cx).detach_and_log_err(cx); let mut servers_to_stop = HashMap::default();
this.update(&mut cx, |this, _cx| {
this.servers.retain(|id, server| {
if desired_servers.contains_key(id) {
true
} else {
servers_to_stop.insert(id.clone(), server.clone());
false
}
});
for (id, config) in desired_servers {
let existing_config = this.servers.get(&id).map(|server| server.config());
if existing_config.as_deref() != Some(&config) {
let config = Arc::new(config);
let server = Arc::new(ContextServer::new(id.clone(), config));
servers_to_start.insert(id.clone(), server.clone());
let old_server = this.servers.insert(id.clone(), server);
if let Some(old_server) = old_server {
servers_to_stop.insert(id, old_server);
}
}
}
})?;
for (id, server) in servers_to_stop {
server.stop().log_err();
this.update(&mut cx, |_, cx| {
cx.emit(Event::ServerStopped { server_id: id })
})?;
} }
for (id, server) in servers_to_start {
if server.start(&cx).await.log_err().is_some() {
this.update(&mut cx, |_, cx| {
cx.emit(Event::ServerStarted { server_id: id })
})?;
}
}
Ok(())
} }
} }

View file

@ -2,75 +2,61 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use collections::HashMap; use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, Global, Model, ReadGlobal, Task}; use gpui::{AppContext, AsyncAppContext, Context, Global, Model, ReadGlobal, Task};
use parking_lot::RwLock;
use project::Project; use project::Project;
use crate::ContextServer; use crate::manager::ServerCommand;
pub type ContextServerFactory = Arc< pub type ContextServerFactory = Arc<
dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>> dyn Fn(Model<Project>, &AsyncAppContext) -> Task<Result<ServerCommand>> + Send + Sync + 'static,
+ Send
+ Sync
+ 'static,
>; >;
#[derive(Default)] struct GlobalContextServerFactoryRegistry(Model<ContextServerFactoryRegistry>);
struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {} impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
struct ContextServerFactoryRegistryState {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
}
#[derive(Default)] #[derive(Default)]
pub struct ContextServerFactoryRegistry { pub struct ContextServerFactoryRegistry {
state: RwLock<ContextServerFactoryRegistryState>, context_servers: HashMap<Arc<str>, ContextServerFactory>,
} }
impl ContextServerFactoryRegistry { impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`]. /// Returns the global [`ContextServerFactoryRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> { pub fn global(cx: &AppContext) -> Model<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone() GlobalContextServerFactoryRegistry::global(cx).0.clone()
} }
/// Returns the global [`ContextServerFactoryRegistry`]. /// Returns the global [`ContextServerFactoryRegistry`].
/// ///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist. /// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> { pub fn default_global(cx: &mut AppContext) -> Model<Self> {
cx.default_global::<GlobalContextServerFactoryRegistry>() if !cx.has_global::<GlobalContextServerFactoryRegistry>() {
.0 let registry = cx.new_model(|_| Self::new());
.clone() cx.set_global(GlobalContextServerFactoryRegistry(registry));
}
cx.global::<GlobalContextServerFactoryRegistry>().0.clone()
} }
pub fn new() -> Arc<Self> { pub fn new() -> Self {
Arc::new(Self { Self {
state: RwLock::new(ContextServerFactoryRegistryState { context_servers: HashMap::default(),
context_servers: HashMap::default(), }
}),
})
} }
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> { pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
self.state self.context_servers
.read()
.context_servers
.iter() .iter()
.map(|(id, factory)| (id.clone(), factory.clone())) .map(|(id, factory)| (id.clone(), factory.clone()))
.collect() .collect()
} }
/// Registers the provided [`ContextServerFactory`]. /// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) { pub fn register_server_factory(&mut self, id: Arc<str>, factory: ContextServerFactory) {
let mut state = self.state.write(); self.context_servers.insert(id, factory);
state.context_servers.insert(id, factory);
} }
/// Unregisters the [`ContextServerFactory`] for the server with the given ID. /// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&self, server_id: &str) { pub fn unregister_server_factory_by_id(&mut self, server_id: &str) {
let mut state = self.state.write(); self.context_servers.remove(server_id);
state.context_servers.remove(server_id);
} }
} }

View file

@ -141,7 +141,7 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
&self, &self,
_id: Arc<str>, _id: Arc<str>,
_extension: WasmExtension, _extension: WasmExtension,
_host: Arc<WasmHost>, _cx: &mut AppContext,
) { ) {
} }
@ -1266,7 +1266,7 @@ impl ExtensionStore {
this.registration_hooks.register_context_server( this.registration_hooks.register_context_server(
id.clone(), id.clone(),
wasm_extension.clone(), wasm_extension.clone(),
this.wasm_host.clone(), cx,
); );
} }

View file

@ -17,7 +17,6 @@ test-support = []
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
assistant_slash_command.workspace = true assistant_slash_command.workspace = true
async-trait.workspace = true
client.workspace = true client.workspace = true
collections.workspace = true collections.workspace = true
context_servers.workspace = true context_servers.workspace = true
@ -31,6 +30,7 @@ fuzzy.workspace = true
gpui.workspace = true gpui.workspace = true
indexed_docs.workspace = true indexed_docs.workspace = true
language.workspace = true language.workspace = true
log.workspace = true
lsp.workspace = true lsp.workspace = true
num-format.workspace = true num-format.workspace = true
picker.workspace = true picker.workspace = true

View file

@ -1,97 +0,0 @@
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use context_servers::manager::{NativeContextServer, ServerCommand, ServerConfig};
use context_servers::protocol::InitializedContextServerProtocol;
use context_servers::ContextServer;
use extension_host::wasm_host::{ExtensionProject, WasmExtension, WasmHost};
use futures::{Future, FutureExt};
use gpui::{AsyncAppContext, Model};
use project::Project;
use wasmtime_wasi::WasiView as _;
pub struct ExtensionContextServer {
#[allow(unused)]
pub(crate) extension: WasmExtension,
#[allow(unused)]
pub(crate) host: Arc<WasmHost>,
id: Arc<str>,
context_server: Arc<NativeContextServer>,
}
impl ExtensionContextServer {
pub async fn new(
extension: WasmExtension,
host: Arc<WasmHost>,
id: Arc<str>,
project: Model<Project>,
mut cx: AsyncAppContext,
) -> Result<Self> {
let extension_project = project.update(&mut cx, |project, cx| ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})?;
let command = extension
.call({
let id = id.clone();
|extension, store| {
async move {
let project = store.data_mut().table().push(extension_project)?;
let command = extension
.call_context_server_command(store, id.clone(), project)
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(command)
}
.boxed()
}
})
.await?;
let config = Arc::new(ServerConfig {
settings: None,
command: Some(ServerCommand {
path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
}),
});
anyhow::Ok(Self {
extension,
host,
id: id.clone(),
context_server: Arc::new(NativeContextServer::new(id, config)),
})
}
}
#[async_trait(?Send)]
impl ContextServer for ExtensionContextServer {
fn id(&self) -> Arc<str> {
self.id.clone()
}
fn config(&self) -> Arc<ServerConfig> {
self.context_server.config()
}
fn client(&self) -> Option<Arc<InitializedContextServerProtocol>> {
self.context_server.client()
}
fn start<'a>(
self: Arc<Self>,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
self.context_server.clone().start(cx)
}
fn stop(&self) -> Result<()> {
self.context_server.stop()
}
}

View file

@ -1,19 +1,21 @@
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
use anyhow::Result; use anyhow::{anyhow, Result};
use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry}; use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry};
use context_servers::manager::ServerCommand;
use context_servers::ContextServerFactoryRegistry; use context_servers::ContextServerFactoryRegistry;
use db::smol::future::FutureExt as _;
use extension::Extension; use extension::Extension;
use extension_host::wasm_host::ExtensionProject;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host}; use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs; use fs::Fs;
use gpui::{AppContext, BackgroundExecutor, Task}; use gpui::{AppContext, BackgroundExecutor, Model, Task};
use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId}; use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId};
use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage}; use language::{LanguageRegistry, LanguageServerBinaryStatus, LoadedLanguage};
use snippet_provider::SnippetRegistry; use snippet_provider::SnippetRegistry;
use theme::{ThemeRegistry, ThemeSettings}; use theme::{ThemeRegistry, ThemeSettings};
use ui::SharedString; use ui::SharedString;
use wasmtime_wasi::WasiView as _;
use crate::extension_context_server::ExtensionContextServer;
pub struct ConcreteExtensionRegistrationHooks { pub struct ConcreteExtensionRegistrationHooks {
slash_command_registry: Arc<SlashCommandRegistry>, slash_command_registry: Arc<SlashCommandRegistry>,
@ -21,7 +23,7 @@ pub struct ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>, indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>, snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
context_server_factory_registry: Arc<ContextServerFactoryRegistry>, context_server_factory_registry: Model<ContextServerFactoryRegistry>,
executor: BackgroundExecutor, executor: BackgroundExecutor,
} }
@ -32,7 +34,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>, indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>, snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
context_server_factory_registry: Arc<ContextServerFactoryRegistry>, context_server_factory_registry: Model<ContextServerFactoryRegistry>,
cx: &AppContext, cx: &AppContext,
) -> Arc<dyn extension_host::ExtensionRegistrationHooks> { ) -> Arc<dyn extension_host::ExtensionRegistrationHooks> {
Arc::new(Self { Arc::new(Self {
@ -71,25 +73,66 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
&self, &self,
id: Arc<str>, id: Arc<str>,
extension: wasm_host::WasmExtension, extension: wasm_host::WasmExtension,
host: Arc<wasm_host::WasmHost>, cx: &mut AppContext,
) { ) {
self.context_server_factory_registry self.context_server_factory_registry
.register_server_factory( .update(cx, |registry, _| {
id.clone(), registry.register_server_factory(
Arc::new({ id.clone(),
move |project, cx| { Arc::new({
let id = id.clone(); move |project, cx| {
let extension = extension.clone(); log::info!(
let host = host.clone(); "loading command for context server {id} from extension {}",
cx.spawn(|cx| async move { extension.manifest.id
let context_server = );
ExtensionContextServer::new(extension, host, id, project, cx)
let id = id.clone();
let extension = extension.clone();
cx.spawn(|mut cx| async move {
let extension_project =
project.update(&mut cx, |project, cx| ExtensionProject {
worktree_ids: project
.visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto())
.collect(),
})?;
let command = extension
.call({
let id = id.clone();
|extension, store| {
async move {
let project = store
.data_mut()
.table()
.push(extension_project)?;
let command = extension
.call_context_server_command(
store,
id.clone(),
project,
)
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(command)
}
.boxed()
}
})
.await?; .await?;
anyhow::Ok(Arc::new(context_server) as _)
}) log::info!("loaded command for context server {id}: {command:?}");
}
}), Ok(ServerCommand {
); path: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
})
})
}
}),
)
});
} }
fn register_docs_provider(&self, extension: Arc<dyn Extension>, provider_id: Arc<str>) { fn register_docs_provider(&self, extension: Arc<dyn Extension>, provider_id: Arc<str>) {

View file

@ -268,7 +268,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new(); let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor())); let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new()); let snippet_registry = Arc::new(SnippetRegistry::new());
let context_server_factory_registry = ContextServerFactoryRegistry::new(); let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
let node_runtime = NodeRuntime::unavailable(); let node_runtime = NodeRuntime::unavailable();
let store = cx.new_model(|cx| { let store = cx.new_model(|cx| {
@ -508,7 +508,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new(); let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor())); let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new()); let snippet_registry = Arc::new(SnippetRegistry::new());
let context_server_factory_registry = ContextServerFactoryRegistry::new(); let context_server_factory_registry = cx.new_model(|_| ContextServerFactoryRegistry::new());
let node_runtime = NodeRuntime::unavailable(); let node_runtime = NodeRuntime::unavailable();
let mut status_updates = language_registry.language_server_binary_statuses(); let mut status_updates = language_registry.language_server_binary_statuses();

View file

@ -1,5 +1,4 @@
mod components; mod components;
mod extension_context_server;
mod extension_registration_hooks; mod extension_registration_hooks;
mod extension_suggest; mod extension_suggest;
mod extension_version_selector; mod extension_version_selector;