From 7e7f25df6cfed456c44a09eb1b320738394461b5 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 7 Nov 2024 18:23:25 -0500 Subject: [PATCH] Scope slash commands, context servers, and tools to individual Assistant Panel instances (#20372) This PR reworks how the Assistant Panel references slash commands, context servers, and tools. Previously we were always reading them from the global registries, but now we store individual collections on each Assistant Panel instance so that there can be different ones registered for each project. Release Notes: - N/A --------- Co-authored-by: Max Co-authored-by: Antonio Co-authored-by: Joseph Co-authored-by: Max Brunsfeld --- Cargo.lock | 1 - crates/assistant/src/assistant.rs | 121 +----------- crates/assistant/src/assistant_panel.rs | 94 +++++---- crates/assistant/src/context.rs | 46 +++-- crates/assistant/src/context/context_tests.rs | 94 +++++++-- crates/assistant/src/context_store.rs | 181 +++++++++++++++++- .../slash_command/context_server_command.rs | 29 +-- crates/assistant/src/slash_command_picker.rs | 15 +- .../src/slash_command_working_set.rs | 66 +++++++ crates/assistant/src/tool_working_set.rs | 75 ++++++++ .../src/tools/context_server_tool.rs | 14 +- crates/collab/src/tests/integration_tests.rs | 22 ++- crates/context_servers/Cargo.toml | 1 - crates/context_servers/src/context_servers.rs | 35 +--- crates/context_servers/src/manager.rs | 126 ++++-------- crates/context_servers/src/registry.rs | 69 ------- 16 files changed, 592 insertions(+), 397 deletions(-) create mode 100644 crates/assistant/src/slash_command_working_set.rs create mode 100644 crates/assistant/src/tool_working_set.rs delete mode 100644 crates/context_servers/src/registry.rs diff --git a/Cargo.lock b/Cargo.lock index 56ccdae0a3..604c5da77a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2829,7 +2829,6 @@ dependencies = [ "smol", "url", "util", - "workspace", ] [[package]] diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index ef059e2090..236311753f 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,10 +12,14 @@ mod prompts; mod slash_command; pub(crate) mod slash_command_picker; pub mod slash_command_settings; +mod slash_command_working_set; mod streaming_diff; mod terminal_inline_assistant; +mod tool_working_set; mod tools; +pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet}; +pub use crate::tool_working_set::{ToolId, ToolWorkingSet}; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; @@ -23,12 +27,11 @@ use assistant_tool::ToolRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub use context::*; -use context_servers::ContextServerRegistry; pub use context_store::*; use feature_flags::FeatureFlagAppExt; use fs::Fs; +use gpui::impl_actions; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; -use gpui::{impl_actions, Context as _}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; use language_model::{ @@ -43,10 +46,9 @@ use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::search_command::SearchSlashCommandFeatureFlag; use slash_command::{ - auto_command, cargo_workspace_command, context_server_command, default_command, delta_command, - diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command, - prompt_command, search_command, selection_command, symbols_command, tab_command, - terminal_command, + auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command, + docs_command, fetch_command, file_command, now_command, project_command, prompt_command, + search_command, selection_command, symbols_command, tab_command, terminal_command, }; use std::path::PathBuf; use std::sync::Arc; @@ -281,116 +283,9 @@ pub fn init( }) .detach(); - register_context_server_handlers(cx); - prompt_builder } -fn register_context_server_handlers(cx: &mut AppContext) { - cx.subscribe( - &context_servers::manager::ContextServerManager::global(cx), - |manager, event, cx| match event { - context_servers::manager::Event::ServerStarted { server_id } => { - cx.update_model( - &manager, - |manager: &mut context_servers::manager::ContextServerManager, cx| { - let slash_command_registry = SlashCommandRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(server) = manager.get_server(server_id) { - cx.spawn(|_, _| async move { - let Some(protocol) = server.client.read().clone() else { - return; - }; - - if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - for prompt in prompts - .into_iter() - .filter(context_server_command::acceptable_prompt) - { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - context_server_registry.register_command( - server.id.clone(), - prompt.name.as_str(), - ); - slash_command_registry.register_command( - context_server_command::ContextServerSlashCommand::new( - &server, prompt, - ), - true, - ); - } - } - } - }) - .detach(); - } - }, - ); - - cx.update_model( - &manager, - |manager: &mut context_servers::manager::ContextServerManager, cx| { - let tool_registry = ToolRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(server) = manager.get_server(server_id) { - cx.spawn(|_, _| async move { - let Some(protocol) = server.client.read().clone() else { - return; - }; - - if protocol.capable(context_servers::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { - for tool in tools.tools { - log::info!( - "registering context server tool: {:?}", - tool.name - ); - context_server_registry.register_tool( - server.id.clone(), - tool.name.as_str(), - ); - tool_registry.register_tool( - tools::context_server_tool::ContextServerTool::new( - server.id.clone(), - tool - ), - ); - } - } - } - }) - .detach(); - } - }, - ); - } - context_servers::manager::Event::ServerStopped { server_id } => { - let slash_command_registry = SlashCommandRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(commands) = context_server_registry.get_commands(server_id) { - for command_name in commands { - slash_command_registry.unregister_command_by_name(&command_name); - context_server_registry.unregister_command(&server_id, &command_name); - } - } - - if let Some(tools) = context_server_registry.get_tools(server_id) { - let tool_registry = ToolRegistry::global(cx); - for tool_name in tools { - tool_registry.unregister_tool_by_name(&tool_name); - context_server_registry.unregister_tool(&server_id, &tool_name); - } - } - } - }, - ) - .detach(); -} - fn init_language_model_settings(cx: &mut AppContext) { update_active_language_model_from_settings(cx); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 7562901106..52139fce8c 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,4 +1,6 @@ use crate::slash_command::file_command::codeblock_fence_for_path; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, @@ -7,21 +9,20 @@ use crate::{ slash_command::{ default_command::DefaultSlashCommand, docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, - file_command, SlashCommandCompletionProvider, SlashCommandRegistry, + file_command, SlashCommandCompletionProvider, }, slash_command_picker, terminal_inline_assistant::TerminalInlineAssistant, Assist, AssistantPatch, AssistantPatchStatus, CacheStatus, ConfirmCommand, Content, Context, ContextEvent, ContextId, ContextStore, ContextStoreEvent, CopyCode, CycleMessageRole, DeployHistory, DeployPromptLibrary, Edit, InlineAssistant, InsertDraggedFiles, - InsertIntoEditor, InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, - MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, ParsedSlashCommand, - PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, RequestType, - SavedContextMetadata, SlashCommandId, Split, ToggleFocus, ToggleModelSelector, + InsertIntoEditor, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, + MessageMetadata, MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, + ParsedSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, + RequestType, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::Result; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; -use assistant_tool::ToolRegistry; use client::{proto, zed_urls, Client, Status}; use collections::{hash_map, BTreeSet, HashMap, HashSet}; use editor::{ @@ -112,7 +113,8 @@ pub fn init(cx: &mut AppContext) { .register_action(ContextEditor::copy_code) .register_action(ContextEditor::insert_dragged_files) .register_action(AssistantPanel::show_configuration) - .register_action(AssistantPanel::create_new_context); + .register_action(AssistantPanel::create_new_context) + .register_action(AssistantPanel::restart_context_servers); }, ) .detach(); @@ -315,10 +317,12 @@ impl AssistantPanel { cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { + let slash_commands = Arc::new(SlashCommandWorkingSet::default()); + let tools = Arc::new(ToolWorkingSet::default()); let context_store = workspace .update(&mut cx, |workspace, cx| { let project = workspace.project().clone(); - ContextStore::new(project, prompt_builder.clone(), cx) + ContextStore::new(project, prompt_builder.clone(), slash_commands, tools, cx) })? .await?; @@ -1294,6 +1298,24 @@ impl AssistantPanel { .active_provider() .map_or(None, |provider| Some(provider.authenticate(cx))) } + + fn restart_context_servers( + workspace: &mut Workspace, + _action: &context_servers::Restart, + cx: &mut ViewContext, + ) { + let Some(assistant_panel) = workspace.panel::(cx) else { + return; + }; + + assistant_panel.update(cx, |assistant_panel, cx| { + assistant_panel + .context_store + .update(cx, |context_store, cx| { + context_store.restart_context_servers(cx); + }); + }); + } } impl Render for AssistantPanel { @@ -1468,6 +1490,8 @@ enum AssistError { pub struct ContextEditor { context: Model, fs: Arc, + slash_commands: Arc, + tools: Arc, workspace: WeakView, project: Model, lsp_adapter_delegate: Option>, @@ -1477,7 +1501,7 @@ pub struct ContextEditor { scroll_position: Option, remote_id: Option, pending_slash_command_creases: HashMap, CreaseId>, - invoked_slash_command_creases: HashMap, + invoked_slash_command_creases: HashMap, pending_tool_use_creases: HashMap, CreaseId>, _subscriptions: Vec, patches: HashMap, PatchViewState>, @@ -1536,8 +1560,12 @@ impl ContextEditor { let sections = context.read(cx).slash_command_output_sections().to_vec(); let patch_ranges = context.read(cx).patch_ranges().collect::>(); + let slash_commands = context.read(cx).slash_commands.clone(); + let tools = context.read(cx).tools.clone(); let mut this = Self { context, + slash_commands, + tools, editor, lsp_adapter_delegate, blocks: Default::default(), @@ -1688,7 +1716,7 @@ impl ContextEditor { } pub fn insert_command(&mut self, name: &str, cx: &mut ViewContext) { - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { self.editor.update(cx, |editor, cx| { editor.transact(cx, |editor, cx| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| s.try_cancel()); @@ -1770,7 +1798,7 @@ impl ContextEditor { workspace: WeakView, cx: &mut ViewContext, ) { - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { let context = self.context.read(cx); let sections = context .slash_command_output_sections() @@ -2043,8 +2071,7 @@ impl ContextEditor { .collect::>(); for tool_use in pending_tool_uses { - let tool_registry = ToolRegistry::global(cx); - if let Some(tool) = tool_registry.tool(&tool_use.name) { + if let Some(tool) = self.tools.tool(&tool_use.name, cx) { let task = tool.run(tool_use.input, self.workspace.clone(), cx); self.context.update(cx, |context, cx| { @@ -2108,7 +2135,7 @@ impl ContextEditor { fn update_invoked_slash_command( &mut self, - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, cx: &mut ViewContext, ) { self.editor.update(cx, |editor, cx| { @@ -3719,6 +3746,19 @@ impl ContextEditor { }) } + fn render_inject_context_menu(&self, cx: &mut ViewContext) -> impl IntoElement { + slash_command_picker::SlashCommandSelector::new( + self.slash_commands.clone(), + cx.view().downgrade(), + Button::new("trigger", "Add Context") + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)), + ) + } + fn render_last_error(&self, cx: &mut ViewContext) -> Option { let last_error = self.last_error.as_ref()?; @@ -4133,11 +4173,7 @@ impl Render for ContextEditor { .border_t_1() .border_color(cx.theme().colors().border_variant) .bg(cx.theme().colors().editor_background) - .child( - h_flex() - .gap_1() - .child(render_inject_context_menu(cx.view().downgrade(), cx)), - ) + .child(h_flex().gap_1().child(self.render_inject_context_menu(cx))) .child( h_flex() .w_full() @@ -4419,24 +4455,6 @@ pub struct ContextEditorToolbarItem { model_selector_menu_handle: PopoverMenuHandle>, } -fn render_inject_context_menu( - active_context_editor: WeakView, - cx: &mut WindowContext<'_>, -) -> impl IntoElement { - let commands = SlashCommandRegistry::global(cx); - - slash_command_picker::SlashCommandSelector::new( - commands.clone(), - active_context_editor, - Button::new("trigger", "Add Context") - .icon(IconName::Plus) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)), - ) -} - impl ContextEditorToolbarItem { pub fn new( workspace: &Workspace, @@ -5095,7 +5113,7 @@ fn make_lsp_adapter_delegate( enum PendingSlashCommand {} fn invoked_slash_command_fold_placeholder( - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, context: WeakModel, ) -> FoldPlaceholder { FoldPlaceholder { diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index dd1e397300..da058969b6 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod context_tests; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ prompts::PromptBuilder, slash_command::{file_command::FileCommandMetadata, SlashCommandLine}, @@ -8,10 +10,8 @@ use crate::{ }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ - SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandRegistry, - SlashCommandResult, + SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandResult, }; -use assistant_tool::ToolRegistry; use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; @@ -95,13 +95,13 @@ pub enum ContextOperation { version: clock::Global, }, SlashCommandStarted { - id: SlashCommandId, + id: InvokedSlashCommandId, output_range: Range, name: String, version: clock::Global, }, SlashCommandFinished { - id: SlashCommandId, + id: InvokedSlashCommandId, timestamp: clock::Lamport, error_message: Option, version: clock::Global, @@ -167,7 +167,7 @@ impl ContextOperation { }), proto::context_operation::Variant::SlashCommandStarted(message) => { Ok(Self::SlashCommandStarted { - id: SlashCommandId(language::proto::deserialize_timestamp( + id: InvokedSlashCommandId(language::proto::deserialize_timestamp( message.id.context("invalid id")?, )), output_range: language::proto::deserialize_anchor_range( @@ -198,7 +198,7 @@ impl ContextOperation { } proto::context_operation::Variant::SlashCommandCompleted(message) => { Ok(Self::SlashCommandFinished { - id: SlashCommandId(language::proto::deserialize_timestamp( + id: InvokedSlashCommandId(language::proto::deserialize_timestamp( message.id.context("invalid id")?, )), timestamp: language::proto::deserialize_timestamp( @@ -372,7 +372,7 @@ pub enum ContextEvent { updated: Vec>, }, InvokedSlashCommandChanged { - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, }, ParsedSlashCommandsUpdated { removed: Vec>, @@ -513,7 +513,7 @@ struct PendingCompletion { } #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct SlashCommandId(clock::Lamport); +pub struct InvokedSlashCommandId(clock::Lamport); #[derive(Clone, Debug)] pub struct XmlTag { @@ -543,8 +543,10 @@ pub struct Context { operations: Vec, buffer: Model, parsed_slash_commands: Vec, - invoked_slash_commands: HashMap, + invoked_slash_commands: HashMap, edits_since_last_parse: language::Subscription, + pub(crate) slash_commands: Arc, + pub(crate) tools: Arc, slash_command_output_sections: Vec>, pending_tool_uses_by_id: HashMap, PendingToolUse>, message_anchors: Vec, @@ -598,6 +600,8 @@ impl Context { project: Option>, telemetry: Option>, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, cx: &mut ModelContext, ) -> Self { Self::new( @@ -606,6 +610,8 @@ impl Context { language::Capability::ReadWrite, language_registry, prompt_builder, + slash_commands, + tools, project, telemetry, cx, @@ -619,6 +625,8 @@ impl Context { capability: language::Capability, language_registry: Arc, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -663,6 +671,8 @@ impl Context { telemetry, project, language_registry, + slash_commands, + tools, patches: Vec::new(), xml_tags: Vec::new(), prompt_builder, @@ -738,6 +748,8 @@ impl Context { path: PathBuf, language_registry: Arc, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -749,6 +761,8 @@ impl Context { language::Capability::ReadWrite, language_registry, prompt_builder, + slash_commands, + tools, project, telemetry, cx, @@ -1086,7 +1100,7 @@ impl Context { pub fn invoked_slash_command( &self, - command_id: &SlashCommandId, + command_id: &InvokedSlashCommandId, ) -> Option<&InvokedSlashCommand> { self.invoked_slash_commands.get(command_id) } @@ -1455,7 +1469,7 @@ impl Context { }) .map(ToOwned::to_owned) .collect::>(); - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { if !command.requires_argument() || !arguments.is_empty() { let start_ix = offset + command_line.name.start - 1; let end_ix = offset @@ -1847,7 +1861,7 @@ impl Context { cx: &mut ModelContext, ) { let version = self.version.clone(); - let command_id = SlashCommandId(self.next_timestamp()); + let command_id = InvokedSlashCommandId(self.next_timestamp()); const PENDING_OUTPUT_END_MARKER: &str = "…"; @@ -2227,9 +2241,9 @@ impl Context { let mut request = self.to_completion_request(request_type, cx); if cx.has_flag::() { - let tool_registry = ToolRegistry::global(cx); - request.tools = tool_registry - .tools() + request.tools = self + .tools + .tools(cx) .into_iter() .map(|tool| LanguageModelRequestTool { name: tool.name(), diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs index 615bdc5d02..c68aeb25e2 100644 --- a/crates/assistant/src/context/context_tests.rs +++ b/crates/assistant/src/context/context_tests.rs @@ -1,8 +1,10 @@ use super::{AssistantEdit, MessageCacheMetadata}; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus, - Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder, - SlashCommandId, + Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId, + MessageStatus, PromptBuilder, }; use anyhow::Result; use assistant_slash_command::{ @@ -49,8 +51,17 @@ fn test_inserting_and_removing_messages(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -182,8 +193,17 @@ fn test_message_splitting(cx: &mut AppContext) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -277,8 +297,17 @@ fn test_messages_for_offsets(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -383,13 +412,22 @@ async fn test_slash_commands(cx: &mut TestAppContext) { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); #[derive(Default)] struct ContextRanges { parsed_commands: HashSet>, - command_outputs: HashMap>, + command_outputs: HashMap>, output_sections: HashSet>, } @@ -671,6 +709,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) { Some(project), None, prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), cx, ) }); @@ -934,6 +974,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) { Default::default(), registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1042,8 +1084,17 @@ async fn test_serialization(cx: &mut TestAppContext) { cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); let message_1 = context.update(cx, |context, cx| { @@ -1083,6 +1134,8 @@ async fn test_serialization(cx: &mut TestAppContext) { Default::default(), registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1141,6 +1194,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std language::Capability::ReadWrite, registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1394,8 +1449,17 @@ fn test_mark_cache_anchors(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); // Create a test cache configuration diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index f4f03dda37..ad8ae9808f 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,10 +1,16 @@ +use crate::slash_command::context_server_command; use crate::{ - prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, - SavedContext, SavedContextMetadata, + prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context, + ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata, }; +use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet}; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use clock::ReplicaId; +use collections::HashMap; +use command_palette_hooks::CommandPaletteFilter; +use context_servers::manager::{ContextServerManager, ContextServerSettings}; +use context_servers::CONTEXT_SERVERS_NAMESPACE; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; @@ -16,6 +22,7 @@ use paths::contexts_dir; use project::Project; use regex::Regex; use rpc::AnyProtoClient; +use settings::{Settings as _, SettingsStore}; use std::{ cmp::Reverse, ffi::OsStr, @@ -43,9 +50,14 @@ pub struct RemoteContextMetadata { pub struct ContextStore { contexts: Vec, contexts_metadata: Vec, + context_server_manager: Model, + context_server_slash_command_ids: HashMap>, + context_server_tool_ids: HashMap>, host_contexts: Vec, fs: Arc, languages: Arc, + slash_commands: Arc, + tools: Arc, telemetry: Arc, _watch_updates: Task>, client: Arc, @@ -87,6 +99,8 @@ impl ContextStore { pub fn new( project: Model, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, cx: &mut AppContext, ) -> Task>> { let fs = project.read(cx).fs().clone(); @@ -97,12 +111,18 @@ impl ContextStore { let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new_model(|cx: &mut ModelContext| { + let context_server_manager = cx.new_model(|_cx| ContextServerManager::new()); let mut this = Self { contexts: Vec::new(), contexts_metadata: Vec::new(), + context_server_manager, + context_server_slash_command_ids: HashMap::default(), + context_server_tool_ids: HashMap::default(), host_contexts: Vec::new(), fs, languages, + slash_commands, + tools, telemetry, _watch_updates: cx.spawn(|this, mut cx| { async move { @@ -127,15 +147,48 @@ impl ContextStore { }; this.handle_project_changed(project, cx); this.synchronize_contexts(cx); + this.register_context_server_handlers(cx); this })?; this.update(&mut cx, |this, cx| this.reload(cx))? .await .log_err(); + + this.update(&mut cx, |this, cx| { + this.watch_context_server_settings(cx); + }) + .log_err(); + Ok(this) }) } + fn watch_context_server_settings(&self, cx: &mut ModelContext) { + cx.observe_global::(move |this, cx| { + this.context_server_manager.update(cx, |manager, cx| { + let location = this.project.read(cx).worktrees(cx).next().map(|worktree| { + settings::SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: Path::new(""), + } + }); + let settings = ContextServerSettings::get(location, cx); + + manager.maintain_servers(settings, cx); + + let has_any_context_servers = !manager.servers().is_empty(); + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if has_any_context_servers { + filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); + } else { + filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); + } + }); + }) + }) + .detach(); + } + async fn handle_advertise_contexts( this: Model, envelope: TypedEnvelope, @@ -342,6 +395,8 @@ impl ContextStore { Some(self.project.clone()), Some(self.telemetry.clone()), self.prompt_builder.clone(), + self.slash_commands.clone(), + self.tools.clone(), cx, ) }); @@ -364,6 +419,8 @@ impl ContextStore { let project = self.project.clone(); let telemetry = self.telemetry.clone(); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(|this, mut cx| async move { let response = request.await?; @@ -376,6 +433,8 @@ impl ContextStore { capability, language_registry, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -425,6 +484,8 @@ impl ContextStore { } }); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let saved_context = load.await?; @@ -434,6 +495,8 @@ impl ContextStore { path.clone(), languages, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -500,6 +563,8 @@ impl ContextStore { context_id: context_id.to_proto(), }); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_proto = response.context.context("invalid context")?; @@ -510,6 +575,8 @@ impl ContextStore { capability, language_registry, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -745,4 +812,114 @@ impl ContextStore { }) }) } + + pub fn restart_context_servers(&mut self, cx: &mut ModelContext) { + cx.update_model( + &self.context_server_manager, + |context_server_manager, cx| { + for server in context_server_manager.servers() { + context_server_manager + .restart_server(&server.id, cx) + .detach_and_log_err(cx); + } + }, + ); + } + + fn register_context_server_handlers(&self, cx: &mut ModelContext) { + cx.subscribe( + &self.context_server_manager.clone(), + Self::handle_context_server_event, + ) + .detach(); + } + + fn handle_context_server_event( + &mut self, + context_server_manager: Model, + event: &context_servers::manager::Event, + cx: &mut ModelContext, + ) { + let slash_command_working_set = self.slash_commands.clone(); + let tool_working_set = self.tools.clone(); + match event { + context_servers::manager::Event::ServerStarted { server_id } => { + if let Some(server) = context_server_manager.read(cx).get_server(server_id) { + let context_server_manager = context_server_manager.clone(); + cx.spawn({ + let server = server.clone(); + let server_id = server_id.clone(); + |this, mut cx| async move { + let Some(protocol) = server.client.read().clone() else { + return; + }; + + if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + let slash_command_ids = prompts + .into_iter() + .filter(context_server_command::acceptable_prompt) + .map(|prompt| { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + slash_command_working_set.insert(Arc::new( + context_server_command::ContextServerSlashCommand::new( + context_server_manager.clone(), + &server, + prompt, + ), + )) + }) + .collect::>(); + + this.update(&mut cx, |this, _cx| { + this.context_server_slash_command_ids + .insert(server_id.clone(), slash_command_ids); + }) + .log_err(); + } + } + + if protocol.capable(context_servers::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + let tool_ids = tools.tools.into_iter().map(|tool| { + log::info!("registering context server tool: {:?}", tool.name); + tool_working_set.insert( + Arc::new(tools::context_server_tool::ContextServerTool::new( + context_server_manager.clone(), + server.id.clone(), + tool, + )), + ) + + }).collect::>(); + + + this.update(&mut cx, |this, _cx| { + this.context_server_tool_ids + .insert(server_id, tool_ids); + }) + .log_err(); + } + } + } + }) + .detach(); + } + } + context_servers::manager::Event::ServerStopped { server_id } => { + if let Some(slash_command_ids) = + self.context_server_slash_command_ids.remove(server_id) + { + slash_command_working_set.remove(&slash_command_ids); + } + + if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { + tool_working_set.remove(&tool_ids); + } + } + } + } } diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 5b22e76bf8..997f289c9b 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -8,7 +8,7 @@ use context_servers::{ manager::{ContextServer, ContextServerManager}, types::Prompt, }; -use gpui::{AppContext, Task, WeakView, WindowContext}; +use gpui::{AppContext, Model, Task, WeakView, WindowContext}; use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -19,15 +19,21 @@ use workspace::Workspace; use crate::slash_command::create_label_for_command; pub struct ContextServerSlashCommand { + server_manager: Model, server_id: String, prompt: Prompt, } impl ContextServerSlashCommand { - pub fn new(server: &Arc, prompt: Prompt) -> Self { + pub fn new( + server_manager: Model, + server: &Arc, + prompt: Prompt, + ) -> Self { Self { server_id: server.id.clone(), prompt, + server_manager, } } } @@ -74,18 +80,14 @@ impl SlashCommand for ContextServerSlashCommand { _workspace: Option>, cx: &mut WindowContext, ) -> Task>> { + let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else { + return Task::ready(Err(anyhow!("Failed to complete argument"))); + }; + let server_id = self.server_id.clone(); let prompt_name = self.prompt.name.clone(); - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); - let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) { - Ok(tp) => tp, - Err(e) => { - return Task::ready(Err(e)); - } - }; - if let Some(server) = manager.get_server(&server_id) { + if let Some(server) = self.server_manager.read(cx).get_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client.read().clone() else { return Err(anyhow!("Context server not initialized")); @@ -100,7 +102,7 @@ impl SlashCommand for ContextServerSlashCommand { }, ), arg_name, - arg_val, + arg_value, ) .await?; @@ -138,8 +140,7 @@ impl SlashCommand for ContextServerSlashCommand { Err(e) => return Task::ready(Err(e)), }; - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); + let manager = self.server_manager.read(cx); if let Some(server) = manager.get_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client.read().clone() else { diff --git a/crates/assistant/src/slash_command_picker.rs b/crates/assistant/src/slash_command_picker.rs index b376dd8881..8e797d6184 100644 --- a/crates/assistant/src/slash_command_picker.rs +++ b/crates/assistant/src/slash_command_picker.rs @@ -1,16 +1,15 @@ use std::sync::Arc; -use assistant_slash_command::SlashCommandRegistry; - use gpui::{AnyElement, DismissEvent, SharedString, Task, WeakView}; use picker::{Picker, PickerDelegate, PickerEditorPosition}; use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverTrigger}; use crate::assistant_panel::ContextEditor; +use crate::SlashCommandWorkingSet; #[derive(IntoElement)] pub(super) struct SlashCommandSelector { - registry: Arc, + working_set: Arc, active_context_editor: WeakView, trigger: T, } @@ -51,12 +50,12 @@ pub(crate) struct SlashCommandDelegate { impl SlashCommandSelector { pub(crate) fn new( - registry: Arc, + working_set: Arc, active_context_editor: WeakView, trigger: T, ) -> Self { SlashCommandSelector { - registry, + working_set, active_context_editor, trigger, } @@ -231,11 +230,11 @@ impl PickerDelegate for SlashCommandDelegate { impl RenderOnce for SlashCommandSelector { fn render(self, cx: &mut WindowContext) -> impl IntoElement { let all_models = self - .registry - .featured_command_names() + .working_set + .featured_command_names(cx) .into_iter() .filter_map(|command_name| { - let command = self.registry.command(&command_name)?; + let command = self.working_set.command(&command_name, cx)?; let menu_text = SharedString::from(Arc::from(command.menu_text())); let label = command.label(cx); let args = label.filter_range.end.ne(&label.text.len()).then(|| { diff --git a/crates/assistant/src/slash_command_working_set.rs b/crates/assistant/src/slash_command_working_set.rs new file mode 100644 index 0000000000..93979557c1 --- /dev/null +++ b/crates/assistant/src/slash_command_working_set.rs @@ -0,0 +1,66 @@ +use assistant_slash_command::{SlashCommand, SlashCommandRegistry}; +use collections::HashMap; +use gpui::AppContext; +use parking_lot::Mutex; +use std::sync::Arc; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] +pub struct SlashCommandId(usize); + +/// A working set of slash commands for use in one instance of the Assistant Panel. +#[derive(Default)] +pub struct SlashCommandWorkingSet { + state: Mutex, +} + +#[derive(Default)] +struct WorkingSetState { + context_server_commands_by_id: HashMap>, + context_server_commands_by_name: HashMap>, + next_command_id: SlashCommandId, +} + +impl SlashCommandWorkingSet { + pub fn command(&self, name: &str, cx: &AppContext) -> Option> { + self.state + .lock() + .context_server_commands_by_name + .get(name) + .cloned() + .or_else(|| SlashCommandRegistry::global(cx).command(name)) + } + + pub fn featured_command_names(&self, cx: &AppContext) -> Vec> { + SlashCommandRegistry::global(cx).featured_command_names() + } + + pub fn insert(&self, command: Arc) -> SlashCommandId { + let mut state = self.state.lock(); + let command_id = state.next_command_id; + state.next_command_id.0 += 1; + state + .context_server_commands_by_id + .insert(command_id, command.clone()); + state.slash_commands_changed(); + command_id + } + + pub fn remove(&self, command_ids_to_remove: &[SlashCommandId]) { + let mut state = self.state.lock(); + state + .context_server_commands_by_id + .retain(|id, _| !command_ids_to_remove.contains(id)); + state.slash_commands_changed(); + } +} + +impl WorkingSetState { + fn slash_commands_changed(&mut self) { + self.context_server_commands_by_name.clear(); + self.context_server_commands_by_name.extend( + self.context_server_commands_by_id + .values() + .map(|command| (command.name(), command.clone())), + ); + } +} diff --git a/crates/assistant/src/tool_working_set.rs b/crates/assistant/src/tool_working_set.rs new file mode 100644 index 0000000000..aa2bb7a530 --- /dev/null +++ b/crates/assistant/src/tool_working_set.rs @@ -0,0 +1,75 @@ +use assistant_tool::{Tool, ToolRegistry}; +use collections::HashMap; +use gpui::AppContext; +use parking_lot::Mutex; +use std::sync::Arc; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] +pub struct ToolId(usize); + +/// A working set of tools for use in one instance of the Assistant Panel. +#[derive(Default)] +pub struct ToolWorkingSet { + state: Mutex, +} + +#[derive(Default)] +struct WorkingSetState { + context_server_tools_by_id: HashMap>, + context_server_tools_by_name: HashMap>, + next_tool_id: ToolId, +} + +impl ToolWorkingSet { + pub fn tool(&self, name: &str, cx: &AppContext) -> Option> { + self.state + .lock() + .context_server_tools_by_name + .get(name) + .cloned() + .or_else(|| ToolRegistry::global(cx).tool(name)) + } + + pub fn tools(&self, cx: &AppContext) -> Vec> { + let mut tools = ToolRegistry::global(cx).tools(); + tools.extend( + self.state + .lock() + .context_server_tools_by_id + .values() + .cloned(), + ); + + tools + } + + pub fn insert(&self, command: Arc) -> ToolId { + let mut state = self.state.lock(); + let command_id = state.next_tool_id; + state.next_tool_id.0 += 1; + state + .context_server_tools_by_id + .insert(command_id, command.clone()); + state.tools_changed(); + command_id + } + + pub fn remove(&self, command_ids_to_remove: &[ToolId]) { + let mut state = self.state.lock(); + state + .context_server_tools_by_id + .retain(|id, _| !command_ids_to_remove.contains(id)); + state.tools_changed(); + } +} + +impl WorkingSetState { + fn tools_changed(&mut self) { + self.context_server_tools_by_name.clear(); + self.context_server_tools_by_name.extend( + self.context_server_tools_by_id + .values() + .map(|command| (command.name(), command.clone())), + ); + } +} diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs index 93edb32b75..72bb87191f 100644 --- a/crates/assistant/src/tools/context_server_tool.rs +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -2,16 +2,22 @@ use anyhow::{anyhow, bail}; use assistant_tool::Tool; use context_servers::manager::ContextServerManager; use context_servers::types; -use gpui::Task; +use gpui::{Model, Task}; pub struct ContextServerTool { + server_manager: Model, server_id: String, tool: types::Tool, } impl ContextServerTool { - pub fn new(server_id: impl Into, tool: types::Tool) -> Self { + pub fn new( + server_manager: Model, + server_id: impl Into, + tool: types::Tool, + ) -> Self { Self { + server_manager, server_id: server_id.into(), tool, } @@ -45,9 +51,7 @@ impl Tool for ContextServerTool { _workspace: gpui::WeakView, cx: &mut ui::WindowContext, ) -> gpui::Task> { - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); - if let Some(server) = manager.get_server(&self.server_id) { + if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { cx.foreground_executor().spawn({ let tool_name = self.tool.name.clone(); async move { diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f4c401eb78..8c819d6da7 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,7 +6,7 @@ use crate::{ }, }; use anyhow::{anyhow, Result}; -use assistant::{ContextStore, PromptBuilder}; +use assistant::{ContextStore, PromptBuilder, SlashCommandWorkingSet, ToolWorkingSet}; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; use collections::{HashMap, HashSet}; @@ -6489,11 +6489,27 @@ async fn test_context_collaboration_with_reconnect( let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a - .update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx)) + .update(|cx| { + ContextStore::new( + project_a.clone(), + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }) .await .unwrap(); let context_store_b = cx_b - .update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx)) + .update(|cx| { + ContextStore::new( + project_b.clone(), + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }) .await .unwrap(); diff --git a/crates/context_servers/Cargo.toml b/crates/context_servers/Cargo.toml index 9c0336f121..73a393375c 100644 --- a/crates/context_servers/Cargo.toml +++ b/crates/context_servers/Cargo.toml @@ -27,4 +27,3 @@ settings.workspace = true smol.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true -workspace.workspace = true diff --git a/crates/context_servers/src/context_servers.rs b/crates/context_servers/src/context_servers.rs index 55634bb77c..c20b1ebdb1 100644 --- a/crates/context_servers/src/context_servers.rs +++ b/crates/context_servers/src/context_servers.rs @@ -1,40 +1,23 @@ -use gpui::{actions, AppContext, Context, ViewContext}; -use manager::ContextServerManager; -use workspace::Workspace; +use command_palette_hooks::CommandPaletteFilter; +use gpui::{actions, AppContext}; +use settings::Settings; + +use crate::manager::ContextServerSettings; pub mod client; pub mod manager; pub mod protocol; -mod registry; pub mod types; -pub use registry::*; - actions!(context_servers, [Restart]); /// The namespace for the context servers actions. -const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; +pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; pub fn init(cx: &mut AppContext) { - log::info!("initializing context server client"); - manager::init(cx); - ContextServerRegistry::register(cx); + ContextServerSettings::register(cx); - cx.observe_new_views( - |workspace: &mut Workspace, _cx: &mut ViewContext| { - workspace.register_action(restart_servers); - }, - ) - .detach(); -} - -fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext) { - let model = ContextServerManager::global(cx); - cx.update_model(&model, |manager, cx| { - for server in manager.servers() { - manager - .restart_server(&server.id, cx) - .detach_and_log_err(cx); - } + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); }); } diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs index 3c21fd53fb..34583e559a 100644 --- a/crates/context_servers/src/manager.rs +++ b/crates/context_servers/src/manager.rs @@ -14,18 +14,17 @@ //! The module also includes initialization logic to set up the context server system //! and react to changes in settings. +use std::path::Path; +use std::sync::Arc; + use collections::{HashMap, HashSet}; -use command_palette_hooks::CommandPaletteFilter; -use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task}; +use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task}; use log; use parking_lot::RwLock; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, SettingsStore}; -use std::path::Path; -use std::sync::Arc; +use settings::{Settings, SettingsSources}; -use crate::CONTEXT_SERVERS_NAMESPACE; use crate::{ client::{self, Client}, types, @@ -124,7 +123,6 @@ pub enum Event { ServerStopped { server_id: String }, } -impl Global for ContextServerManager {} impl EventEmitter for ContextServerManager {} impl Default for ContextServerManager { @@ -140,14 +138,11 @@ impl ContextServerManager { pending_servers: HashSet::default(), } } - pub fn global(cx: &AppContext) -> Model { - cx.global::().0.clone() - } pub fn add_server( &mut self, config: ServerConfig, - cx: &mut ModelContext, + cx: &ModelContext, ) -> Task> { let server_id = config.id.clone(); @@ -179,11 +174,7 @@ impl ContextServerManager { self.servers.get(id).cloned() } - pub fn remove_server( - &mut self, - id: &str, - cx: &mut ModelContext, - ) -> Task> { + pub fn remove_server(&mut self, id: &str, cx: &ModelContext) -> Task> { let id = id.to_string(); cx.spawn(|this, mut cx| async move { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { @@ -229,75 +220,38 @@ impl ContextServerManager { self.servers.values().cloned().collect() } - pub fn model(cx: &mut AppContext) -> Model { - cx.new_model(|_cx| ContextServerManager::new()) + pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext) { + let current_servers = self + .servers() + .into_iter() + .map(|server| (server.id.clone(), server.config.clone())) + .collect::>(); + + let new_servers = settings + .servers + .iter() + .map(|config| (config.id.clone(), config.clone())) + .collect::>(); + + let servers_to_add = new_servers + .values() + .filter(|config| !current_servers.contains_key(&config.id)) + .cloned() + .collect::>(); + + let servers_to_remove = current_servers + .keys() + .filter(|id| !new_servers.contains_key(*id)) + .cloned() + .collect::>(); + + log::trace!("servers_to_add={:?}", servers_to_add); + for config in servers_to_add { + self.add_server(config, cx).detach_and_log_err(cx); + } + + for id in servers_to_remove { + self.remove_server(&id, cx).detach_and_log_err(cx); + } } } - -pub struct GlobalContextServerManager(Model); -impl Global for GlobalContextServerManager {} - -impl GlobalContextServerManager { - fn register(cx: &mut AppContext) { - let model = ContextServerManager::model(cx); - cx.set_global(Self(model)); - } -} - -pub fn init(cx: &mut AppContext) { - ContextServerSettings::register(cx); - GlobalContextServerManager::register(cx); - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - }); - - cx.observe_global::(|cx| { - let manager = ContextServerManager::global(cx); - cx.update_model(&manager, |manager, cx| { - let settings = ContextServerSettings::get_global(cx); - let current_servers = manager - .servers() - .into_iter() - .map(|server| (server.id.clone(), server.config.clone())) - .collect::>(); - - let new_servers = settings - .servers - .iter() - .map(|config| (config.id.clone(), config.clone())) - .collect::>(); - - let servers_to_add = new_servers - .values() - .filter(|config| !current_servers.contains_key(&config.id)) - .cloned() - .collect::>(); - - let servers_to_remove = current_servers - .keys() - .filter(|id| !new_servers.contains_key(*id)) - .cloned() - .collect::>(); - - log::trace!("servers_to_add={:?}", servers_to_add); - for config in servers_to_add { - manager.add_server(config, cx).detach_and_log_err(cx); - } - - for id in servers_to_remove { - manager.remove_server(&id, cx).detach_and_log_err(cx); - } - - let has_any_context_servers = !manager.servers().is_empty(); - CommandPaletteFilter::update_global(cx, |filter, _cx| { - if has_any_context_servers { - filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); - } else { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - } - }); - }) - }) - .detach(); -} diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs deleted file mode 100644 index 5490187034..0000000000 --- a/crates/context_servers/src/registry.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use gpui::{AppContext, Global, ReadGlobal}; -use parking_lot::RwLock; - -struct GlobalContextServerRegistry(Arc); - -impl Global for GlobalContextServerRegistry {} - -pub struct ContextServerRegistry { - command_registry: RwLock>>>, - tool_registry: RwLock>>>, -} - -impl ContextServerRegistry { - pub fn global(cx: &AppContext) -> Arc { - GlobalContextServerRegistry::global(cx).0.clone() - } - - pub fn register(cx: &mut AppContext) { - cx.set_global(GlobalContextServerRegistry(Arc::new( - ContextServerRegistry { - command_registry: RwLock::new(HashMap::default()), - tool_registry: RwLock::new(HashMap::default()), - }, - ))) - } - - pub fn register_command(&self, server_id: String, command_name: &str) { - let mut registry = self.command_registry.write(); - registry - .entry(server_id) - .or_default() - .push(command_name.into()); - } - - pub fn unregister_command(&self, server_id: &str, command_name: &str) { - let mut registry = self.command_registry.write(); - if let Some(commands) = registry.get_mut(server_id) { - commands.retain(|name| name.as_ref() != command_name); - } - } - - pub fn get_commands(&self, server_id: &str) -> Option>> { - let registry = self.command_registry.read(); - registry.get(server_id).cloned() - } - - pub fn register_tool(&self, server_id: String, tool_name: &str) { - let mut registry = self.tool_registry.write(); - registry - .entry(server_id) - .or_default() - .push(tool_name.into()); - } - - pub fn unregister_tool(&self, server_id: &str, tool_name: &str) { - let mut registry = self.tool_registry.write(); - if let Some(tools) = registry.get_mut(server_id) { - tools.retain(|name| name.as_ref() != tool_name); - } - } - - pub fn get_tools(&self, server_id: &str) -> Option>> { - let registry = self.tool_registry.read(); - registry.get(server_id).cloned() - } -}