diff --git a/Cargo.lock b/Cargo.lock index 56ccdae0a3..604c5da77a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2829,7 +2829,6 @@ dependencies = [ "smol", "url", "util", - "workspace", ] [[package]] diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index ef059e2090..236311753f 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -12,10 +12,14 @@ mod prompts; mod slash_command; pub(crate) mod slash_command_picker; pub mod slash_command_settings; +mod slash_command_working_set; mod streaming_diff; mod terminal_inline_assistant; +mod tool_working_set; mod tools; +pub use crate::slash_command_working_set::{SlashCommandId, SlashCommandWorkingSet}; +pub use crate::tool_working_set::{ToolId, ToolWorkingSet}; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; @@ -23,12 +27,11 @@ use assistant_tool::ToolRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; pub use context::*; -use context_servers::ContextServerRegistry; pub use context_store::*; use feature_flags::FeatureFlagAppExt; use fs::Fs; +use gpui::impl_actions; use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal}; -use gpui::{impl_actions, Context as _}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; use language_model::{ @@ -43,10 +46,9 @@ use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; use slash_command::search_command::SearchSlashCommandFeatureFlag; use slash_command::{ - auto_command, cargo_workspace_command, context_server_command, default_command, delta_command, - diagnostics_command, docs_command, fetch_command, file_command, now_command, project_command, - prompt_command, search_command, selection_command, symbols_command, tab_command, - terminal_command, + auto_command, cargo_workspace_command, default_command, delta_command, diagnostics_command, + docs_command, fetch_command, file_command, now_command, project_command, prompt_command, + search_command, selection_command, symbols_command, tab_command, terminal_command, }; use std::path::PathBuf; use std::sync::Arc; @@ -281,116 +283,9 @@ pub fn init( }) .detach(); - register_context_server_handlers(cx); - prompt_builder } -fn register_context_server_handlers(cx: &mut AppContext) { - cx.subscribe( - &context_servers::manager::ContextServerManager::global(cx), - |manager, event, cx| match event { - context_servers::manager::Event::ServerStarted { server_id } => { - cx.update_model( - &manager, - |manager: &mut context_servers::manager::ContextServerManager, cx| { - let slash_command_registry = SlashCommandRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(server) = manager.get_server(server_id) { - cx.spawn(|_, _| async move { - let Some(protocol) = server.client.read().clone() else { - return; - }; - - if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - for prompt in prompts - .into_iter() - .filter(context_server_command::acceptable_prompt) - { - log::info!( - "registering context server command: {:?}", - prompt.name - ); - context_server_registry.register_command( - server.id.clone(), - prompt.name.as_str(), - ); - slash_command_registry.register_command( - context_server_command::ContextServerSlashCommand::new( - &server, prompt, - ), - true, - ); - } - } - } - }) - .detach(); - } - }, - ); - - cx.update_model( - &manager, - |manager: &mut context_servers::manager::ContextServerManager, cx| { - let tool_registry = ToolRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(server) = manager.get_server(server_id) { - cx.spawn(|_, _| async move { - let Some(protocol) = server.client.read().clone() else { - return; - }; - - if protocol.capable(context_servers::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { - for tool in tools.tools { - log::info!( - "registering context server tool: {:?}", - tool.name - ); - context_server_registry.register_tool( - server.id.clone(), - tool.name.as_str(), - ); - tool_registry.register_tool( - tools::context_server_tool::ContextServerTool::new( - server.id.clone(), - tool - ), - ); - } - } - } - }) - .detach(); - } - }, - ); - } - context_servers::manager::Event::ServerStopped { server_id } => { - let slash_command_registry = SlashCommandRegistry::global(cx); - let context_server_registry = ContextServerRegistry::global(cx); - if let Some(commands) = context_server_registry.get_commands(server_id) { - for command_name in commands { - slash_command_registry.unregister_command_by_name(&command_name); - context_server_registry.unregister_command(&server_id, &command_name); - } - } - - if let Some(tools) = context_server_registry.get_tools(server_id) { - let tool_registry = ToolRegistry::global(cx); - for tool_name in tools { - tool_registry.unregister_tool_by_name(&tool_name); - context_server_registry.unregister_tool(&server_id, &tool_name); - } - } - } - }, - ) - .detach(); -} - fn init_language_model_settings(cx: &mut AppContext) { update_active_language_model_from_settings(cx); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 7562901106..52139fce8c 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,4 +1,6 @@ use crate::slash_command::file_command::codeblock_fence_for_path; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, humanize_token_count, @@ -7,21 +9,20 @@ use crate::{ slash_command::{ default_command::DefaultSlashCommand, docs_command::{DocsSlashCommand, DocsSlashCommandArgs}, - file_command, SlashCommandCompletionProvider, SlashCommandRegistry, + file_command, SlashCommandCompletionProvider, }, slash_command_picker, terminal_inline_assistant::TerminalInlineAssistant, Assist, AssistantPatch, AssistantPatchStatus, CacheStatus, ConfirmCommand, Content, Context, ContextEvent, ContextId, ContextStore, ContextStoreEvent, CopyCode, CycleMessageRole, DeployHistory, DeployPromptLibrary, Edit, InlineAssistant, InsertDraggedFiles, - InsertIntoEditor, InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, - MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, ParsedSlashCommand, - PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, RequestType, - SavedContextMetadata, SlashCommandId, Split, ToggleFocus, ToggleModelSelector, + InsertIntoEditor, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId, + MessageMetadata, MessageStatus, ModelPickerDelegate, ModelSelector, NewContext, + ParsedSlashCommand, PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, + RequestType, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::Result; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; -use assistant_tool::ToolRegistry; use client::{proto, zed_urls, Client, Status}; use collections::{hash_map, BTreeSet, HashMap, HashSet}; use editor::{ @@ -112,7 +113,8 @@ pub fn init(cx: &mut AppContext) { .register_action(ContextEditor::copy_code) .register_action(ContextEditor::insert_dragged_files) .register_action(AssistantPanel::show_configuration) - .register_action(AssistantPanel::create_new_context); + .register_action(AssistantPanel::create_new_context) + .register_action(AssistantPanel::restart_context_servers); }, ) .detach(); @@ -315,10 +317,12 @@ impl AssistantPanel { cx: AsyncWindowContext, ) -> Task>> { cx.spawn(|mut cx| async move { + let slash_commands = Arc::new(SlashCommandWorkingSet::default()); + let tools = Arc::new(ToolWorkingSet::default()); let context_store = workspace .update(&mut cx, |workspace, cx| { let project = workspace.project().clone(); - ContextStore::new(project, prompt_builder.clone(), cx) + ContextStore::new(project, prompt_builder.clone(), slash_commands, tools, cx) })? .await?; @@ -1294,6 +1298,24 @@ impl AssistantPanel { .active_provider() .map_or(None, |provider| Some(provider.authenticate(cx))) } + + fn restart_context_servers( + workspace: &mut Workspace, + _action: &context_servers::Restart, + cx: &mut ViewContext, + ) { + let Some(assistant_panel) = workspace.panel::(cx) else { + return; + }; + + assistant_panel.update(cx, |assistant_panel, cx| { + assistant_panel + .context_store + .update(cx, |context_store, cx| { + context_store.restart_context_servers(cx); + }); + }); + } } impl Render for AssistantPanel { @@ -1468,6 +1490,8 @@ enum AssistError { pub struct ContextEditor { context: Model, fs: Arc, + slash_commands: Arc, + tools: Arc, workspace: WeakView, project: Model, lsp_adapter_delegate: Option>, @@ -1477,7 +1501,7 @@ pub struct ContextEditor { scroll_position: Option, remote_id: Option, pending_slash_command_creases: HashMap, CreaseId>, - invoked_slash_command_creases: HashMap, + invoked_slash_command_creases: HashMap, pending_tool_use_creases: HashMap, CreaseId>, _subscriptions: Vec, patches: HashMap, PatchViewState>, @@ -1536,8 +1560,12 @@ impl ContextEditor { let sections = context.read(cx).slash_command_output_sections().to_vec(); let patch_ranges = context.read(cx).patch_ranges().collect::>(); + let slash_commands = context.read(cx).slash_commands.clone(); + let tools = context.read(cx).tools.clone(); let mut this = Self { context, + slash_commands, + tools, editor, lsp_adapter_delegate, blocks: Default::default(), @@ -1688,7 +1716,7 @@ impl ContextEditor { } pub fn insert_command(&mut self, name: &str, cx: &mut ViewContext) { - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { self.editor.update(cx, |editor, cx| { editor.transact(cx, |editor, cx| { editor.change_selections(Some(Autoscroll::fit()), cx, |s| s.try_cancel()); @@ -1770,7 +1798,7 @@ impl ContextEditor { workspace: WeakView, cx: &mut ViewContext, ) { - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { let context = self.context.read(cx); let sections = context .slash_command_output_sections() @@ -2043,8 +2071,7 @@ impl ContextEditor { .collect::>(); for tool_use in pending_tool_uses { - let tool_registry = ToolRegistry::global(cx); - if let Some(tool) = tool_registry.tool(&tool_use.name) { + if let Some(tool) = self.tools.tool(&tool_use.name, cx) { let task = tool.run(tool_use.input, self.workspace.clone(), cx); self.context.update(cx, |context, cx| { @@ -2108,7 +2135,7 @@ impl ContextEditor { fn update_invoked_slash_command( &mut self, - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, cx: &mut ViewContext, ) { self.editor.update(cx, |editor, cx| { @@ -3719,6 +3746,19 @@ impl ContextEditor { }) } + fn render_inject_context_menu(&self, cx: &mut ViewContext) -> impl IntoElement { + slash_command_picker::SlashCommandSelector::new( + self.slash_commands.clone(), + cx.view().downgrade(), + Button::new("trigger", "Add Context") + .icon(IconName::Plus) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .icon_position(IconPosition::Start) + .tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)), + ) + } + fn render_last_error(&self, cx: &mut ViewContext) -> Option { let last_error = self.last_error.as_ref()?; @@ -4133,11 +4173,7 @@ impl Render for ContextEditor { .border_t_1() .border_color(cx.theme().colors().border_variant) .bg(cx.theme().colors().editor_background) - .child( - h_flex() - .gap_1() - .child(render_inject_context_menu(cx.view().downgrade(), cx)), - ) + .child(h_flex().gap_1().child(self.render_inject_context_menu(cx))) .child( h_flex() .w_full() @@ -4419,24 +4455,6 @@ pub struct ContextEditorToolbarItem { model_selector_menu_handle: PopoverMenuHandle>, } -fn render_inject_context_menu( - active_context_editor: WeakView, - cx: &mut WindowContext<'_>, -) -> impl IntoElement { - let commands = SlashCommandRegistry::global(cx); - - slash_command_picker::SlashCommandSelector::new( - commands.clone(), - active_context_editor, - Button::new("trigger", "Add Context") - .icon(IconName::Plus) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .icon_position(IconPosition::Start) - .tooltip(|cx| Tooltip::text("Type / to insert via keyboard", cx)), - ) -} - impl ContextEditorToolbarItem { pub fn new( workspace: &Workspace, @@ -5095,7 +5113,7 @@ fn make_lsp_adapter_delegate( enum PendingSlashCommand {} fn invoked_slash_command_fold_placeholder( - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, context: WeakModel, ) -> FoldPlaceholder { FoldPlaceholder { diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index dd1e397300..da058969b6 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,8 @@ #[cfg(test)] mod context_tests; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ prompts::PromptBuilder, slash_command::{file_command::FileCommandMetadata, SlashCommandLine}, @@ -8,10 +10,8 @@ use crate::{ }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ - SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandRegistry, - SlashCommandResult, + SlashCommandContent, SlashCommandEvent, SlashCommandOutputSection, SlashCommandResult, }; -use assistant_tool::ToolRegistry; use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; @@ -95,13 +95,13 @@ pub enum ContextOperation { version: clock::Global, }, SlashCommandStarted { - id: SlashCommandId, + id: InvokedSlashCommandId, output_range: Range, name: String, version: clock::Global, }, SlashCommandFinished { - id: SlashCommandId, + id: InvokedSlashCommandId, timestamp: clock::Lamport, error_message: Option, version: clock::Global, @@ -167,7 +167,7 @@ impl ContextOperation { }), proto::context_operation::Variant::SlashCommandStarted(message) => { Ok(Self::SlashCommandStarted { - id: SlashCommandId(language::proto::deserialize_timestamp( + id: InvokedSlashCommandId(language::proto::deserialize_timestamp( message.id.context("invalid id")?, )), output_range: language::proto::deserialize_anchor_range( @@ -198,7 +198,7 @@ impl ContextOperation { } proto::context_operation::Variant::SlashCommandCompleted(message) => { Ok(Self::SlashCommandFinished { - id: SlashCommandId(language::proto::deserialize_timestamp( + id: InvokedSlashCommandId(language::proto::deserialize_timestamp( message.id.context("invalid id")?, )), timestamp: language::proto::deserialize_timestamp( @@ -372,7 +372,7 @@ pub enum ContextEvent { updated: Vec>, }, InvokedSlashCommandChanged { - command_id: SlashCommandId, + command_id: InvokedSlashCommandId, }, ParsedSlashCommandsUpdated { removed: Vec>, @@ -513,7 +513,7 @@ struct PendingCompletion { } #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] -pub struct SlashCommandId(clock::Lamport); +pub struct InvokedSlashCommandId(clock::Lamport); #[derive(Clone, Debug)] pub struct XmlTag { @@ -543,8 +543,10 @@ pub struct Context { operations: Vec, buffer: Model, parsed_slash_commands: Vec, - invoked_slash_commands: HashMap, + invoked_slash_commands: HashMap, edits_since_last_parse: language::Subscription, + pub(crate) slash_commands: Arc, + pub(crate) tools: Arc, slash_command_output_sections: Vec>, pending_tool_uses_by_id: HashMap, PendingToolUse>, message_anchors: Vec, @@ -598,6 +600,8 @@ impl Context { project: Option>, telemetry: Option>, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, cx: &mut ModelContext, ) -> Self { Self::new( @@ -606,6 +610,8 @@ impl Context { language::Capability::ReadWrite, language_registry, prompt_builder, + slash_commands, + tools, project, telemetry, cx, @@ -619,6 +625,8 @@ impl Context { capability: language::Capability, language_registry: Arc, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -663,6 +671,8 @@ impl Context { telemetry, project, language_registry, + slash_commands, + tools, patches: Vec::new(), xml_tags: Vec::new(), prompt_builder, @@ -738,6 +748,8 @@ impl Context { path: PathBuf, language_registry: Arc, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, project: Option>, telemetry: Option>, cx: &mut ModelContext, @@ -749,6 +761,8 @@ impl Context { language::Capability::ReadWrite, language_registry, prompt_builder, + slash_commands, + tools, project, telemetry, cx, @@ -1086,7 +1100,7 @@ impl Context { pub fn invoked_slash_command( &self, - command_id: &SlashCommandId, + command_id: &InvokedSlashCommandId, ) -> Option<&InvokedSlashCommand> { self.invoked_slash_commands.get(command_id) } @@ -1455,7 +1469,7 @@ impl Context { }) .map(ToOwned::to_owned) .collect::>(); - if let Some(command) = SlashCommandRegistry::global(cx).command(name) { + if let Some(command) = self.slash_commands.command(name, cx) { if !command.requires_argument() || !arguments.is_empty() { let start_ix = offset + command_line.name.start - 1; let end_ix = offset @@ -1847,7 +1861,7 @@ impl Context { cx: &mut ModelContext, ) { let version = self.version.clone(); - let command_id = SlashCommandId(self.next_timestamp()); + let command_id = InvokedSlashCommandId(self.next_timestamp()); const PENDING_OUTPUT_END_MARKER: &str = "…"; @@ -2227,9 +2241,9 @@ impl Context { let mut request = self.to_completion_request(request_type, cx); if cx.has_flag::() { - let tool_registry = ToolRegistry::global(cx); - request.tools = tool_registry - .tools() + request.tools = self + .tools + .tools(cx) .into_iter() .map(|tool| LanguageModelRequestTool { name: tool.name(), diff --git a/crates/assistant/src/context/context_tests.rs b/crates/assistant/src/context/context_tests.rs index 615bdc5d02..c68aeb25e2 100644 --- a/crates/assistant/src/context/context_tests.rs +++ b/crates/assistant/src/context/context_tests.rs @@ -1,8 +1,10 @@ use super::{AssistantEdit, MessageCacheMetadata}; +use crate::slash_command_working_set::SlashCommandWorkingSet; +use crate::ToolWorkingSet; use crate::{ assistant_panel, prompt_library, slash_command::file_command, AssistantEditKind, CacheStatus, - Context, ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder, - SlashCommandId, + Context, ContextEvent, ContextId, ContextOperation, InvokedSlashCommandId, MessageId, + MessageStatus, PromptBuilder, }; use anyhow::Result; use assistant_slash_command::{ @@ -49,8 +51,17 @@ fn test_inserting_and_removing_messages(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -182,8 +193,17 @@ fn test_message_splitting(cx: &mut AppContext) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -277,8 +297,17 @@ fn test_messages_for_offsets(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); let message_1 = context.read(cx).message_anchors[0].clone(); @@ -383,13 +412,22 @@ async fn test_slash_commands(cx: &mut TestAppContext) { let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); #[derive(Default)] struct ContextRanges { parsed_commands: HashSet>, - command_outputs: HashMap>, + command_outputs: HashMap>, output_sections: HashSet>, } @@ -671,6 +709,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) { Some(project), None, prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), cx, ) }); @@ -934,6 +974,8 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) { Default::default(), registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1042,8 +1084,17 @@ async fn test_serialization(cx: &mut TestAppContext) { cx.update(assistant_panel::init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry.clone(), + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read_with(cx, |context, _| context.buffer.clone()); let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id); let message_1 = context.update(cx, |context, cx| { @@ -1083,6 +1134,8 @@ async fn test_serialization(cx: &mut TestAppContext) { Default::default(), registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1141,6 +1194,8 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std language::Capability::ReadWrite, registry.clone(), prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), None, None, cx, @@ -1394,8 +1449,17 @@ fn test_mark_cache_anchors(cx: &mut AppContext) { assistant_panel::init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); - let context = - cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx)); + let context = cx.new_model(|cx| { + Context::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }); let buffer = context.read(cx).buffer.clone(); // Create a test cache configuration diff --git a/crates/assistant/src/context_store.rs b/crates/assistant/src/context_store.rs index f4f03dda37..ad8ae9808f 100644 --- a/crates/assistant/src/context_store.rs +++ b/crates/assistant/src/context_store.rs @@ -1,10 +1,16 @@ +use crate::slash_command::context_server_command; use crate::{ - prompts::PromptBuilder, Context, ContextEvent, ContextId, ContextOperation, ContextVersion, - SavedContext, SavedContextMetadata, + prompts::PromptBuilder, slash_command_working_set::SlashCommandWorkingSet, Context, + ContextEvent, ContextId, ContextOperation, ContextVersion, SavedContext, SavedContextMetadata, }; +use crate::{tools, SlashCommandId, ToolId, ToolWorkingSet}; use anyhow::{anyhow, Context as _, Result}; use client::{proto, telemetry::Telemetry, Client, TypedEnvelope}; use clock::ReplicaId; +use collections::HashMap; +use command_palette_hooks::CommandPaletteFilter; +use context_servers::manager::{ContextServerManager, ContextServerSettings}; +use context_servers::CONTEXT_SERVERS_NAMESPACE; use fs::Fs; use futures::StreamExt; use fuzzy::StringMatchCandidate; @@ -16,6 +22,7 @@ use paths::contexts_dir; use project::Project; use regex::Regex; use rpc::AnyProtoClient; +use settings::{Settings as _, SettingsStore}; use std::{ cmp::Reverse, ffi::OsStr, @@ -43,9 +50,14 @@ pub struct RemoteContextMetadata { pub struct ContextStore { contexts: Vec, contexts_metadata: Vec, + context_server_manager: Model, + context_server_slash_command_ids: HashMap>, + context_server_tool_ids: HashMap>, host_contexts: Vec, fs: Arc, languages: Arc, + slash_commands: Arc, + tools: Arc, telemetry: Arc, _watch_updates: Task>, client: Arc, @@ -87,6 +99,8 @@ impl ContextStore { pub fn new( project: Model, prompt_builder: Arc, + slash_commands: Arc, + tools: Arc, cx: &mut AppContext, ) -> Task>> { let fs = project.read(cx).fs().clone(); @@ -97,12 +111,18 @@ impl ContextStore { let (mut events, _) = fs.watch(contexts_dir(), CONTEXT_WATCH_DURATION).await; let this = cx.new_model(|cx: &mut ModelContext| { + let context_server_manager = cx.new_model(|_cx| ContextServerManager::new()); let mut this = Self { contexts: Vec::new(), contexts_metadata: Vec::new(), + context_server_manager, + context_server_slash_command_ids: HashMap::default(), + context_server_tool_ids: HashMap::default(), host_contexts: Vec::new(), fs, languages, + slash_commands, + tools, telemetry, _watch_updates: cx.spawn(|this, mut cx| { async move { @@ -127,15 +147,48 @@ impl ContextStore { }; this.handle_project_changed(project, cx); this.synchronize_contexts(cx); + this.register_context_server_handlers(cx); this })?; this.update(&mut cx, |this, cx| this.reload(cx))? .await .log_err(); + + this.update(&mut cx, |this, cx| { + this.watch_context_server_settings(cx); + }) + .log_err(); + Ok(this) }) } + fn watch_context_server_settings(&self, cx: &mut ModelContext) { + cx.observe_global::(move |this, cx| { + this.context_server_manager.update(cx, |manager, cx| { + let location = this.project.read(cx).worktrees(cx).next().map(|worktree| { + settings::SettingsLocation { + worktree_id: worktree.read(cx).id(), + path: Path::new(""), + } + }); + let settings = ContextServerSettings::get(location, cx); + + manager.maintain_servers(settings, cx); + + let has_any_context_servers = !manager.servers().is_empty(); + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if has_any_context_servers { + filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); + } else { + filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); + } + }); + }) + }) + .detach(); + } + async fn handle_advertise_contexts( this: Model, envelope: TypedEnvelope, @@ -342,6 +395,8 @@ impl ContextStore { Some(self.project.clone()), Some(self.telemetry.clone()), self.prompt_builder.clone(), + self.slash_commands.clone(), + self.tools.clone(), cx, ) }); @@ -364,6 +419,8 @@ impl ContextStore { let project = self.project.clone(); let telemetry = self.telemetry.clone(); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); let request = self.client.request(proto::CreateContext { project_id }); cx.spawn(|this, mut cx| async move { let response = request.await?; @@ -376,6 +433,8 @@ impl ContextStore { capability, language_registry, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -425,6 +484,8 @@ impl ContextStore { } }); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let saved_context = load.await?; @@ -434,6 +495,8 @@ impl ContextStore { path.clone(), languages, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -500,6 +563,8 @@ impl ContextStore { context_id: context_id.to_proto(), }); let prompt_builder = self.prompt_builder.clone(); + let slash_commands = self.slash_commands.clone(); + let tools = self.tools.clone(); cx.spawn(|this, mut cx| async move { let response = request.await?; let context_proto = response.context.context("invalid context")?; @@ -510,6 +575,8 @@ impl ContextStore { capability, language_registry, prompt_builder, + slash_commands, + tools, Some(project), Some(telemetry), cx, @@ -745,4 +812,114 @@ impl ContextStore { }) }) } + + pub fn restart_context_servers(&mut self, cx: &mut ModelContext) { + cx.update_model( + &self.context_server_manager, + |context_server_manager, cx| { + for server in context_server_manager.servers() { + context_server_manager + .restart_server(&server.id, cx) + .detach_and_log_err(cx); + } + }, + ); + } + + fn register_context_server_handlers(&self, cx: &mut ModelContext) { + cx.subscribe( + &self.context_server_manager.clone(), + Self::handle_context_server_event, + ) + .detach(); + } + + fn handle_context_server_event( + &mut self, + context_server_manager: Model, + event: &context_servers::manager::Event, + cx: &mut ModelContext, + ) { + let slash_command_working_set = self.slash_commands.clone(); + let tool_working_set = self.tools.clone(); + match event { + context_servers::manager::Event::ServerStarted { server_id } => { + if let Some(server) = context_server_manager.read(cx).get_server(server_id) { + let context_server_manager = context_server_manager.clone(); + cx.spawn({ + let server = server.clone(); + let server_id = server_id.clone(); + |this, mut cx| async move { + let Some(protocol) = server.client.read().clone() else { + return; + }; + + if protocol.capable(context_servers::protocol::ServerCapability::Prompts) { + if let Some(prompts) = protocol.list_prompts().await.log_err() { + let slash_command_ids = prompts + .into_iter() + .filter(context_server_command::acceptable_prompt) + .map(|prompt| { + log::info!( + "registering context server command: {:?}", + prompt.name + ); + slash_command_working_set.insert(Arc::new( + context_server_command::ContextServerSlashCommand::new( + context_server_manager.clone(), + &server, + prompt, + ), + )) + }) + .collect::>(); + + this.update(&mut cx, |this, _cx| { + this.context_server_slash_command_ids + .insert(server_id.clone(), slash_command_ids); + }) + .log_err(); + } + } + + if protocol.capable(context_servers::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + let tool_ids = tools.tools.into_iter().map(|tool| { + log::info!("registering context server tool: {:?}", tool.name); + tool_working_set.insert( + Arc::new(tools::context_server_tool::ContextServerTool::new( + context_server_manager.clone(), + server.id.clone(), + tool, + )), + ) + + }).collect::>(); + + + this.update(&mut cx, |this, _cx| { + this.context_server_tool_ids + .insert(server_id, tool_ids); + }) + .log_err(); + } + } + } + }) + .detach(); + } + } + context_servers::manager::Event::ServerStopped { server_id } => { + if let Some(slash_command_ids) = + self.context_server_slash_command_ids.remove(server_id) + { + slash_command_working_set.remove(&slash_command_ids); + } + + if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { + tool_working_set.remove(&tool_ids); + } + } + } + } } diff --git a/crates/assistant/src/slash_command/context_server_command.rs b/crates/assistant/src/slash_command/context_server_command.rs index 5b22e76bf8..997f289c9b 100644 --- a/crates/assistant/src/slash_command/context_server_command.rs +++ b/crates/assistant/src/slash_command/context_server_command.rs @@ -8,7 +8,7 @@ use context_servers::{ manager::{ContextServer, ContextServerManager}, types::Prompt, }; -use gpui::{AppContext, Task, WeakView, WindowContext}; +use gpui::{AppContext, Model, Task, WeakView, WindowContext}; use language::{BufferSnapshot, CodeLabel, LspAdapterDelegate}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -19,15 +19,21 @@ use workspace::Workspace; use crate::slash_command::create_label_for_command; pub struct ContextServerSlashCommand { + server_manager: Model, server_id: String, prompt: Prompt, } impl ContextServerSlashCommand { - pub fn new(server: &Arc, prompt: Prompt) -> Self { + pub fn new( + server_manager: Model, + server: &Arc, + prompt: Prompt, + ) -> Self { Self { server_id: server.id.clone(), prompt, + server_manager, } } } @@ -74,18 +80,14 @@ impl SlashCommand for ContextServerSlashCommand { _workspace: Option>, cx: &mut WindowContext, ) -> Task>> { + let Ok((arg_name, arg_value)) = completion_argument(&self.prompt, arguments) else { + return Task::ready(Err(anyhow!("Failed to complete argument"))); + }; + let server_id = self.server_id.clone(); let prompt_name = self.prompt.name.clone(); - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); - let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) { - Ok(tp) => tp, - Err(e) => { - return Task::ready(Err(e)); - } - }; - if let Some(server) = manager.get_server(&server_id) { + if let Some(server) = self.server_manager.read(cx).get_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client.read().clone() else { return Err(anyhow!("Context server not initialized")); @@ -100,7 +102,7 @@ impl SlashCommand for ContextServerSlashCommand { }, ), arg_name, - arg_val, + arg_value, ) .await?; @@ -138,8 +140,7 @@ impl SlashCommand for ContextServerSlashCommand { Err(e) => return Task::ready(Err(e)), }; - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); + let manager = self.server_manager.read(cx); if let Some(server) = manager.get_server(&server_id) { cx.foreground_executor().spawn(async move { let Some(protocol) = server.client.read().clone() else { diff --git a/crates/assistant/src/slash_command_picker.rs b/crates/assistant/src/slash_command_picker.rs index b376dd8881..8e797d6184 100644 --- a/crates/assistant/src/slash_command_picker.rs +++ b/crates/assistant/src/slash_command_picker.rs @@ -1,16 +1,15 @@ use std::sync::Arc; -use assistant_slash_command::SlashCommandRegistry; - use gpui::{AnyElement, DismissEvent, SharedString, Task, WeakView}; use picker::{Picker, PickerDelegate, PickerEditorPosition}; use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverTrigger}; use crate::assistant_panel::ContextEditor; +use crate::SlashCommandWorkingSet; #[derive(IntoElement)] pub(super) struct SlashCommandSelector { - registry: Arc, + working_set: Arc, active_context_editor: WeakView, trigger: T, } @@ -51,12 +50,12 @@ pub(crate) struct SlashCommandDelegate { impl SlashCommandSelector { pub(crate) fn new( - registry: Arc, + working_set: Arc, active_context_editor: WeakView, trigger: T, ) -> Self { SlashCommandSelector { - registry, + working_set, active_context_editor, trigger, } @@ -231,11 +230,11 @@ impl PickerDelegate for SlashCommandDelegate { impl RenderOnce for SlashCommandSelector { fn render(self, cx: &mut WindowContext) -> impl IntoElement { let all_models = self - .registry - .featured_command_names() + .working_set + .featured_command_names(cx) .into_iter() .filter_map(|command_name| { - let command = self.registry.command(&command_name)?; + let command = self.working_set.command(&command_name, cx)?; let menu_text = SharedString::from(Arc::from(command.menu_text())); let label = command.label(cx); let args = label.filter_range.end.ne(&label.text.len()).then(|| { diff --git a/crates/assistant/src/slash_command_working_set.rs b/crates/assistant/src/slash_command_working_set.rs new file mode 100644 index 0000000000..93979557c1 --- /dev/null +++ b/crates/assistant/src/slash_command_working_set.rs @@ -0,0 +1,66 @@ +use assistant_slash_command::{SlashCommand, SlashCommandRegistry}; +use collections::HashMap; +use gpui::AppContext; +use parking_lot::Mutex; +use std::sync::Arc; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] +pub struct SlashCommandId(usize); + +/// A working set of slash commands for use in one instance of the Assistant Panel. +#[derive(Default)] +pub struct SlashCommandWorkingSet { + state: Mutex, +} + +#[derive(Default)] +struct WorkingSetState { + context_server_commands_by_id: HashMap>, + context_server_commands_by_name: HashMap>, + next_command_id: SlashCommandId, +} + +impl SlashCommandWorkingSet { + pub fn command(&self, name: &str, cx: &AppContext) -> Option> { + self.state + .lock() + .context_server_commands_by_name + .get(name) + .cloned() + .or_else(|| SlashCommandRegistry::global(cx).command(name)) + } + + pub fn featured_command_names(&self, cx: &AppContext) -> Vec> { + SlashCommandRegistry::global(cx).featured_command_names() + } + + pub fn insert(&self, command: Arc) -> SlashCommandId { + let mut state = self.state.lock(); + let command_id = state.next_command_id; + state.next_command_id.0 += 1; + state + .context_server_commands_by_id + .insert(command_id, command.clone()); + state.slash_commands_changed(); + command_id + } + + pub fn remove(&self, command_ids_to_remove: &[SlashCommandId]) { + let mut state = self.state.lock(); + state + .context_server_commands_by_id + .retain(|id, _| !command_ids_to_remove.contains(id)); + state.slash_commands_changed(); + } +} + +impl WorkingSetState { + fn slash_commands_changed(&mut self) { + self.context_server_commands_by_name.clear(); + self.context_server_commands_by_name.extend( + self.context_server_commands_by_id + .values() + .map(|command| (command.name(), command.clone())), + ); + } +} diff --git a/crates/assistant/src/tool_working_set.rs b/crates/assistant/src/tool_working_set.rs new file mode 100644 index 0000000000..aa2bb7a530 --- /dev/null +++ b/crates/assistant/src/tool_working_set.rs @@ -0,0 +1,75 @@ +use assistant_tool::{Tool, ToolRegistry}; +use collections::HashMap; +use gpui::AppContext; +use parking_lot::Mutex; +use std::sync::Arc; + +#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] +pub struct ToolId(usize); + +/// A working set of tools for use in one instance of the Assistant Panel. +#[derive(Default)] +pub struct ToolWorkingSet { + state: Mutex, +} + +#[derive(Default)] +struct WorkingSetState { + context_server_tools_by_id: HashMap>, + context_server_tools_by_name: HashMap>, + next_tool_id: ToolId, +} + +impl ToolWorkingSet { + pub fn tool(&self, name: &str, cx: &AppContext) -> Option> { + self.state + .lock() + .context_server_tools_by_name + .get(name) + .cloned() + .or_else(|| ToolRegistry::global(cx).tool(name)) + } + + pub fn tools(&self, cx: &AppContext) -> Vec> { + let mut tools = ToolRegistry::global(cx).tools(); + tools.extend( + self.state + .lock() + .context_server_tools_by_id + .values() + .cloned(), + ); + + tools + } + + pub fn insert(&self, command: Arc) -> ToolId { + let mut state = self.state.lock(); + let command_id = state.next_tool_id; + state.next_tool_id.0 += 1; + state + .context_server_tools_by_id + .insert(command_id, command.clone()); + state.tools_changed(); + command_id + } + + pub fn remove(&self, command_ids_to_remove: &[ToolId]) { + let mut state = self.state.lock(); + state + .context_server_tools_by_id + .retain(|id, _| !command_ids_to_remove.contains(id)); + state.tools_changed(); + } +} + +impl WorkingSetState { + fn tools_changed(&mut self) { + self.context_server_tools_by_name.clear(); + self.context_server_tools_by_name.extend( + self.context_server_tools_by_id + .values() + .map(|command| (command.name(), command.clone())), + ); + } +} diff --git a/crates/assistant/src/tools/context_server_tool.rs b/crates/assistant/src/tools/context_server_tool.rs index 93edb32b75..72bb87191f 100644 --- a/crates/assistant/src/tools/context_server_tool.rs +++ b/crates/assistant/src/tools/context_server_tool.rs @@ -2,16 +2,22 @@ use anyhow::{anyhow, bail}; use assistant_tool::Tool; use context_servers::manager::ContextServerManager; use context_servers::types; -use gpui::Task; +use gpui::{Model, Task}; pub struct ContextServerTool { + server_manager: Model, server_id: String, tool: types::Tool, } impl ContextServerTool { - pub fn new(server_id: impl Into, tool: types::Tool) -> Self { + pub fn new( + server_manager: Model, + server_id: impl Into, + tool: types::Tool, + ) -> Self { Self { + server_manager, server_id: server_id.into(), tool, } @@ -45,9 +51,7 @@ impl Tool for ContextServerTool { _workspace: gpui::WeakView, cx: &mut ui::WindowContext, ) -> gpui::Task> { - let manager = ContextServerManager::global(cx); - let manager = manager.read(cx); - if let Some(server) = manager.get_server(&self.server_id) { + if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) { cx.foreground_executor().spawn({ let tool_name = self.tool.name.clone(); async move { diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index f4c401eb78..8c819d6da7 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -6,7 +6,7 @@ use crate::{ }, }; use anyhow::{anyhow, Result}; -use assistant::{ContextStore, PromptBuilder}; +use assistant::{ContextStore, PromptBuilder, SlashCommandWorkingSet, ToolWorkingSet}; use call::{room, ActiveCall, ParticipantLocation, Room}; use client::{User, RECEIVE_TIMEOUT}; use collections::{HashMap, HashSet}; @@ -6489,11 +6489,27 @@ async fn test_context_collaboration_with_reconnect( let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); let context_store_a = cx_a - .update(|cx| ContextStore::new(project_a.clone(), prompt_builder.clone(), cx)) + .update(|cx| { + ContextStore::new( + project_a.clone(), + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }) .await .unwrap(); let context_store_b = cx_b - .update(|cx| ContextStore::new(project_b.clone(), prompt_builder.clone(), cx)) + .update(|cx| { + ContextStore::new( + project_b.clone(), + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + Arc::new(ToolWorkingSet::default()), + cx, + ) + }) .await .unwrap(); diff --git a/crates/context_servers/Cargo.toml b/crates/context_servers/Cargo.toml index 9c0336f121..73a393375c 100644 --- a/crates/context_servers/Cargo.toml +++ b/crates/context_servers/Cargo.toml @@ -27,4 +27,3 @@ settings.workspace = true smol.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true -workspace.workspace = true diff --git a/crates/context_servers/src/context_servers.rs b/crates/context_servers/src/context_servers.rs index 55634bb77c..c20b1ebdb1 100644 --- a/crates/context_servers/src/context_servers.rs +++ b/crates/context_servers/src/context_servers.rs @@ -1,40 +1,23 @@ -use gpui::{actions, AppContext, Context, ViewContext}; -use manager::ContextServerManager; -use workspace::Workspace; +use command_palette_hooks::CommandPaletteFilter; +use gpui::{actions, AppContext}; +use settings::Settings; + +use crate::manager::ContextServerSettings; pub mod client; pub mod manager; pub mod protocol; -mod registry; pub mod types; -pub use registry::*; - actions!(context_servers, [Restart]); /// The namespace for the context servers actions. -const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; +pub const CONTEXT_SERVERS_NAMESPACE: &'static str = "context_servers"; pub fn init(cx: &mut AppContext) { - log::info!("initializing context server client"); - manager::init(cx); - ContextServerRegistry::register(cx); + ContextServerSettings::register(cx); - cx.observe_new_views( - |workspace: &mut Workspace, _cx: &mut ViewContext| { - workspace.register_action(restart_servers); - }, - ) - .detach(); -} - -fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext) { - let model = ContextServerManager::global(cx); - cx.update_model(&model, |manager, cx| { - for server in manager.servers() { - manager - .restart_server(&server.id, cx) - .detach_and_log_err(cx); - } + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); }); } diff --git a/crates/context_servers/src/manager.rs b/crates/context_servers/src/manager.rs index 3c21fd53fb..34583e559a 100644 --- a/crates/context_servers/src/manager.rs +++ b/crates/context_servers/src/manager.rs @@ -14,18 +14,17 @@ //! The module also includes initialization logic to set up the context server system //! and react to changes in settings. +use std::path::Path; +use std::sync::Arc; + use collections::{HashMap, HashSet}; -use command_palette_hooks::CommandPaletteFilter; -use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task}; +use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task}; use log; use parking_lot::RwLock; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsSources, SettingsStore}; -use std::path::Path; -use std::sync::Arc; +use settings::{Settings, SettingsSources}; -use crate::CONTEXT_SERVERS_NAMESPACE; use crate::{ client::{self, Client}, types, @@ -124,7 +123,6 @@ pub enum Event { ServerStopped { server_id: String }, } -impl Global for ContextServerManager {} impl EventEmitter for ContextServerManager {} impl Default for ContextServerManager { @@ -140,14 +138,11 @@ impl ContextServerManager { pending_servers: HashSet::default(), } } - pub fn global(cx: &AppContext) -> Model { - cx.global::().0.clone() - } pub fn add_server( &mut self, config: ServerConfig, - cx: &mut ModelContext, + cx: &ModelContext, ) -> Task> { let server_id = config.id.clone(); @@ -179,11 +174,7 @@ impl ContextServerManager { self.servers.get(id).cloned() } - pub fn remove_server( - &mut self, - id: &str, - cx: &mut ModelContext, - ) -> Task> { + pub fn remove_server(&mut self, id: &str, cx: &ModelContext) -> Task> { let id = id.to_string(); cx.spawn(|this, mut cx| async move { if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? { @@ -229,75 +220,38 @@ impl ContextServerManager { self.servers.values().cloned().collect() } - pub fn model(cx: &mut AppContext) -> Model { - cx.new_model(|_cx| ContextServerManager::new()) + pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext) { + let current_servers = self + .servers() + .into_iter() + .map(|server| (server.id.clone(), server.config.clone())) + .collect::>(); + + let new_servers = settings + .servers + .iter() + .map(|config| (config.id.clone(), config.clone())) + .collect::>(); + + let servers_to_add = new_servers + .values() + .filter(|config| !current_servers.contains_key(&config.id)) + .cloned() + .collect::>(); + + let servers_to_remove = current_servers + .keys() + .filter(|id| !new_servers.contains_key(*id)) + .cloned() + .collect::>(); + + log::trace!("servers_to_add={:?}", servers_to_add); + for config in servers_to_add { + self.add_server(config, cx).detach_and_log_err(cx); + } + + for id in servers_to_remove { + self.remove_server(&id, cx).detach_and_log_err(cx); + } } } - -pub struct GlobalContextServerManager(Model); -impl Global for GlobalContextServerManager {} - -impl GlobalContextServerManager { - fn register(cx: &mut AppContext) { - let model = ContextServerManager::model(cx); - cx.set_global(Self(model)); - } -} - -pub fn init(cx: &mut AppContext) { - ContextServerSettings::register(cx); - GlobalContextServerManager::register(cx); - - CommandPaletteFilter::update_global(cx, |filter, _cx| { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - }); - - cx.observe_global::(|cx| { - let manager = ContextServerManager::global(cx); - cx.update_model(&manager, |manager, cx| { - let settings = ContextServerSettings::get_global(cx); - let current_servers = manager - .servers() - .into_iter() - .map(|server| (server.id.clone(), server.config.clone())) - .collect::>(); - - let new_servers = settings - .servers - .iter() - .map(|config| (config.id.clone(), config.clone())) - .collect::>(); - - let servers_to_add = new_servers - .values() - .filter(|config| !current_servers.contains_key(&config.id)) - .cloned() - .collect::>(); - - let servers_to_remove = current_servers - .keys() - .filter(|id| !new_servers.contains_key(*id)) - .cloned() - .collect::>(); - - log::trace!("servers_to_add={:?}", servers_to_add); - for config in servers_to_add { - manager.add_server(config, cx).detach_and_log_err(cx); - } - - for id in servers_to_remove { - manager.remove_server(&id, cx).detach_and_log_err(cx); - } - - let has_any_context_servers = !manager.servers().is_empty(); - CommandPaletteFilter::update_global(cx, |filter, _cx| { - if has_any_context_servers { - filter.show_namespace(CONTEXT_SERVERS_NAMESPACE); - } else { - filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE); - } - }); - }) - }) - .detach(); -} diff --git a/crates/context_servers/src/registry.rs b/crates/context_servers/src/registry.rs deleted file mode 100644 index 5490187034..0000000000 --- a/crates/context_servers/src/registry.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use gpui::{AppContext, Global, ReadGlobal}; -use parking_lot::RwLock; - -struct GlobalContextServerRegistry(Arc); - -impl Global for GlobalContextServerRegistry {} - -pub struct ContextServerRegistry { - command_registry: RwLock>>>, - tool_registry: RwLock>>>, -} - -impl ContextServerRegistry { - pub fn global(cx: &AppContext) -> Arc { - GlobalContextServerRegistry::global(cx).0.clone() - } - - pub fn register(cx: &mut AppContext) { - cx.set_global(GlobalContextServerRegistry(Arc::new( - ContextServerRegistry { - command_registry: RwLock::new(HashMap::default()), - tool_registry: RwLock::new(HashMap::default()), - }, - ))) - } - - pub fn register_command(&self, server_id: String, command_name: &str) { - let mut registry = self.command_registry.write(); - registry - .entry(server_id) - .or_default() - .push(command_name.into()); - } - - pub fn unregister_command(&self, server_id: &str, command_name: &str) { - let mut registry = self.command_registry.write(); - if let Some(commands) = registry.get_mut(server_id) { - commands.retain(|name| name.as_ref() != command_name); - } - } - - pub fn get_commands(&self, server_id: &str) -> Option>> { - let registry = self.command_registry.read(); - registry.get(server_id).cloned() - } - - pub fn register_tool(&self, server_id: String, tool_name: &str) { - let mut registry = self.tool_registry.write(); - registry - .entry(server_id) - .or_default() - .push(tool_name.into()); - } - - pub fn unregister_tool(&self, server_id: &str, tool_name: &str) { - let mut registry = self.tool_registry.write(); - if let Some(tools) = registry.get_mut(server_id) { - tools.retain(|name| name.as_ref() != tool_name); - } - } - - pub fn get_tools(&self, server_id: &str) -> Option>> { - let registry = self.tool_registry.read(); - registry.get(server_id).cloned() - } -}