Scope slash commands, context servers, and tools to individual Assistant Panel instances (#20372)

This PR reworks how the Assistant Panel references slash commands,
context servers, and tools.

Previously we were always reading them from the global registries, but
now we store individual collections on each Assistant Panel instance so
that there can be different ones registered for each project.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Joseph <joseph@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Marshall Bowers 2024-11-07 18:23:25 -05:00 committed by GitHub
parent 176314bfd2
commit 7e7f25df6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 592 additions and 397 deletions

1
Cargo.lock generated
View file

@ -2829,7 +2829,6 @@ dependencies = [
"smol",
"url",
"util",
"workspace",
]
[[package]]

View file

@ -12,10 +12,14 @@ mod prompts;
mod slash_command;
pub(crate) mod slash_command_picker;
pub mod slash_command_settings;
mod slash_command_working_set;
mod streaming_diff;
mod terminal_inline_assistant;
mod tool_working_set;
mod tools;
pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet};
pub use crate::tool_working_set::{ToolId, ToolWorkingSet};
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
@ -23,12 +27,11 @@ use assistant_tool::ToolRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
use context_servers::ContextServerRegistry;
pub use context_store::*;
use feature_flags::FeatureFlagAppExt;
use fs::Fs;
use gpui::impl_actions;
use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
use gpui::{impl_actions, Context as _};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::{
@ -43,10 +46,9 @@ use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::search_command::SearchSlashCommandFeatureFlag;
use slash_command::{
auto_command, cargo_workspace_command, context_server_command, default_command, delta_command,
diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command,
prompt_command, search_command, selection_command, symbols_command, tab_command,
terminal_command,
auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command,
docs_command, fetch_command, file_command, now_command, project_command, prompt_command,
search_command, selection_command, symbols_command, tab_command, terminal_command,
};
use std::path::PathBuf;
use std::sync::Arc;
@ -281,116 +283,9 @@ pub fn init(
})
.detach();
register_context_server_handlers(cx);
prompt_builder
}
fn register_context_server_handlers(cx: &mut AppContext) {
cx.subscribe(
&context_servers::manager::ContextServerManager::global(cx),
|manager, event, cx| match event {
context_servers::manager::Event::ServerStarted { server_id } => {
cx.update_model(
&manager,
|manager: &mut context_servers::manager::ContextServerManager, cx| {
let slash_command_registry = SlashCommandRegistry::global(cx);
let context_server_registry = ContextServerRegistry::global(cx);
if let Some(server) = manager.get_server(server_id) {
cx.spawn(|_, _| async move {
let Some(protocol) = server.client.read().clone() else {
return;
};
if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
for prompt in prompts
.into_iter()
.filter(context_server_command::acceptable_prompt)
{
log::info!(
"registering context server command: {:?}",
prompt.name
);
context_server_registry.register_command(
server.id.clone(),
prompt.name.as_str(),
);
slash_command_registry.register_command(
context_server_command::ContextServerSlashCommand::new(
&server, prompt,
),
true,
);
}
}
}
})
.detach();
}
},
);
cx.update_model(
&manager,
|manager: &mut context_servers::manager::ContextServerManager, cx| {
let tool_registry = ToolRegistry::global(cx);
let context_server_registry = ContextServerRegistry::global(cx);
if let Some(server) = manager.get_server(server_id) {
cx.spawn(|_, _| async move {
let Some(protocol) = server.client.read().clone() else {
return;
};
if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
for tool in tools.tools {
log::info!(
"registering context server tool: {:?}",
tool.name
);
context_server_registry.register_tool(
server.id.clone(),
tool.name.as_str(),
);
tool_registry.register_tool(
tools::context_server_tool::ContextServerTool::new(
server.id.clone(),
tool
),
);
}
}
}
})
.detach();
}
},
);
}
context_servers::manager::Event::ServerStopped { server_id } => {
let slash_command_registry = SlashCommandRegistry::global(cx);
let context_server_registry = ContextServerRegistry::global(cx);
if let Some(commands) = context_server_registry.get_commands(server_id) {
for command_name in commands {
slash_command_registry.unregister_command_by_name(&command_name);
context_server_registry.unregister_command(&server_id, &command_name);
}
}
if let Some(tools) = context_server_registry.get_tools(server_id) {
let tool_registry = ToolRegistry::global(cx);
for tool_name in tools {
tool_registry.unregister_tool_by_name(&tool_name);
context_server_registry.unregister_tool(&server_id, &tool_name);
}
}
}
},
)
.detach();
}
fn init_language_model_settings(cx: &mut AppContext) {
update_active_language_model_from_settings(cx);

View file

@ -1,4 +1,6 @@
use crate::slash_command::file_command::codeblock_fence_for_path;
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
humanize_token_count,
@ -7,21 +9,20 @@ use crate::{
slash_command::{
default_command::DefaultSlashCommand,
docs_command::{DocsSlashCommand, DocsSlashCommandArgs},
file_command, SlashCommandCompletionProvider, SlashCommandRegistry,
file_command, SlashCommandCompletionProvider,
},
slash_command_picker,
terminal_inline_assistant::TerminalInlineAssistant,
Assist, AssistantPatch, AssistantPatchStatus, CacheStatus, ConfirmCommand, Content, Context,
ContextEvent, ContextId, ContextStore, ContextStoreEvent, CopyCode, CycleMessageRole,
DeployHistory, DeployPromptLibrary, Edit, InlineAssistant, InsertDraggedFiles,
InsertIntoEditor, InvokedSlashCommandStatus, Message, MessageId, MessageMetadata,
MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, ParsedSlashCommand,
PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, RequestType,
SavedContextMetadata, SlashCommandId, Split, ToggleFocus, ToggleModelSelector,
InsertIntoEditor, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId,
MessageMetadata, MessageStatus, ModelPickerDelegate, ModelSelector, NewContext,
ParsedSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata,
RequestType, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
};
use anyhow::Result;
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use assistant_tool::ToolRegistry;
use client::{proto, zed_urls, Client, Status};
use collections::{hash_map, BTreeSet, HashMap, HashSet};
use editor::{
@ -112,7 +113,8 @@ pub fn init(cx: &mut AppContext) {
.register_action(ContextEditor::copy_code)
.register_action(ContextEditor::insert_dragged_files)
.register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context);
.register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers);
},
)
.detach();
@ -315,10 +317,12 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let tools = Arc::new(ToolWorkingSet::default());
let context_store = workspace
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ContextStore::new(project, prompt_builder.clone(), cx)
ContextStore::new(project, prompt_builder.clone(), slash_commands, tools, cx)
})?
.await?;
@ -1294,6 +1298,24 @@ impl AssistantPanel {
.active_provider()
.map_or(None, |provider| Some(provider.authenticate(cx)))
}
fn restart_context_servers(
workspace: &mut Workspace,
_action: &context_servers::Restart,
cx: &mut ViewContext<Workspace>,
) {
let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
return;
};
assistant_panel.update(cx, |assistant_panel, cx| {
assistant_panel
.context_store
.update(cx, |context_store, cx| {
context_store.restart_context_servers(cx);
});
});
}
}
impl Render for AssistantPanel {
@ -1468,6 +1490,8 @@ enum AssistError {
pub struct ContextEditor {
context: Model<Context>,
fs: Arc<dyn Fs>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
workspace: WeakView<Workspace>,
project: Model<Project>,
lsp_adapter_delegate: Option<Arc<dyn LspAdapterDelegate>>,
@ -1477,7 +1501,7 @@ pub struct ContextEditor {
scroll_position: Option<ScrollPosition>,
remote_id: Option<workspace::ViewId>,
pending_slash_command_creases: HashMap<Range<language::Anchor>, CreaseId>,
invoked_slash_command_creases: HashMap<SlashCommandId, CreaseId>,
invoked_slash_command_creases: HashMap<InvokedSlashCommandId, CreaseId>,
pending_tool_use_creases: HashMap<Range<language::Anchor>, CreaseId>,
_subscriptions: Vec<Subscription>,
patches: HashMap<Range<language::Anchor>, PatchViewState>,
@ -1536,8 +1560,12 @@ impl ContextEditor {
let sections = context.read(cx).slash_command_output_sections().to_vec();
let patch_ranges = context.read(cx).patch_ranges().collect::<Vec<_>>();
let slash_commands = context.read(cx).slash_commands.clone();
let tools = context.read(cx).tools.clone();
let mut this = Self {
context,
slash_commands,
tools,
editor,
lsp_adapter_delegate,
blocks: Default::default(),
@ -1688,7 +1716,7 @@ impl ContextEditor {
}
pub fn insert_command(&mut self, name: &str, cx: &mut ViewContext<Self>) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
if let Some(command) = self.slash_commands.command(name, cx) {
self.editor.update(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| s.try_cancel());
@ -1770,7 +1798,7 @@ impl ContextEditor {
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
if let Some(command) = self.slash_commands.command(name, cx) {
let context = self.context.read(cx);
let sections = context
.slash_command_output_sections()
@ -2043,8 +2071,7 @@ impl ContextEditor {
.collect::<Vec<_>>();
for tool_use in pending_tool_uses {
let tool_registry = ToolRegistry::global(cx);
if let Some(tool) = tool_registry.tool(&tool_use.name) {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), cx);
self.context.update(cx, |context, cx| {
@ -2108,7 +2135,7 @@ impl ContextEditor {
fn update_invoked_slash_command(
&mut self,
command_id: SlashCommandId,
command_id: InvokedSlashCommandId,
cx: &mut ViewContext<Self>,
) {
self.editor.update(cx, |editor, cx| {
@ -3719,6 +3746,19 @@ impl ContextEditor {
})
}
fn render_inject_context_menu(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
slash_command_picker::SlashCommandSelector::new(
self.slash_commands.clone(),
cx.view().downgrade(),
Button::new("trigger", "Add Context")
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)),
)
}
fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
let last_error = self.last_error.as_ref()?;
@ -4133,11 +4173,7 @@ impl Render for ContextEditor {
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(
h_flex()
.gap_1()
.child(render_inject_context_menu(cx.view().downgrade(), cx)),
)
.child(h_flex().gap_1().child(self.render_inject_context_menu(cx)))
.child(
h_flex()
.w_full()
@ -4419,24 +4455,6 @@ pub struct ContextEditorToolbarItem {
model_selector_menu_handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>,
}
fn render_inject_context_menu(
active_context_editor: WeakView<ContextEditor>,
cx: &mut WindowContext<'_>,
) -> impl IntoElement {
let commands = SlashCommandRegistry::global(cx);
slash_command_picker::SlashCommandSelector::new(
commands.clone(),
active_context_editor,
Button::new("trigger", "Add Context")
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)),
)
}
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
@ -5095,7 +5113,7 @@ fn make_lsp_adapter_delegate(
enum PendingSlashCommand {}
fn invoked_slash_command_fold_placeholder(
command_id: SlashCommandId,
command_id: InvokedSlashCommandId,
context: WeakModel<Context>,
) -> FoldPlaceholder {
FoldPlaceholder {

View file

@ -1,6 +1,8 @@
#[cfg(test)]
mod context_tests;
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
prompts::PromptBuilder,
slash_command::{file_command::FileCommandMetadata, SlashCommandLine},
@ -8,10 +10,8 @@ use crate::{
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandRegistry,
SlashCommandResult,
SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandResult,
};
use assistant_tool::ToolRegistry;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
@ -95,13 +95,13 @@ pub enum ContextOperation {
version: clock::Global,
},
SlashCommandStarted {
id: SlashCommandId,
id: InvokedSlashCommandId,
output_range: Range<language::Anchor>,
name: String,
version: clock::Global,
},
SlashCommandFinished {
id: SlashCommandId,
id: InvokedSlashCommandId,
timestamp: clock::Lamport,
error_message: Option<String>,
version: clock::Global,
@ -167,7 +167,7 @@ impl ContextOperation {
}),
proto::context_operation::Variant::SlashCommandStarted(message) => {
Ok(Self::SlashCommandStarted {
id: SlashCommandId(language::proto::deserialize_timestamp(
id: InvokedSlashCommandId(language::proto::deserialize_timestamp(
message.id.context("invalid id")?,
)),
output_range: language::proto::deserialize_anchor_range(
@ -198,7 +198,7 @@ impl ContextOperation {
}
proto::context_operation::Variant::SlashCommandCompleted(message) => {
Ok(Self::SlashCommandFinished {
id: SlashCommandId(language::proto::deserialize_timestamp(
id: InvokedSlashCommandId(language::proto::deserialize_timestamp(
message.id.context("invalid id")?,
)),
timestamp: language::proto::deserialize_timestamp(
@ -372,7 +372,7 @@ pub enum ContextEvent {
updated: Vec<Range<language::Anchor>>,
},
InvokedSlashCommandChanged {
command_id: SlashCommandId,
command_id: InvokedSlashCommandId,
},
ParsedSlashCommandsUpdated {
removed: Vec<Range<language::Anchor>>,
@ -513,7 +513,7 @@ struct PendingCompletion {
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct SlashCommandId(clock::Lamport);
pub struct InvokedSlashCommandId(clock::Lamport);
#[derive(Clone, Debug)]
pub struct XmlTag {
@ -543,8 +543,10 @@ pub struct Context {
operations: Vec<ContextOperation>,
buffer: Model<Buffer>,
parsed_slash_commands: Vec<ParsedSlashCommand>,
invoked_slash_commands: HashMap<SlashCommandId, InvokedSlashCommand>,
invoked_slash_commands: HashMap<InvokedSlashCommandId, InvokedSlashCommand>,
edits_since_last_parse: language::Subscription,
pub(crate) slash_commands: Arc<SlashCommandWorkingSet>,
pub(crate) tools: Arc<ToolWorkingSet>,
slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
message_anchors: Vec<MessageAnchor>,
@ -598,6 +600,8 @@ impl Context {
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
cx: &mut ModelContext<Self>,
) -> Self {
Self::new(
@ -606,6 +610,8 @@ impl Context {
language::Capability::ReadWrite,
language_registry,
prompt_builder,
slash_commands,
tools,
project,
telemetry,
cx,
@ -619,6 +625,8 @@ impl Context {
capability: language::Capability,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -663,6 +671,8 @@ impl Context {
telemetry,
project,
language_registry,
slash_commands,
tools,
patches: Vec::new(),
xml_tags: Vec::new(),
prompt_builder,
@ -738,6 +748,8 @@ impl Context {
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
project: Option<Model<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut ModelContext<Self>,
@ -749,6 +761,8 @@ impl Context {
language::Capability::ReadWrite,
language_registry,
prompt_builder,
slash_commands,
tools,
project,
telemetry,
cx,
@ -1086,7 +1100,7 @@ impl Context {
pub fn invoked_slash_command(
&self,
command_id: &SlashCommandId,
command_id: &InvokedSlashCommandId,
) -> Option<&InvokedSlashCommand> {
self.invoked_slash_commands.get(command_id)
}
@ -1455,7 +1469,7 @@ impl Context {
})
.map(ToOwned::to_owned)
.collect::<SmallVec<_>>();
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
if let Some(command) = self.slash_commands.command(name, cx) {
if !command.requires_argument() || !arguments.is_empty() {
let start_ix = offset + command_line.name.start - 1;
let end_ix = offset
@ -1847,7 +1861,7 @@ impl Context {
cx: &mut ModelContext<Self>,
) {
let version = self.version.clone();
let command_id = SlashCommandId(self.next_timestamp());
let command_id = InvokedSlashCommandId(self.next_timestamp());
const PENDING_OUTPUT_END_MARKER: &str = "";
@ -2227,9 +2241,9 @@ impl Context {
let mut request = self.to_completion_request(request_type, cx);
if cx.has_flag::<ToolUseFeatureFlag>() {
let tool_registry = ToolRegistry::global(cx);
request.tools = tool_registry
.tools()
request.tools = self
.tools
.tools(cx)
.into_iter()
.map(|tool| LanguageModelRequestTool {
name: tool.name(),

View file

@ -1,8 +1,10 @@
use super::{AssistantEdit, MessageCacheMetadata};
use crate::slash_command_working_set::SlashCommandWorkingSet;
use crate::ToolWorkingSet;
use crate::{
assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus,
Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
SlashCommandId,
Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId,
MessageStatus, PromptBuilder,
};
use anyhow::Result;
use assistant_slash_command::{
@ -49,8 +51,17 @@ fn test_inserting_and_removing_messages(cx: &mut AppContext) {
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -182,8 +193,17 @@ fn test_message_splitting(cx: &mut AppContext) {
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -277,8 +297,17 @@ fn test_messages_for_offsets(cx: &mut AppContext) {
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
let message_1 = context.read(cx).message_anchors[0].clone();
@ -383,13 +412,22 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
#[derive(Default)]
struct ContextRanges {
parsed_commands: HashSet<Range<language::Anchor>>,
command_outputs: HashMap<SlashCommandId, Range<language::Anchor>>,
command_outputs: HashMap<InvokedSlashCommandId, Range<language::Anchor>>,
output_sections: HashSet<Range<language::Anchor>>,
}
@ -671,6 +709,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
Some(project),
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
@ -934,6 +974,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
None,
None,
cx,
@ -1042,8 +1084,17 @@ async fn test_serialization(cx: &mut TestAppContext) {
cx.update(assistant_panel::init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry.clone(),
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
let buffer = context.read_with(cx, |context, _| context.buffer.clone());
let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
let message_1 = context.update(cx, |context, cx| {
@ -1083,6 +1134,8 @@ async fn test_serialization(cx: &mut TestAppContext) {
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
None,
None,
cx,
@ -1141,6 +1194,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
language::Capability::ReadWrite,
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
None,
None,
cx,
@ -1394,8 +1449,17 @@ fn test_mark_cache_anchors(cx: &mut AppContext) {
assistant_panel::init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context =
cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
let context = cx.new_model(|cx| {
Context::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
});
let buffer = context.read(cx).buffer.clone();
// Create a test cache configuration

View file

@ -1,10 +1,16 @@
use crate::slash_command::context_server_command;
use crate::{
prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion,
SavedContext, SavedContextMetadata,
prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context,
ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata,
};
use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet};
use anyhow::{anyhow, Context as _, Result};
use client::{proto, telemetry::Telemetry, Client, TypedEnvelope};
use clock::ReplicaId;
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use context_servers::manager::{ContextServerManager, ContextServerSettings};
use context_servers::CONTEXT_SERVERS_NAMESPACE;
use fs::Fs;
use futures::StreamExt;
use fuzzy::StringMatchCandidate;
@ -16,6 +22,7 @@ use paths::contexts_dir;
use project::Project;
use regex::Regex;
use rpc::AnyProtoClient;
use settings::{Settings as _, SettingsStore};
use std::{
cmp::Reverse,
ffi::OsStr,
@ -43,9 +50,14 @@ pub struct RemoteContextMetadata {
pub struct ContextStore {
contexts: Vec<ContextHandle>,
contexts_metadata: Vec<SavedContextMetadata>,
context_server_manager: Model<ContextServerManager>,
context_server_slash_command_ids: HashMap<String, Vec<SlashCommandId>>,
context_server_tool_ids: HashMap<String, Vec<ToolId>>,
host_contexts: Vec<RemoteContextMetadata>,
fs: Arc<dyn Fs>,
languages: Arc<LanguageRegistry>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
telemetry: Arc<Telemetry>,
_watch_updates: Task<Option<()>>,
client: Arc<Client>,
@ -87,6 +99,8 @@ impl ContextStore {
pub fn new(
project: Model<Project>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
tools: Arc<ToolWorkingSet>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
let fs = project.read(cx).fs().clone();
@ -97,12 +111,18 @@ impl ContextStore {
let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await;
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
let context_server_manager = cx.new_model(|_cx| ContextServerManager::new());
let mut this = Self {
contexts: Vec::new(),
contexts_metadata: Vec::new(),
context_server_manager,
context_server_slash_command_ids: HashMap::default(),
context_server_tool_ids: HashMap::default(),
host_contexts: Vec::new(),
fs,
languages,
slash_commands,
tools,
telemetry,
_watch_updates: cx.spawn(|this, mut cx| {
async move {
@ -127,15 +147,48 @@ impl ContextStore {
};
this.handle_project_changed(project, cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);
this
})?;
this.update(&mut cx, |this, cx| this.reload(cx))?
.await
.log_err();
this.update(&mut cx, |this, cx| {
this.watch_context_server_settings(cx);
})
.log_err();
Ok(this)
})
}
fn watch_context_server_settings(&self, cx: &mut ModelContext<Self>) {
cx.observe_global::<SettingsStore>(move |this, cx| {
this.context_server_manager.update(cx, |manager, cx| {
let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
settings::SettingsLocation {
worktree_id: worktree.read(cx).id(),
path: Path::new(""),
}
});
let settings = ContextServerSettings::get(location, cx);
manager.maintain_servers(settings, cx);
let has_any_context_servers = !manager.servers().is_empty();
CommandPaletteFilter::update_global(cx, |filter, _cx| {
if has_any_context_servers {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
} else {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
}
});
})
})
.detach();
}
async fn handle_advertise_contexts(
this: Model<Self>,
envelope: TypedEnvelope<proto::AdvertiseContexts>,
@ -342,6 +395,8 @@ impl ContextStore {
Some(self.project.clone()),
Some(self.telemetry.clone()),
self.prompt_builder.clone(),
self.slash_commands.clone(),
self.tools.clone(),
cx,
)
});
@ -364,6 +419,8 @@ impl ContextStore {
let project = self.project.clone();
let telemetry = self.telemetry.clone();
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
let tools = self.tools.clone();
let request = self.client.request(proto::CreateContext { project_id });
cx.spawn(|this, mut cx| async move {
let response = request.await?;
@ -376,6 +433,8 @@ impl ContextStore {
capability,
language_registry,
prompt_builder,
slash_commands,
tools,
Some(project),
Some(telemetry),
cx,
@ -425,6 +484,8 @@ impl ContextStore {
}
});
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
let tools = self.tools.clone();
cx.spawn(|this, mut cx| async move {
let saved_context = load.await?;
@ -434,6 +495,8 @@ impl ContextStore {
path.clone(),
languages,
prompt_builder,
slash_commands,
tools,
Some(project),
Some(telemetry),
cx,
@ -500,6 +563,8 @@ impl ContextStore {
context_id: context_id.to_proto(),
});
let prompt_builder = self.prompt_builder.clone();
let slash_commands = self.slash_commands.clone();
let tools = self.tools.clone();
cx.spawn(|this, mut cx| async move {
let response = request.await?;
let context_proto = response.context.context("invalid context")?;
@ -510,6 +575,8 @@ impl ContextStore {
capability,
language_registry,
prompt_builder,
slash_commands,
tools,
Some(project),
Some(telemetry),
cx,
@ -745,4 +812,114 @@ impl ContextStore {
})
})
}
pub fn restart_context_servers(&mut self, cx: &mut ModelContext<Self>) {
cx.update_model(
&self.context_server_manager,
|context_server_manager, cx| {
for server in context_server_manager.servers() {
context_server_manager
.restart_server(&server.id, cx)
.detach_and_log_err(cx);
}
},
);
}
fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
Self::handle_context_server_event,
)
.detach();
}
fn handle_context_server_event(
&mut self,
context_server_manager: Model<ContextServerManager>,
event: &context_servers::manager::Event,
cx: &mut ModelContext<Self>,
) {
let slash_command_working_set = self.slash_commands.clone();
let tool_working_set = self.tools.clone();
match event {
context_servers::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
|this, mut cx| async move {
let Some(protocol) = server.client.read().clone() else {
return;
};
if protocol.capable(context_servers::protocol::ServerCapability::Prompts) {
if let Some(prompts) = protocol.list_prompts().await.log_err() {
let slash_command_ids = prompts
.into_iter()
.filter(context_server_command::acceptable_prompt)
.map(|prompt| {
log::info!(
"registering context server command: {:?}",
prompt.name
);
slash_command_working_set.insert(Arc::new(
context_server_command::ContextServerSlashCommand::new(
context_server_manager.clone(),
&server,
prompt,
),
))
})
.collect::<Vec<_>>();
this.update(&mut cx, |this, _cx| {
this.context_server_slash_command_ids
.insert(server_id.clone(), slash_command_ids);
})
.log_err();
}
}
if protocol.capable(context_servers::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools.tools.into_iter().map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(
Arc::new(tools::context_server_tool::ContextServerTool::new(
context_server_manager.clone(),
server.id.clone(),
tool,
)),
)
}).collect::<Vec<_>>();
this.update(&mut cx, |this, _cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
})
.log_err();
}
}
}
})
.detach();
}
}
context_servers::manager::Event::ServerStopped { server_id } => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
slash_command_working_set.remove(&slash_command_ids);
}
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
}
}
}
}
}

View file

@ -8,7 +8,7 @@ use context_servers::{
manager::{ContextServer, ContextServerManager},
types::Prompt,
};
use gpui::{AppContext, Task, WeakView, WindowContext};
use gpui::{AppContext, Model, Task, WeakView, WindowContext};
use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
@ -19,15 +19,21 @@ use workspace::Workspace;
use crate::slash_command::create_label_for_command;
pub struct ContextServerSlashCommand {
server_manager: Model<ContextServerManager>,
server_id: String,
prompt: Prompt,
}
impl ContextServerSlashCommand {
pub fn new(server: &Arc<ContextServer>, prompt: Prompt) -> Self {
pub fn new(
server_manager: Model<ContextServerManager>,
server: &Arc<ContextServer>,
prompt: Prompt,
) -> Self {
Self {
server_id: server.id.clone(),
prompt,
server_manager,
}
}
}
@ -74,18 +80,14 @@ impl SlashCommand for ContextServerSlashCommand {
_workspace: Option<WeakView<Workspace>>,
cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else {
return Task::ready(Err(anyhow!("Failed to complete argument")));
};
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
Ok(tp) => tp,
Err(e) => {
return Task::ready(Err(e));
}
};
if let Some(server) = manager.get_server(&server_id) {
if let Some(server) = self.server_manager.read(cx).get_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client.read().clone() else {
return Err(anyhow!("Context server not initialized"));
@ -100,7 +102,7 @@ impl SlashCommand for ContextServerSlashCommand {
},
),
arg_name,
arg_val,
arg_value,
)
.await?;
@ -138,8 +140,7 @@ impl SlashCommand for ContextServerSlashCommand {
Err(e) => return Task::ready(Err(e)),
};
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
let manager = self.server_manager.read(cx);
if let Some(server) = manager.get_server(&server_id) {
cx.foreground_executor().spawn(async move {
let Some(protocol) = server.client.read().clone() else {

View file

@ -1,16 +1,15 @@
use std::sync::Arc;
use assistant_slash_command::SlashCommandRegistry;
use gpui::{AnyElement, DismissEvent, SharedString, Task, WeakView};
use picker::{Picker, PickerDelegate, PickerEditorPosition};
use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverTrigger};
use crate::assistant_panel::ContextEditor;
use crate::SlashCommandWorkingSet;
#[derive(IntoElement)]
pub(super) struct SlashCommandSelector<T: PopoverTrigger> {
registry: Arc<SlashCommandRegistry>,
working_set: Arc<SlashCommandWorkingSet>,
active_context_editor: WeakView<ContextEditor>,
trigger: T,
}
@ -51,12 +50,12 @@ pub(crate) struct SlashCommandDelegate {
impl<T: PopoverTrigger> SlashCommandSelector<T> {
pub(crate) fn new(
registry: Arc<SlashCommandRegistry>,
working_set: Arc<SlashCommandWorkingSet>,
active_context_editor: WeakView<ContextEditor>,
trigger: T,
) -> Self {
SlashCommandSelector {
registry,
working_set,
active_context_editor,
trigger,
}
@ -231,11 +230,11 @@ impl PickerDelegate for SlashCommandDelegate {
impl<T: PopoverTrigger> RenderOnce for SlashCommandSelector<T> {
fn render(self, cx: &mut WindowContext) -> impl IntoElement {
let all_models = self
.registry
.featured_command_names()
.working_set
.featured_command_names(cx)
.into_iter()
.filter_map(|command_name| {
let command = self.registry.command(&command_name)?;
let command = self.working_set.command(&command_name, cx)?;
let menu_text = SharedString::from(Arc::from(command.menu_text()));
let label = command.label(cx);
let args = label.filter_range.end.ne(&label.text.len()).then(|| {

View file

@ -0,0 +1,66 @@
use assistant_slash_command::{SlashCommand, SlashCommandRegistry};
use collections::HashMap;
use gpui::AppContext;
use parking_lot::Mutex;
use std::sync::Arc;
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct SlashCommandId(usize);
/// A working set of slash commands for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct SlashCommandWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_commands_by_id: HashMap<SlashCommandId, Arc<dyn SlashCommand>>,
context_server_commands_by_name: HashMap<String, Arc<dyn SlashCommand>>,
next_command_id: SlashCommandId,
}
impl SlashCommandWorkingSet {
pub fn command(&self, name: &str, cx: &AppContext) -> Option<Arc<dyn SlashCommand>> {
self.state
.lock()
.context_server_commands_by_name
.get(name)
.cloned()
.or_else(|| SlashCommandRegistry::global(cx).command(name))
}
pub fn featured_command_names(&self, cx: &AppContext) -> Vec<Arc<str>> {
SlashCommandRegistry::global(cx).featured_command_names()
}
pub fn insert(&self, command: Arc<dyn SlashCommand>) -> SlashCommandId {
let mut state = self.state.lock();
let command_id = state.next_command_id;
state.next_command_id.0 += 1;
state
.context_server_commands_by_id
.insert(command_id, command.clone());
state.slash_commands_changed();
command_id
}
pub fn remove(&self, command_ids_to_remove: &[SlashCommandId]) {
let mut state = self.state.lock();
state
.context_server_commands_by_id
.retain(|id, _| !command_ids_to_remove.contains(id));
state.slash_commands_changed();
}
}
impl WorkingSetState {
fn slash_commands_changed(&mut self) {
self.context_server_commands_by_name.clear();
self.context_server_commands_by_name.extend(
self.context_server_commands_by_id
.values()
.map(|command| (command.name(), command.clone())),
);
}
}

View file

@ -0,0 +1,75 @@
use assistant_tool::{Tool, ToolRegistry};
use collections::HashMap;
use gpui::AppContext;
use parking_lot::Mutex;
use std::sync::Arc;
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
next_tool_id: ToolId,
}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &AppContext) -> Option<Arc<dyn Tool>> {
self.state
.lock()
.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &AppContext) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(
self.state
.lock()
.context_server_tools_by_id
.values()
.cloned(),
);
tools
}
pub fn insert(&self, command: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let command_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(command_id, command.clone());
state.tools_changed();
command_id
}
pub fn remove(&self, command_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !command_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|command| (command.name(), command.clone())),
);
}
}

View file

@ -2,16 +2,22 @@ use anyhow::{anyhow, bail};
use assistant_tool::Tool;
use context_servers::manager::ContextServerManager;
use context_servers::types;
use gpui::Task;
use gpui::{Model, Task};
pub struct ContextServerTool {
server_manager: Model<ContextServerManager>,
server_id: String,
tool: types::Tool,
}
impl ContextServerTool {
pub fn new(server_id: impl Into<String>, tool: types::Tool) -> Self {
pub fn new(
server_manager: Model<ContextServerManager>,
server_id: impl Into<String>,
tool: types::Tool,
) -> Self {
Self {
server_manager,
server_id: server_id.into(),
tool,
}
@ -45,9 +51,7 @@ impl Tool for ContextServerTool {
_workspace: gpui::WeakView<workspace::Workspace>,
cx: &mut ui::WindowContext,
) -> gpui::Task<gpui::Result<String>> {
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
if let Some(server) = manager.get_server(&self.server_id) {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
cx.foreground_executor().spawn({
let tool_name = self.tool.name.clone();
async move {

View file

@ -6,7 +6,7 @@ use crate::{
},
};
use anyhow::{anyhow, Result};
use assistant::{ContextStore, PromptBuilder};
use assistant::{ContextStore, PromptBuilder, SlashCommandWorkingSet, ToolWorkingSet};
use call::{room, ActiveCall, ParticipantLocation, Room};
use client::{User, RECEIVE_TIMEOUT};
use collections::{HashMap, HashSet};
@ -6489,11 +6489,27 @@ async fn test_context_collaboration_with_reconnect(
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context_store_a = cx_a
.update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx))
.update(|cx| {
ContextStore::new(
project_a.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
})
.await
.unwrap();
let context_store_b = cx_b
.update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx))
.update(|cx| {
ContextStore::new(
project_b.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
Arc::new(ToolWorkingSet::default()),
cx,
)
})
.await
.unwrap();

View file

@ -27,4 +27,3 @@ settings.workspace = true
smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace.workspace = true

View file

@ -1,40 +1,23 @@
use gpui::{actions, AppContext, Context, ViewContext};
use manager::ContextServerManager;
use workspace::Workspace;
use command_palette_hooks::CommandPaletteFilter;
use gpui::{actions, AppContext};
use settings::Settings;
use crate::manager::ContextServerSettings;
pub mod client;
pub mod manager;
pub mod protocol;
mod registry;
pub mod types;
pub use registry::*;
actions!(context_servers, [Restart]);
/// The namespace for the context servers actions.
const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers";
pub fn init(cx: &mut AppContext) {
log::info!("initializing context server client");
manager::init(cx);
ContextServerRegistry::register(cx);
ContextServerSettings::register(cx);
cx.observe_new_views(
|workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
workspace.register_action(restart_servers);
},
)
.detach();
}
fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext<Workspace>) {
let model = ContextServerManager::global(cx);
cx.update_model(&model, |manager, cx| {
for server in manager.servers() {
manager
.restart_server(&server.id, cx)
.detach_and_log_err(cx);
}
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
});
}

View file

@ -14,18 +14,17 @@
//! The module also includes initialization logic to set up the context server system
//! and react to changes in settings.
use std::path::Path;
use std::sync::Arc;
use collections::{HashMap, HashSet};
use command_palette_hooks::CommandPaletteFilter;
use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
use log;
use parking_lot::RwLock;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources, SettingsStore};
use std::path::Path;
use std::sync::Arc;
use settings::{Settings, SettingsSources};
use crate::CONTEXT_SERVERS_NAMESPACE;
use crate::{
client::{self, Client},
types,
@ -124,7 +123,6 @@ pub enum Event {
ServerStopped { server_id: String },
}
impl Global for ContextServerManager {}
impl EventEmitter<Event> for ContextServerManager {}
impl Default for ContextServerManager {
@ -140,14 +138,11 @@ impl ContextServerManager {
pending_servers: HashSet::default(),
}
}
pub fn global(cx: &AppContext) -> Model<Self> {
cx.global::<GlobalContextServerManager>().0.clone()
}
pub fn add_server(
&mut self,
config: ServerConfig,
cx: &mut ModelContext<Self>,
cx: &ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
let server_id = config.id.clone();
@ -179,11 +174,7 @@ impl ContextServerManager {
self.servers.get(id).cloned()
}
pub fn remove_server(
&mut self,
id: &str,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
let id = id.to_string();
cx.spawn(|this, mut cx| async move {
if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
@ -229,75 +220,38 @@ impl ContextServerManager {
self.servers.values().cloned().collect()
}
pub fn model(cx: &mut AppContext) -> Model<Self> {
cx.new_model(|_cx| ContextServerManager::new())
pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
let current_servers = self
.servers()
.into_iter()
.map(|server| (server.id.clone(), server.config.clone()))
.collect::<HashMap<_, _>>();
let new_servers = settings
.servers
.iter()
.map(|config| (config.id.clone(), config.clone()))
.collect::<HashMap<_, _>>();
let servers_to_add = new_servers
.values()
.filter(|config| !current_servers.contains_key(&config.id))
.cloned()
.collect::<Vec<_>>();
let servers_to_remove = current_servers
.keys()
.filter(|id| !new_servers.contains_key(*id))
.cloned()
.collect::<Vec<_>>();
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
self.add_server(config, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {
self.remove_server(&id, cx).detach_and_log_err(cx);
}
}
}
pub struct GlobalContextServerManager(Model<ContextServerManager>);
impl Global for GlobalContextServerManager {}
impl GlobalContextServerManager {
fn register(cx: &mut AppContext) {
let model = ContextServerManager::model(cx);
cx.set_global(Self(model));
}
}
pub fn init(cx: &mut AppContext) {
ContextServerSettings::register(cx);
GlobalContextServerManager::register(cx);
CommandPaletteFilter::update_global(cx, |filter, _cx| {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
});
cx.observe_global::<SettingsStore>(|cx| {
let manager = ContextServerManager::global(cx);
cx.update_model(&manager, |manager, cx| {
let settings = ContextServerSettings::get_global(cx);
let current_servers = manager
.servers()
.into_iter()
.map(|server| (server.id.clone(), server.config.clone()))
.collect::<HashMap<_, _>>();
let new_servers = settings
.servers
.iter()
.map(|config| (config.id.clone(), config.clone()))
.collect::<HashMap<_, _>>();
let servers_to_add = new_servers
.values()
.filter(|config| !current_servers.contains_key(&config.id))
.cloned()
.collect::<Vec<_>>();
let servers_to_remove = current_servers
.keys()
.filter(|id| !new_servers.contains_key(*id))
.cloned()
.collect::<Vec<_>>();
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
manager.add_server(config, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {
manager.remove_server(&id, cx).detach_and_log_err(cx);
}
let has_any_context_servers = !manager.servers().is_empty();
CommandPaletteFilter::update_global(cx, |filter, _cx| {
if has_any_context_servers {
filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
} else {
filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
}
});
})
})
.detach();
}

View file

@ -1,69 +0,0 @@
use std::sync::Arc;
use collections::HashMap;
use gpui::{AppContext, Global, ReadGlobal};
use parking_lot::RwLock;
struct GlobalContextServerRegistry(Arc<ContextServerRegistry>);
impl Global for GlobalContextServerRegistry {}
pub struct ContextServerRegistry {
command_registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
tool_registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
}
impl ContextServerRegistry {
pub fn global(cx: &AppContext) -> Arc<Self> {
GlobalContextServerRegistry::global(cx).0.clone()
}
pub fn register(cx: &mut AppContext) {
cx.set_global(GlobalContextServerRegistry(Arc::new(
ContextServerRegistry {
command_registry: RwLock::new(HashMap::default()),
tool_registry: RwLock::new(HashMap::default()),
},
)))
}
pub fn register_command(&self, server_id: String, command_name: &str) {
let mut registry = self.command_registry.write();
registry
.entry(server_id)
.or_default()
.push(command_name.into());
}
pub fn unregister_command(&self, server_id: &str, command_name: &str) {
let mut registry = self.command_registry.write();
if let Some(commands) = registry.get_mut(server_id) {
commands.retain(|name| name.as_ref() != command_name);
}
}
pub fn get_commands(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
let registry = self.command_registry.read();
registry.get(server_id).cloned()
}
pub fn register_tool(&self, server_id: String, tool_name: &str) {
let mut registry = self.tool_registry.write();
registry
.entry(server_id)
.or_default()
.push(tool_name.into());
}
pub fn unregister_tool(&self, server_id: &str, tool_name: &str) {
let mut registry = self.tool_registry.write();
if let Some(tools) = registry.get_mut(server_id) {
tools.retain(|name| name.as_ref() != tool_name);
}
}
pub fn get_tools(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
let registry = self.tool_registry.read();
registry.get(server_id).cloned()
}
}