Allow defining slash commands in extensions (#12255)

This PR adds initial support for defining slash commands for the
Assistant from extensions.

Slash commands are defined in an extension's `extension.toml`:

```toml
[slash_commands.gleam-project]
description = "Returns information about the current Gleam project."
requires_argument = false
```

and then executed via the `run_slash_command` method on the `Extension`
trait:

```rs
impl Extension for GleamExtension {
    // ...

    fn run_slash_command(
        &self,
        command: SlashCommand,
        _argument: Option<String>,
        worktree: &zed::Worktree,
    ) -> Result<Option<String>, String> {
        match command.name.as_str() {
            "gleam-project" => Ok(Some("Yayyy".to_string())),
            command => Err(format!("unknown slash command: \"{command}\"")),
        }
    }
}
```

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-05-24 15:44:32 -04:00 committed by GitHub
parent 055a13a9b6
commit 82f5f36422
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 310 additions and 14 deletions

View file

@ -14,6 +14,7 @@ doctest = false
[dependencies]
anyhow.workspace = true
assistant_slash_command.workspace = true
async-compression.workspace = true
async-tar.workspace = true
async-trait.workspace = true

View file

@ -74,6 +74,8 @@ pub struct ExtensionManifest {
pub grammars: BTreeMap<Arc<str>, GrammarManifestEntry>,
#[serde(default)]
pub language_servers: BTreeMap<LanguageServerName, LanguageServerManifestEntry>,
#[serde(default)]
pub slash_commands: BTreeMap<Arc<str>, SlashCommandManifestEntry>,
}
#[derive(Clone, Default, PartialEq, Eq, Debug, Deserialize, Serialize)]
@ -128,6 +130,12 @@ impl LanguageServerManifestEntry {
}
}
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct SlashCommandManifestEntry {
pub description: String,
pub requires_argument: bool,
}
impl ExtensionManifest {
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
let extension_name = extension_dir
@ -190,5 +198,6 @@ fn manifest_from_old_manifest(
.map(|grammar_name| (grammar_name, Default::default()))
.collect(),
language_servers: Default::default(),
slash_commands: BTreeMap::default(),
}
}

View file

@ -0,0 +1,85 @@
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandCleanup, SlashCommandInvocation};
use futures::channel::oneshot;
use futures::FutureExt;
use gpui::{AppContext, Task};
use language::LspAdapterDelegate;
use wasmtime_wasi::WasiView;
use crate::wasm_host::{WasmExtension, WasmHost};
pub struct ExtensionSlashCommand {
pub(crate) extension: WasmExtension,
#[allow(unused)]
pub(crate) host: Arc<WasmHost>,
pub(crate) command: crate::wit::SlashCommand,
}
impl SlashCommand for ExtensionSlashCommand {
fn name(&self) -> String {
self.command.name.clone()
}
fn description(&self) -> String {
self.command.description.clone()
}
fn requires_argument(&self) -> bool {
self.command.requires_argument
}
fn complete_argument(
&self,
_query: String,
_cancel: Arc<AtomicBool>,
_cx: &mut AppContext,
) -> Task<Result<Vec<String>>> {
Task::ready(Ok(Vec::new()))
}
fn run(
self: Arc<Self>,
argument: Option<&str>,
delegate: Arc<dyn LspAdapterDelegate>,
cx: &mut AppContext,
) -> SlashCommandInvocation {
let argument = argument.map(|arg| arg.to_string());
let output = cx.background_executor().spawn(async move {
let output = self
.extension
.call({
let this = self.clone();
move |extension, store| {
async move {
let resource = store.data_mut().table().push(delegate)?;
let output = extension
.call_run_slash_command(
store,
&this.command,
argument.as_deref(),
resource,
)
.await?
.map_err(|e| anyhow!("{}", e))?;
anyhow::Ok(output)
}
.boxed()
}
})
.await?;
output.ok_or_else(|| anyhow!("no output from command: {}", self.command.name))
});
SlashCommandInvocation {
output,
invalidated: oneshot::channel().1,
cleanup: SlashCommandCleanup::default(),
}
}
}

View file

@ -2,14 +2,17 @@ pub mod extension_builder;
mod extension_lsp_adapter;
mod extension_manifest;
mod extension_settings;
mod extension_slash_command;
mod wasm_host;
#[cfg(test)]
mod extension_store_test;
use crate::extension_manifest::SchemaVersion;
use crate::extension_slash_command::ExtensionSlashCommand;
use crate::{extension_lsp_adapter::ExtensionLspAdapter, wasm_host::wit};
use anyhow::{anyhow, bail, Context as _, Result};
use assistant_slash_command::SlashCommandRegistry;
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use client::{telemetry::Telemetry, Client, ExtensionMetadata, GetExtensionsResponse};
@ -107,6 +110,7 @@ pub struct ExtensionStore {
index_path: PathBuf,
language_registry: Arc<LanguageRegistry>,
theme_registry: Arc<ThemeRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
modified_extensions: HashSet<Arc<str>>,
wasm_host: Arc<WasmHost>,
wasm_extensions: Vec<(Arc<ExtensionManifest>, WasmExtension)>,
@ -183,6 +187,7 @@ pub fn init(
node_runtime,
language_registry,
theme_registry,
SlashCommandRegistry::global(cx),
cx,
)
});
@ -215,6 +220,7 @@ impl ExtensionStore {
node_runtime: Arc<dyn NodeRuntime>,
language_registry: Arc<LanguageRegistry>,
theme_registry: Arc<ThemeRegistry>,
slash_command_registry: Arc<SlashCommandRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
let work_dir = extensions_dir.join("work");
@ -245,6 +251,7 @@ impl ExtensionStore {
telemetry,
language_registry,
theme_registry,
slash_command_registry,
reload_tx,
tasks: Vec::new(),
};
@ -1169,6 +1176,19 @@ impl ExtensionStore {
);
}
}
for (slash_command_name, slash_command) in &manifest.slash_commands {
this.slash_command_registry
.register_command(ExtensionSlashCommand {
command: crate::wit::SlashCommand {
name: slash_command_name.to_string(),
description: slash_command.description.to_string(),
requires_argument: slash_command.requires_argument,
},
extension: wasm_extension.clone(),
host: this.wasm_host.clone(),
});
}
}
this.wasm_extensions.extend(wasm_extensions);
ThemeSettings::reload_current_theme(cx)

View file

@ -5,6 +5,7 @@ use crate::{
ExtensionIndexThemeEntry, ExtensionManifest, ExtensionStore, GrammarManifestEntry,
RELOAD_DEBOUNCE_DURATION,
};
use assistant_slash_command::SlashCommandRegistry;
use async_compression::futures::bufread::GzipEncoder;
use collections::BTreeMap;
use fs::{FakeFs, Fs, RealFs};
@ -156,6 +157,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
.into_iter()
.collect(),
language_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
}),
dev: false,
},
@ -179,6 +181,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
}),
dev: false,
},
@ -250,6 +253,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
let theme_registry = Arc::new(ThemeRegistry::new(Box::new(())));
let slash_command_registry = SlashCommandRegistry::new();
let node_runtime = FakeNodeRuntime::new();
let store = cx.new_model(|cx| {
@ -262,6 +266,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
node_runtime.clone(),
language_registry.clone(),
theme_registry.clone(),
slash_command_registry.clone(),
cx,
)
});
@ -333,6 +338,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
languages: Default::default(),
grammars: BTreeMap::default(),
language_servers: BTreeMap::default(),
slash_commands: BTreeMap::default(),
}),
dev: false,
},
@ -382,6 +388,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
node_runtime.clone(),
language_registry.clone(),
theme_registry.clone(),
slash_command_registry,
cx,
)
});
@ -460,6 +467,7 @@ async fn test_extension_store_with_gleam_extension(cx: &mut TestAppContext) {
let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
let theme_registry = Arc::new(ThemeRegistry::new(Box::new(())));
let slash_command_registry = SlashCommandRegistry::new();
let node_runtime = FakeNodeRuntime::new();
let mut status_updates = language_registry.language_server_binary_statuses();
@ -541,6 +549,7 @@ async fn test_extension_store_with_gleam_extension(cx: &mut TestAppContext) {
node_runtime,
language_registry.clone(),
theme_registry.clone(),
slash_command_registry,
cx,
)
});

View file

@ -19,7 +19,7 @@ use wasmtime::{
pub use latest::CodeLabelSpanLiteral;
pub use latest::{
zed::extension::lsp::{Completion, CompletionKind, InsertTextFormat, Symbol, SymbolKind},
CodeLabel, CodeLabelSpan, Command, Range,
CodeLabel, CodeLabelSpan, Command, Range, SlashCommand,
};
pub use since_v0_0_4::LanguageServerConfig;
@ -255,6 +255,22 @@ impl Extension {
Extension::V001(_) | Extension::V004(_) => Ok(Ok(Vec::new())),
}
}
pub async fn call_run_slash_command(
&self,
store: &mut Store<WasmState>,
command: &SlashCommand,
argument: Option<&str>,
resource: Resource<Arc<dyn LspAdapterDelegate>>,
) -> Result<Result<Option<String>, String>> {
match self {
Extension::V007(ext) => {
ext.call_run_slash_command(store, command, argument, resource)
.await
}
Extension::V001(_) | Extension::V004(_) | Extension::V006(_) => Ok(Ok(None)),
}
}
}
trait ToWasmtimeResult<T> {

View file

@ -222,6 +222,9 @@ impl platform::Host for WasmState {
}
}
#[async_trait]
impl slash_command::Host for WasmState {}
#[async_trait]
impl ExtensionImports for WasmState {
async fn get_settings(