Add support for context server extensions (#20250)

This PR adds support for context servers provided by extensions.

To provide a context server from an extension, you need to list the
context servers in your `extension.toml`:

```toml
[context_servers.my-context-server]
```

And then implement the `context_server_command` method to return the
command that will be used to start the context server:

```rs
use zed_extension_api::{self as zed, Command, ContextServerId, Result};

struct ExampleContextServerExtension;

impl zed::Extension for ExampleContextServerExtension {
    fn new() -> Self {
        ExampleContextServerExtension
    }

    fn context_server_command(&mut self, _context_server_id: &ContextServerId) -> Result<Command> {
        Ok(Command {
            command: "node".to_string(),
            args: vec!["/path/to/example-context-server/index.js".to_string()],
            env: Vec::new(),
        })
    }
}

zed::register_extension!(ExampleContextServerExtension);
```

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-08 16:39:21 -05:00 committed by GitHub
parent ff4f67993b
commit f92e6e9a95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 340 additions and 22 deletions

2
Cargo.lock generated
View file

@ -4212,6 +4212,7 @@ dependencies = [
"async-trait",
"client",
"collections",
"context_servers",
"ctor",
"db",
"editor",
@ -15353,6 +15354,7 @@ dependencies = [
"collections",
"command_palette",
"command_palette_hooks",
"context_servers",
"copilot",
"db",
"diagnostics",

View file

@ -10,7 +10,7 @@ use clock::ReplicaId;
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use context_servers::manager::{ContextServerManager, ContextServerSettings};
use context_servers::CONTEXT_SERVERS_NAMESPACE;
use context_servers::{ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE};
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@ -51,8 +51,8 @@ pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
context_server_manager: Model<ContextServerManager>,
context_server_slash_command_ids: HashMap<String, Vec<SlashCommandId>>,
context_server_tool_ids: HashMap<String, Vec<ToolId>>,
context_server_slash_command_ids: HashMap<Arc<str>, Vec<SlashCommandId>>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
@ -148,6 +148,47 @@ impl ContextStore {
this.handle_project_changed(project, cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);
// TODO: At the time when we construct the `ContextStore` we may not have yet initialized the extensions.
// In order to register the context servers when the extension is loaded, we're periodically looping to
// see if there are context servers to register.
//
// I tried doing this in a subscription on the `ExtensionStore`, but it never seemed to fire.
//
// We should find a more elegant way to do this.
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
cx.spawn(|context_store, mut cx| async move {
loop {
let mut servers_to_register = Vec::new();
for (_id, factory) in
context_server_factory_registry.context_server_factories()
{
if let Some(server) = factory(&cx).await.log_err() {
servers_to_register.push(server);
}
}
let Some(_) = context_store
.update(&mut cx, |this, cx| {
this.context_server_manager.update(cx, |this, cx| {
for server in servers_to_register {
this.add_server(server, cx).detach_and_log_err(cx);
}
})
})
.log_err()
else {
break;
};
smol::Timer::after(Duration::from_millis(100)).await;
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
this
})?;
this.update(&mut cx, |this, cx| this.reload(cx))?

View file

@ -1,13 +1,16 @@
pub mod client;
pub mod manager;
pub mod protocol;
mod registry;
pub mod types;
use command_palette_hooks::CommandPaletteFilter;
use gpui::{actions, AppContext};
use settings::Settings;
pub use crate::manager::ContextServer;
use crate::manager::ContextServerSettings;
pub mod client;
pub mod manager;
pub mod protocol;
pub mod types;
pub use crate::registry::ContextServerFactoryRegistry;
actions!(context_servers, [Restart]);
@ -16,6 +19,7 @@ pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut AppContext) {
ContextServerSettings::register(cx);
ContextServerFactoryRegistry::default_global(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);

View file

@ -79,7 +79,7 @@ pub struct NativeContextServer {
}
impl NativeContextServer {
fn new(config: Arc<ServerConfig>) -> Self {
pub fn new(config: Arc<ServerConfig>) -> Self {
Self {
id: config.id.clone().into(),
config,
@ -151,13 +151,13 @@ impl ContextServer for NativeContextServer {
/// must go through the `GlobalContextServerManager` which holds
/// a model to the ContextServerManager.
pub struct ContextServerManager {
servers: HashMap<String, Arc<dyn ContextServer>>,
pending_servers: HashSet<String>,
servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
pending_servers: HashSet<Arc<str>>,
}
pub enum Event {
ServerStarted { server_id: String },
ServerStopped { server_id: String },
ServerStarted { server_id: Arc<str> },
ServerStopped { server_id: Arc<str> },
}
impl EventEmitter<Event> for ContextServerManager {}
@ -178,10 +178,10 @@ impl ContextServerManager {
pub fn add_server(
&mut self,
config: Arc<ServerConfig>,
server: Arc<dyn ContextServer>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let server_id = config.id.clone();
let server_id = server.id();
if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
return Task::ready(Ok(()));
@ -190,7 +190,6 @@ impl ContextServerManager {
let task = {
let server_id = server_id.clone();
cx.spawn(|this, mut cx| async move {
let server = Arc::new(NativeContextServer::new(config));
server.clone().start(&cx).await?;
this.update(&mut cx, |this, cx| {
this.servers.insert(server_id.clone(), server);
@ -211,14 +210,20 @@ impl ContextServerManager {
self.servers.get(id).cloned()
}
pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
let id = id.to_string();
pub fn remove_server(
&mut self,
id: &Arc<str>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.clone();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
if let Some(server) =
this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
{
server.stop()?;
}
this.update(&mut cx, |this, cx| {
this.pending_servers.remove(&id);
this.pending_servers.remove(id.as_ref());
cx.emit(Event::ServerStopped {
server_id: id.clone(),
})
@ -232,7 +237,7 @@ impl ContextServerManager {
id: &Arc<str>,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let id = id.to_string();
let id = id.clone();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
server.stop()?;
@ -284,7 +289,8 @@ impl ContextServerManager {
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
let server = Arc::new(NativeContextServer::new(Arc::new(config)));
self.add_server(server, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {

View file

@ -0,0 +1,72 @@
use std::sync::Arc;
use anyhow::Result;
use collections::HashMap;
use gpui::{AppContext, AsyncAppContext, ReadGlobal};
use gpui::{Global, Task};
use parking_lot::RwLock;
use crate::ContextServer;
pub type ContextServerFactory =
Arc<dyn Fn(&AsyncAppContext) -> Task<Result<Arc<dyn ContextServer>>> + Send + Sync + 'static>;
#[derive(Default)]
struct GlobalContextServerFactoryRegistry(Arc<ContextServerFactoryRegistry>);
impl Global for GlobalContextServerFactoryRegistry {}
#[derive(Default)]
struct ContextServerFactoryRegistryState {
context_servers: HashMap<Arc<str>, ContextServerFactory>,
}
#[derive(Default)]
pub struct ContextServerFactoryRegistry {
state: RwLock<ContextServerFactoryRegistryState>,
}
impl ContextServerFactoryRegistry {
/// Returns the global [`ContextServerFactoryRegistry`].
pub fn global(cx: &AppContext) -> Arc<Self> {
GlobalContextServerFactoryRegistry::global(cx).0.clone()
}
/// Returns the global [`ContextServerFactoryRegistry`].
///
/// Inserts a default [`ContextServerFactoryRegistry`] if one does not yet exist.
pub fn default_global(cx: &mut AppContext) -> Arc<Self> {
cx.default_global::<GlobalContextServerFactoryRegistry>()
.0
.clone()
}
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: RwLock::new(ContextServerFactoryRegistryState {
context_servers: HashMap::default(),
}),
})
}
pub fn context_server_factories(&self) -> Vec<(Arc<str>, ContextServerFactory)> {
self.state
.read()
.context_servers
.iter()
.map(|(id, factory)| (id.clone(), factory.clone()))
.collect()
}
/// Registers the provided [`ContextServerFactory`].
pub fn register_server_factory(&self, id: Arc<str>, factory: ContextServerFactory) {
let mut state = self.state.write();
state.context_servers.insert(id, factory);
}
/// Unregisters the [`ContextServerFactory`] for the server with the given ID.
pub fn unregister_server_factory_by_id(&self, server_id: &str) {
let mut state = self.state.write();
state.context_servers.remove(server_id);
}
}

View file

@ -75,6 +75,8 @@ pub struct ExtensionManifest {
#[serde(default)]
pub language_servers: BTreeMap<LanguageServerName, LanguageServerManifestEntry>,
#[serde(default)]
pub context_servers: BTreeMap<Arc<str>, ContextServerManifestEntry>,
#[serde(default)]
pub slash_commands: BTreeMap<Arc<str>, SlashCommandManifestEntry>,
#[serde(default)]
pub indexed_docs_providers: BTreeMap<Arc<str>, IndexedDocsProviderEntry>,
@ -134,6 +136,9 @@ impl LanguageServerManifestEntry {
}
}
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct ContextServerManifestEntry {}
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct SlashCommandManifestEntry {
pub description: String,
@ -205,6 +210,7 @@ fn manifest_from_old_manifest(
.map(|grammar_name| (grammar_name, Default::default()))
.collect(),
language_servers: Default::default(),
context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,

View file

@ -129,6 +129,11 @@ pub trait Extension: Send + Sync {
Err("`run_slash_command` not implemented".to_string())
}
/// Returns the command used to start a context server.
fn context_server_command(&mut self, _context_server_id: &ContextServerId) -> Result<Command> {
Err("`context_server_command` not implemented".to_string())
}
/// Returns a list of package names as suggestions to be included in the
/// search results of the `/docs` slash command.
///
@ -270,6 +275,11 @@ impl wit::Guest for Component {
extension().run_slash_command(command, args, worktree)
}
fn context_server_command(context_server_id: String) -> Result<wit::Command> {
let context_server_id = ContextServerId(context_server_id);
extension().context_server_command(&context_server_id)
}
fn suggest_docs_packages(provider: String) -> Result<Vec<String>, String> {
extension().suggest_docs_packages(provider)
}
@ -299,6 +309,22 @@ impl fmt::Display for LanguageServerId {
}
}
/// The ID of a context server.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub struct ContextServerId(String);
impl AsRef<str> for ContextServerId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Display for ContextServerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl CodeLabelSpan {
/// Returns a [`CodeLabelSpan::CodeRange`].
pub fn code_range(range: impl Into<wit::Range>) -> Self {

View file

@ -135,6 +135,9 @@ world extension {
/// Returns the output from running the provided slash command.
export run-slash-command: func(command: slash-command, args: list<string>, worktree: option<borrow<worktree>>) -> result<slash-command-output, string>;
/// Returns the command used to start up a context server.
export context-server-command: func(context-server-id: string) -> result<command, string>;
/// Returns a list of packages as suggestions to be included in the `/docs`
/// search results.
///

View file

@ -145,6 +145,14 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
) {
}
fn register_context_server(
&self,
_id: Arc<str>,
_extension: WasmExtension,
_host: Arc<WasmHost>,
) {
}
fn register_docs_provider(
&self,
_extension: WasmExtension,
@ -1267,6 +1275,14 @@ impl ExtensionStore {
);
}
for (id, _context_server_entry) in &manifest.context_servers {
this.registration_hooks.register_context_server(
id.clone(),
wasm_extension.clone(),
this.wasm_host.clone(),
);
}
for (provider_id, _provider) in &manifest.indexed_docs_providers {
this.registration_hooks.register_docs_provider(
wasm_extension.clone(),

View file

@ -384,6 +384,24 @@ impl Extension {
}
}
pub async fn call_context_server_command(
&self,
store: &mut Store<WasmState>,
context_server_id: Arc<str>,
) -> Result<Result<Command, String>> {
match self {
Extension::V020(ext) => {
ext.call_context_server_command(store, &context_server_id)
.await
}
Extension::V001(_) | Extension::V004(_) | Extension::V006(_) | Extension::V010(_) => {
Err(anyhow!(
"`context_server_command` not available prior to v0.2.0"
))
}
}
}
pub async fn call_suggest_docs_packages(
&self,
store: &mut Store<WasmState>,

View file

@ -20,6 +20,7 @@ assistant_slash_command.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
context_servers.workspace = true
db.workspace = true
editor.workspace = true
extension_host.workspace = true

View file

@ -0,0 +1,80 @@
use std::pin::Pin;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use context_servers::manager::{NativeContextServer, ServerConfig};
use context_servers::protocol::InitializedContextServerProtocol;
use context_servers::ContextServer;
use extension_host::wasm_host::{WasmExtension, WasmHost};
use futures::{Future, FutureExt};
use gpui::AsyncAppContext;
pub struct ExtensionContextServer {
#[allow(unused)]
pub(crate) extension: WasmExtension,
#[allow(unused)]
pub(crate) host: Arc<WasmHost>,
id: Arc<str>,
context_server: Arc<NativeContextServer>,
}
impl ExtensionContextServer {
pub async fn new(extension: WasmExtension, host: Arc<WasmHost>, id: Arc<str>) -> Result<Self> {
let command = extension
.call({
let id = id.clone();
|extension, store| {
async move {
let command = extension
.call_context_server_command(store, id.clone())
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(command)
}
.boxed()
}
})
.await?;
let config = Arc::new(ServerConfig {
id: id.to_string(),
executable: command.command,
args: command.args,
env: Some(command.env.into_iter().collect()),
});
anyhow::Ok(Self {
extension,
host,
id,
context_server: Arc::new(NativeContextServer::new(config)),
})
}
}
#[async_trait(?Send)]
impl ContextServer for ExtensionContextServer {
fn id(&self) -> Arc<str> {
self.id.clone()
}
fn config(&self) -> Arc<ServerConfig> {
self.context_server.config()
}
fn client(&self) -> Option<Arc<InitializedContextServerProtocol>> {
self.context_server.client()
}
fn start<'a>(
self: Arc<Self>,
cx: &'a AsyncAppContext,
) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
self.context_server.clone().start(cx)
}
fn stop(&self) -> Result<()> {
self.context_server.stop()
}
}

View file

@ -2,6 +2,7 @@ use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use assistant_slash_command::SlashCommandRegistry;
use context_servers::ContextServerFactoryRegistry;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs;
use gpui::{AppContext, BackgroundExecutor, Task};
@ -11,6 +12,7 @@ use snippet_provider::SnippetRegistry;
use theme::{ThemeRegistry, ThemeSettings};
use ui::SharedString;
use crate::extension_context_server::ExtensionContextServer;
use crate::{extension_indexed_docs_provider, extension_slash_command::ExtensionSlashCommand};
pub struct ConcreteExtensionRegistrationHooks {
@ -19,6 +21,7 @@ pub struct ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
executor: BackgroundExecutor,
}
@ -29,6 +32,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry: Arc<IndexedDocsRegistry>,
snippet_registry: Arc<SnippetRegistry>,
language_registry: Arc<LanguageRegistry>,
context_server_factory_registry: Arc<ContextServerFactoryRegistry>,
cx: &AppContext,
) -> Arc<dyn extension_host::ExtensionRegistrationHooks> {
Arc::new(Self {
@ -37,6 +41,7 @@ impl ConcreteExtensionRegistrationHooks {
indexed_docs_registry,
snippet_registry,
language_registry,
context_server_factory_registry,
executor: cx.background_executor().clone(),
})
}
@ -69,6 +74,31 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
)
}
fn register_context_server(
&self,
id: Arc<str>,
extension: wasm_host::WasmExtension,
host: Arc<wasm_host::WasmHost>,
) {
self.context_server_factory_registry
.register_server_factory(
id.clone(),
Arc::new({
move |cx| {
let id = id.clone();
let extension = extension.clone();
let host = host.clone();
cx.spawn(|_cx| async move {
let context_server =
ExtensionContextServer::new(extension, host, id).await?;
anyhow::Ok(Arc::new(context_server) as _)
})
}
}),
);
}
fn register_docs_provider(
&self,
extension: wasm_host::WasmExtension,

View file

@ -1,6 +1,7 @@
use assistant_slash_command::SlashCommandRegistry;
use async_compression::futures::bufread::GzipEncoder;
use collections::BTreeMap;
use context_servers::ContextServerFactoryRegistry;
use extension_host::ExtensionSettings;
use extension_host::SchemaVersion;
use extension_host::{
@ -161,6 +162,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
.into_iter()
.collect(),
language_servers: BTreeMap::default(),
context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@ -187,6 +189,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@ -264,6 +267,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
let context_server_factory_registry = ContextServerFactoryRegistry::new();
let node_runtime = NodeRuntime::unavailable();
let store = cx.new_model(|cx| {
@ -273,6 +277,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
indexed_docs_registry.clone(),
snippet_registry.clone(),
language_registry.clone(),
context_server_factory_registry.clone(),
cx,
);
@ -356,6 +361,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
context_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
indexed_docs_providers: BTreeMap::default(),
snippets: None,
@ -406,6 +412,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
indexed_docs_registry,
snippet_registry,
language_registry.clone(),
context_server_factory_registry.clone(),
cx,
);
@ -500,6 +507,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
let slash_command_registry = SlashCommandRegistry::new();
let indexed_docs_registry = Arc::new(IndexedDocsRegistry::new(cx.executor()));
let snippet_registry = Arc::new(SnippetRegistry::new());
let context_server_factory_registry = ContextServerFactoryRegistry::new();
let node_runtime = NodeRuntime::unavailable();
let mut status_updates = language_registry.language_server_binary_statuses();
@ -596,6 +604,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
indexed_docs_registry,
snippet_registry,
language_registry.clone(),
context_server_factory_registry.clone(),
cx,
);
ExtensionStore::new(

View file

@ -1,4 +1,5 @@
mod components;
mod extension_context_server;
mod extension_indexed_docs_provider;
mod extension_registration_hooks;
mod extension_slash_command;

View file

@ -35,6 +35,7 @@ collab_ui.workspace = true
collections.workspace = true
command_palette.workspace = true
command_palette_hooks.workspace = true
context_servers.workspace = true
copilot.workspace = true
db.workspace = true
diagnostics.workspace = true

View file

@ -13,6 +13,7 @@ use clap::{command, Parser};
use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
use client::{parse_zed_link, Client, ProxySettings, UserStore};
use collab_ui::channel_view::ChannelView;
use context_servers::ContextServerFactoryRegistry;
use db::kvp::{GLOBAL_KEY_VALUE_STORE, KEY_VALUE_STORE};
use editor::Editor;
use env_logger::Builder;
@ -411,6 +412,7 @@ fn main() {
IndexedDocsRegistry::global(cx),
SnippetRegistry::global(cx),
app_state.languages.clone(),
ContextServerFactoryRegistry::global(cx),
cx,
);
extension_host::init(