Use Extension trait when registering extension context servers (#21070)

This PR updates the extension context server registration to go through
the `Extension` trait for interacting with extensions rather than going
through the `WasmHost` directly.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-11-22 13:21:30 -05:00 committed by GitHub
parent ca76948044
commit cb8028c092
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 59 additions and 46 deletions

1
Cargo.lock generated
View file

@ -4237,7 +4237,6 @@ dependencies = [
"ui", "ui",
"util", "util",
"vim_mode_setting", "vim_mode_setting",
"wasmtime-wasi",
"workspace", "workspace",
"zed_actions", "zed_actions",
] ]

View file

@ -25,6 +25,10 @@ pub trait WorktreeDelegate: Send + Sync + 'static {
async fn shell_env(&self) -> Vec<(String, String)>; async fn shell_env(&self) -> Vec<(String, String)>;
} }
pub trait ProjectDelegate: Send + Sync + 'static {
fn worktree_ids(&self) -> Vec<u64>;
}
pub trait KeyValueStoreDelegate: Send + Sync + 'static { pub trait KeyValueStoreDelegate: Send + Sync + 'static {
fn insert(&self, key: String, docs: String) -> Task<Result<()>>; fn insert(&self, key: String, docs: String) -> Task<Result<()>>;
} }
@ -87,6 +91,12 @@ pub trait Extension: Send + Sync + 'static {
worktree: Option<Arc<dyn WorktreeDelegate>>, worktree: Option<Arc<dyn WorktreeDelegate>>,
) -> Result<SlashCommandOutput>; ) -> Result<SlashCommandOutput>;
async fn context_server_command(
&self,
context_server_id: Arc<str>,
project: Arc<dyn ProjectDelegate>,
) -> Result<Command>;
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>>; async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>>;
async fn index_docs( async fn index_docs(

View file

@ -10,6 +10,7 @@ pub use slash_command::*;
pub type EnvVars = Vec<(String, String)>; pub type EnvVars = Vec<(String, String)>;
/// A command. /// A command.
#[derive(Debug)]
pub struct Command { pub struct Command {
/// The command to execute. /// The command to execute.
pub command: String, pub command: String,

View file

@ -149,8 +149,8 @@ pub trait ExtensionRegistrationHooks: Send + Sync + 'static {
fn register_context_server( fn register_context_server(
&self, &self,
_extension: Arc<dyn Extension>,
_id: Arc<str>, _id: Arc<str>,
_extension: WasmExtension,
_cx: &mut AppContext, _cx: &mut AppContext,
) { ) {
} }
@ -1284,8 +1284,8 @@ impl ExtensionStore {
for (id, _context_server_entry) in &manifest.context_servers { for (id, _context_server_entry) in &manifest.context_servers {
this.registration_hooks.register_context_server( this.registration_hooks.register_context_server(
extension.clone(),
id.clone(), id.clone(),
wasm_extension.clone(),
cx, cx,
); );
} }

View file

@ -4,7 +4,7 @@ use crate::{ExtensionManifest, ExtensionRegistrationHooks};
use anyhow::{anyhow, bail, Context as _, Result}; use anyhow::{anyhow, bail, Context as _, Result};
use async_trait::async_trait; use async_trait::async_trait;
use extension::{ use extension::{
CodeLabel, Command, Completion, KeyValueStoreDelegate, SlashCommand, CodeLabel, Command, Completion, KeyValueStoreDelegate, ProjectDelegate, SlashCommand,
SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate, SlashCommandArgumentCompletion, SlashCommandOutput, Symbol, WorktreeDelegate,
}; };
use fs::{normalize_path, Fs}; use fs::{normalize_path, Fs};
@ -34,7 +34,6 @@ use wasmtime::{
}; };
use wasmtime_wasi::{self as wasi, WasiView}; use wasmtime_wasi::{self as wasi, WasiView};
use wit::Extension; use wit::Extension;
pub use wit::ExtensionProject;
pub struct WasmHost { pub struct WasmHost {
engine: Engine, engine: Engine,
@ -238,6 +237,25 @@ impl extension::Extension for WasmExtension {
.await .await
} }
async fn context_server_command(
&self,
context_server_id: Arc<str>,
project: Arc<dyn ProjectDelegate>,
) -> Result<Command> {
self.call(|extension, store| {
async move {
let project_resource = store.data_mut().table().push(project)?;
let command = extension
.call_context_server_command(store, context_server_id.clone(), project_resource)
.await?
.map_err(|err| anyhow!("{err}"))?;
anyhow::Ok(command.into())
}
.boxed()
})
.await
}
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> { async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
self.call(|extension, store| { self.call(|extension, store| {
async move { async move {

View file

@ -8,7 +8,7 @@ use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive; use async_tar::Archive;
use async_trait::async_trait; use async_trait::async_trait;
use context_servers::manager::ContextServerSettings; use context_servers::manager::ContextServerSettings;
use extension::{KeyValueStoreDelegate, WorktreeDelegate}; use extension::{KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate};
use futures::{io::BufReader, FutureExt as _}; use futures::{io::BufReader, FutureExt as _};
use futures::{lock::Mutex, AsyncReadExt}; use futures::{lock::Mutex, AsyncReadExt};
use language::{language_settings::AllLanguageSettings, LanguageName, LanguageServerBinaryStatus}; use language::{language_settings::AllLanguageSettings, LanguageName, LanguageServerBinaryStatus};
@ -44,13 +44,10 @@ mod settings {
} }
pub type ExtensionWorktree = Arc<dyn WorktreeDelegate>; pub type ExtensionWorktree = Arc<dyn WorktreeDelegate>;
pub type ExtensionProject = Arc<dyn ProjectDelegate>;
pub type ExtensionKeyValueStore = Arc<dyn KeyValueStoreDelegate>; pub type ExtensionKeyValueStore = Arc<dyn KeyValueStoreDelegate>;
pub type ExtensionHttpResponseStream = Arc<Mutex<::http_client::Response<AsyncBody>>>; pub type ExtensionHttpResponseStream = Arc<Mutex<::http_client::Response<AsyncBody>>>;
pub struct ExtensionProject {
pub worktree_ids: Vec<u64>,
}
pub fn linker() -> &'static Linker<WasmState> { pub fn linker() -> &'static Linker<WasmState> {
static LINKER: OnceLock<Linker<WasmState>> = OnceLock::new(); static LINKER: OnceLock<Linker<WasmState>> = OnceLock::new();
LINKER.get_or_init(|| super::new_linker(Extension::add_to_linker)) LINKER.get_or_init(|| super::new_linker(Extension::add_to_linker))
@ -273,7 +270,7 @@ impl HostProject for WasmState {
project: Resource<ExtensionProject>, project: Resource<ExtensionProject>,
) -> wasmtime::Result<Vec<u64>> { ) -> wasmtime::Result<Vec<u64>> {
let project = self.table.get(&project)?; let project = self.table.get(&project)?;
Ok(project.worktree_ids.clone()) Ok(project.worktree_ids())
} }
fn drop(&mut self, _project: Resource<Project>) -> Result<()> { fn drop(&mut self, _project: Resource<Project>) -> Result<()> {

View file

@ -41,7 +41,6 @@ theme.workspace = true
ui.workspace = true ui.workspace = true
util.workspace = true util.workspace = true
vim_mode_setting.workspace = true vim_mode_setting.workspace = true
wasmtime-wasi.workspace = true
workspace.workspace = true workspace.workspace = true
zed_actions.workspace = true zed_actions.workspace = true

View file

@ -1,13 +1,11 @@
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
use anyhow::{anyhow, Result}; use anyhow::Result;
use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry}; use assistant_slash_command::{ExtensionSlashCommand, SlashCommandRegistry};
use context_servers::manager::ServerCommand; use context_servers::manager::ServerCommand;
use context_servers::ContextServerFactoryRegistry; use context_servers::ContextServerFactoryRegistry;
use db::smol::future::FutureExt as _; use extension::{Extension, ProjectDelegate};
use extension::Extension; use extension_host::extension_lsp_adapter::ExtensionLspAdapter;
use extension_host::wasm_host::ExtensionProject;
use extension_host::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host};
use fs::Fs; use fs::Fs;
use gpui::{AppContext, BackgroundExecutor, Model, Task}; use gpui::{AppContext, BackgroundExecutor, Model, Task};
use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId}; use indexed_docs::{ExtensionIndexedDocsProvider, IndexedDocsRegistry, ProviderId};
@ -16,7 +14,16 @@ use lsp::LanguageServerName;
use snippet_provider::SnippetRegistry; use snippet_provider::SnippetRegistry;
use theme::{ThemeRegistry, ThemeSettings}; use theme::{ThemeRegistry, ThemeSettings};
use ui::SharedString; use ui::SharedString;
use wasmtime_wasi::WasiView as _;
struct ExtensionProject {
worktree_ids: Vec<u64>,
}
impl ProjectDelegate for ExtensionProject {
fn worktree_ids(&self) -> Vec<u64> {
self.worktree_ids.clone()
}
}
pub struct ConcreteExtensionRegistrationHooks { pub struct ConcreteExtensionRegistrationHooks {
slash_command_registry: Arc<SlashCommandRegistry>, slash_command_registry: Arc<SlashCommandRegistry>,
@ -72,8 +79,8 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
fn register_context_server( fn register_context_server(
&self, &self,
extension: Arc<dyn Extension>,
id: Arc<str>, id: Arc<str>,
extension: wasm_host::WasmExtension,
cx: &mut AppContext, cx: &mut AppContext,
) { ) {
self.context_server_factory_registry self.context_server_factory_registry
@ -84,42 +91,24 @@ impl extension_host::ExtensionRegistrationHooks for ConcreteExtensionRegistratio
move |project, cx| { move |project, cx| {
log::info!( log::info!(
"loading command for context server {id} from extension {}", "loading command for context server {id} from extension {}",
extension.manifest.id extension.manifest().id
); );
let id = id.clone(); let id = id.clone();
let extension = extension.clone(); let extension = extension.clone();
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
let extension_project = let extension_project =
project.update(&mut cx, |project, cx| ExtensionProject { project.update(&mut cx, |project, cx| {
Arc::new(ExtensionProject {
worktree_ids: project worktree_ids: project
.visible_worktrees(cx) .visible_worktrees(cx)
.map(|worktree| worktree.read(cx).id().to_proto()) .map(|worktree| worktree.read(cx).id().to_proto())
.collect(), .collect(),
})
})?; })?;
let command = extension let command = extension
.call({ .context_server_command(id.clone(), extension_project)
let id = id.clone();
|extension, store| {
async move {
let project = store
.data_mut()
.table()
.push(extension_project)?;
let command = extension
.call_context_server_command(
store,
id.clone(),
project,
)
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(command)
}
.boxed()
}
})
.await?; .await?;
log::info!("loaded command for context server {id}: {command:?}"); log::info!("loaded command for context server {id}: {command:?}");