From 12037dc2c64d8789357d5f3f50fcabd98f0898db Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 1 Apr 2025 11:25:23 -0400 Subject: [PATCH] assistant2: Allow profiles to enable all context servers (#27847) This PR adds a new `enable_all_context_servers` field to agent profiles to allow them to enable all context servers without having to opt into them individually. The "Write" profile will now have all context servers enabled out of the box. Release Notes: - N/A --- assets/settings/default.json | 3 ++ .../manage_profiles_modal.rs | 4 ++ .../assistant_configuration/tool_picker.rs | 7 ++- crates/assistant2/src/thread_store.rs | 52 +++++++++++++------ .../assistant_settings/src/agent_profile.rs | 1 + .../src/assistant_settings.rs | 6 +++ crates/assistant_tool/src/tool_working_set.rs | 24 +++++---- 7 files changed, 68 insertions(+), 29 deletions(-) diff --git a/assets/settings/default.json b/assets/settings/default.json index 46dae6ccec..515ca6746d 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -637,6 +637,8 @@ "profiles": { "ask": { "name": "Ask", + // We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default. + // "enable_all_context_servers": true, "tools": { "diagnostics": true, "fetch": true, @@ -650,6 +652,7 @@ }, "write": { "name": "Write", + "enable_all_context_servers": true, "tools": { "bash": true, "batch-tool": true, diff --git a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs index c05c351a0d..3cdff03440 100644 --- a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs +++ b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs @@ -227,6 +227,10 @@ impl ManageProfilesModal { .as_ref() .map(|profile| profile.tools.clone()) .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), context_servers: base_profile .map(|profile| profile.context_servers) .unwrap_or_default(), diff --git a/crates/assistant2/src/assistant_configuration/tool_picker.rs b/crates/assistant2/src/assistant_configuration/tool_picker.rs index 5c1ca4bed5..d8bbf449a5 100644 --- a/crates/assistant2/src/assistant_configuration/tool_picker.rs +++ b/crates/assistant2/src/assistant_configuration/tool_picker.rs @@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate { let active_profile_id = &AssistantSettings::get_global(cx).default_profile; if active_profile_id == &self.profile_id { self.thread_store - .update(cx, |this, _cx| { - this.load_profile(&self.profile); + .update(cx, |this, cx| { + this.load_profile(&self.profile, cx); }) .log_err(); } @@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate { .or_insert_with(|| AgentProfileContent { name: default_profile.name.into(), tools: default_profile.tools, + enable_all_context_servers: Some( + default_profile.enable_all_context_servers, + ), context_servers: default_profile .context_servers .into_iter() diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 86471215f8..bd540d5d26 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool}; use futures::FutureExt as _; use futures::future::{self, BoxFuture, Shared}; use gpui::{ - App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*, + App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task, + prelude::*, }; use heed::Database; use heed::types::SerdeBincode; @@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::Project; use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; -use settings::Settings as _; +use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; @@ -36,6 +37,7 @@ pub struct ThreadStore { context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, + _subscriptions: Vec, } impl ThreadStore { @@ -50,6 +52,10 @@ impl ThreadStore { let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); + let settings_subscription = + cx.observe_global::(move |this: &mut Self, cx| { + this.load_default_profile(cx); + }); let this = Self { project, @@ -58,6 +64,7 @@ impl ThreadStore { context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), + _subscriptions: vec![settings_subscription], }; this.load_default_profile(cx); this.register_context_server_handlers(cx); @@ -197,11 +204,11 @@ impl ThreadStore { let assistant_settings = AssistantSettings::get_global(cx); if let Some(profile) = assistant_settings.profiles.get(profile_id) { - self.load_profile(profile); + self.load_profile(profile, cx); } } - pub fn load_profile(&self, profile: &AgentProfile) { + pub fn load_profile(&self, profile: &AgentProfile, cx: &Context) { self.tools.disable_all_tools(); self.tools.enable( ToolSource::Native, @@ -212,17 +219,28 @@ impl ThreadStore { .collect::>(), ); - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) + if profile.enable_all_context_servers { + for context_server in self.context_server_manager.read(cx).all_servers() { + self.tools.enable_source( + ToolSource::ContextServer { + id: context_server.id().into(), + }, + cx, + ); + } + } else { + for (context_server_id, preset) in &profile.context_servers { + self.tools.enable( + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ) + } } } @@ -273,8 +291,9 @@ impl ThreadStore { }) .collect::>(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { this.context_server_tool_ids.insert(server_id, tool_ids); + this.load_default_profile(cx); }) .log_err(); } @@ -287,6 +306,7 @@ impl ThreadStore { context_server::manager::Event::ServerStopped { server_id } => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { tool_working_set.remove(&tool_ids); + self.load_default_profile(cx); } } } diff --git a/crates/assistant_settings/src/agent_profile.rs b/crates/assistant_settings/src/agent_profile.rs index 0e9459fcf0..00a7fbcedb 100644 --- a/crates/assistant_settings/src/agent_profile.rs +++ b/crates/assistant_settings/src/agent_profile.rs @@ -9,6 +9,7 @@ pub struct AgentProfile { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, + pub enable_all_context_servers: bool, pub context_servers: IndexMap, ContextServerPreset>, } diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 102fad134f..0a5af98ab5 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -352,6 +352,7 @@ impl AssistantSettingsContent { AgentProfileContent { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: Some(profile.enable_all_context_servers), context_servers: profile .context_servers .into_iter() @@ -485,6 +486,8 @@ impl Default for LanguageModelSelection { pub struct AgentProfileContent { pub name: Arc, pub tools: IndexMap, bool>, + /// Whether all context servers are enabled by default. + pub enable_all_context_servers: Option, #[serde(default)] pub context_servers: IndexMap, ContextServerPresetContent>, } @@ -607,6 +610,9 @@ impl Settings for AssistantSettings { AgentProfile { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: profile + .enable_all_context_servers + .unwrap_or_default(), context_servers: profile .context_servers .into_iter() diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 82a4455920..97060cfdad 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,6 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, + enabled_sources: HashSet, enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -168,21 +169,22 @@ impl WorkingSetState { } fn enable_source(&mut self, source: ToolSource, cx: &App) { - let tools_by_source = self.tools_by_source(cx); - let Some(tools) = tools_by_source.get(&source) else { - return; - }; + self.enabled_sources.insert(source.clone()); - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); + let tools_by_source = self.tools_by_source(cx); + if let Some(tools) = tools_by_source.get(&source) { + self.enabled_tools_by_source.insert( + source, + tools + .into_iter() + .map(|tool| tool.name().into()) + .collect::>(), + ); + } } fn disable_source(&mut self, source: &ToolSource) { + self.enabled_sources.remove(source); self.enabled_tools_by_source.remove(source); }