Scope slash commands, context servers, and tools to individual Assistant Panel instances (#20372)

This PR reworks how the Assistant Panel references slash commands,
context servers, and tools.

Previously we were always reading them from the global registries, but
now we store individual collections on each Assistant Panel instance so
that there can be different ones registered for each project.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Joseph <joseph@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Marshall Bowers 2024-11-07 18:23:25 -05:00 committed by GitHub
parent 176314bfd2
commit 7e7f25df6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 592 additions and 397 deletions

View file

@ -1,6 +1,8 @@
#[cfg(test)]
mod context_tests;
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
prompts::PromptBuilder,
slash_command::{file_command::FileCommandMetadata, SlashCommandLine},
@ -8,10 +10,8 @@ use crate::{
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandRegistry,
SlashCommandResult,
SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandResult,
};
use assistant_tool::ToolRegistry;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
@ -95,13 +95,13 @@ pub enum ContextOperation {
version: clock::Global,
},
SlashCommandStarted {
id: SlashCommandId,
id: InvokedSlashCommandId,
output_range: Range<language::Anchor>,
name: String,
version: clock::Global,
},
SlashCommandFinished {
id: SlashCommandId,
id: InvokedSlashCommandId,
timestamp: clock::Lamport,
error_message: Option<String>,
version: clock::Global,
@ -167,7 +167,7 @@ impl ContextOperation {
}),
proto::context_operation::Variant::SlashCommandStarted(message) => {
Ok(Self::SlashCommandStarted {
id: SlashCommandId(language::proto::deserialize_timestamp(
id: InvokedSlashCommandId(language::proto::deserialize_timestamp(
message.id.context("invalid id")?,
)),
output_range: language::proto::deserialize_anchor_range(
@ -198,7 +198,7 @@ impl ContextOperation {
}
proto::context_operation::Variant::SlashCommandCompleted(message) => {
Ok(Self::SlashCommandFinished {
id: SlashCommandId(language::proto::deserialize_timestamp(
id: InvokedSlashCommandId(language::proto::deserialize_timestamp(
message.id.context("invalid id")?,
)),
timestamp: language::proto::deserialize_timestamp(
@ -372,7 +372,7 @@ pub enum ContextEvent {
updated: Vec<Range<language::Anchor>>,
},
InvokedSlashCommandChanged {
command_id: SlashCommandId,
command_id: InvokedSlashCommandId,
},
ParsedSlashCommandsUpdated {
removed: Vec<Range<language::Anchor>>,
@ -513,7 +513,7 @@ struct PendingCompletion {
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct SlashCommandId(clock::Lamport);
pub struct InvokedSlashCommandId(clock::Lamport);
#[derive(Clone, Debug)]
pub struct XmlTag {
@ -543,8 +543,10 @@ pub struct Context {
operations: Vec<ContextOperation>,
buffer: Model<Buffer>,
parsed_slash_commands: Vec<ParsedSlashCommand>,
invoked_slash_commands: HashMap<SlashCommandId, InvokedSlashCommand>,
invoked_slash_commands: HashMap<InvokedSlashCommandId, InvokedSlashCommand>,
edits_since_last_parse: language::Subscription,
pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
pub(crate) tools: Arc<ToolWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
@ -598,6 +600,8 @@ impl Context {
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
cx: &mut ModelContext<Self>,
) -> Self {
Self::new(
@ -606,6 +610,8 @@ impl Context {
language::Capability::ReadWrite,
language_registry,
prompt_builder,
slash_commands,
tools,
project,
telemetry,
cx,
@ -619,6 +625,8 @@ impl Context {
capability: language::Capability,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -663,6 +671,8 @@ impl Context {
telemetry,
project,
language_registry,
slash_commands,
tools,
patches: Vec::new(),
xml_tags: Vec::new(),
prompt_builder,
@ -738,6 +748,8 @@ impl Context {
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -749,6 +761,8 @@ impl Context {
language::Capability::ReadWrite,
language_registry,
prompt_builder,
slash_commands,
tools,
project,
telemetry,
cx,
@ -1086,7 +1100,7 @@ impl Context {
pub fn invoked_slash_command(
&self,
command_id: &SlashCommandId,
command_id: &InvokedSlashCommandId,
) -> Option<&InvokedSlashCommand> {
self.invoked_slash_commands.get(command_id)
}
@ -1455,7 +1469,7 @@ impl Context {
})
.map(ToOwned::to_owned)
.collect::<SmallVec<_>>();
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
if let Some(command) = self.slash_commands.command(name, cx) {
if !command.requires_argument() || !arguments.is_empty() {
let start_ix = offset + command_line.name.start - 1;
let end_ix = offset
@ -1847,7 +1861,7 @@ impl Context {
cx: &mut ModelContext<Self>,
) {
let version = self.version.clone();
let command_id = SlashCommandId(self.next_timestamp());
let command_id = InvokedSlashCommandId(self.next_timestamp());
const PENDING_OUTPUT_END_MARKER: &str = "";
@ -2227,9 +2241,9 @@ impl Context {
let mut request = self.to_completion_request(request_type, cx);
if cx.has_flag::<ToolUseFeatureFlag>() {
let tool_registry = ToolRegistry::global(cx);
request.tools = tool_registry
.tools()
request.tools = self
.tools
.tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),