diff --git a/assets/settings/default.json b/assets/settings/default.json index 46dae6ccec..515ca6746d 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -637,6 +637,8 @@ "profiles": { "ask": { "name": "Ask", + // We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default. + // "enable_all_context_servers": true, "tools": { "diagnostics": true, "fetch": true, @@ -650,6 +652,7 @@ }, "write": { "name": "Write", + "enable_all_context_servers": true, "tools": { "bash": true, "batch-tool": true, diff --git a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs index c05c351a0d..3cdff03440 100644 --- a/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs +++ b/crates/assistant2/src/assistant_configuration/manage_profiles_modal.rs @@ -227,6 +227,10 @@ impl ManageProfilesModal { .as_ref() .map(|profile| profile.tools.clone()) .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), context_servers: base_profile .map(|profile| profile.context_servers) .unwrap_or_default(), diff --git a/crates/assistant2/src/assistant_configuration/tool_picker.rs b/crates/assistant2/src/assistant_configuration/tool_picker.rs index 5c1ca4bed5..d8bbf449a5 100644 --- a/crates/assistant2/src/assistant_configuration/tool_picker.rs +++ b/crates/assistant2/src/assistant_configuration/tool_picker.rs @@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate { let active_profile_id = &AssistantSettings::get_global(cx).default_profile; if active_profile_id == &self.profile_id { self.thread_store - .update(cx, |this, _cx| { - this.load_profile(&self.profile); + .update(cx, |this, cx| { + this.load_profile(&self.profile, cx); }) .log_err(); } @@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate { .or_insert_with(|| AgentProfileContent { name: default_profile.name.into(), tools: default_profile.tools, + enable_all_context_servers: Some( + default_profile.enable_all_context_servers, + ), context_servers: default_profile .context_servers .into_iter() diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 86471215f8..bd540d5d26 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool}; use futures::FutureExt as _; use futures::future::{self, BoxFuture, Shared}; use gpui::{ - App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*, + App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task, + prelude::*, }; use heed::Database; use heed::types::SerdeBincode; @@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage}; use project::Project; use prompt_store::PromptBuilder; use serde::{Deserialize, Serialize}; -use settings::Settings as _; +use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; @@ -36,6 +37,7 @@ pub struct ThreadStore { context_server_manager: Entity, context_server_tool_ids: HashMap, Vec>, threads: Vec, + _subscriptions: Vec, } impl ThreadStore { @@ -50,6 +52,10 @@ impl ThreadStore { let context_server_manager = cx.new(|cx| { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); + let settings_subscription = + cx.observe_global::(move |this: &mut Self, cx| { + this.load_default_profile(cx); + }); let this = Self { project, @@ -58,6 +64,7 @@ impl ThreadStore { context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), + _subscriptions: vec![settings_subscription], }; this.load_default_profile(cx); this.register_context_server_handlers(cx); @@ -197,11 +204,11 @@ impl ThreadStore { let assistant_settings = AssistantSettings::get_global(cx); if let Some(profile) = assistant_settings.profiles.get(profile_id) { - self.load_profile(profile); + self.load_profile(profile, cx); } } - pub fn load_profile(&self, profile: &AgentProfile) { + pub fn load_profile(&self, profile: &AgentProfile, cx: &Context) { self.tools.disable_all_tools(); self.tools.enable( ToolSource::Native, @@ -212,17 +219,28 @@ impl ThreadStore { .collect::>(), ); - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) + if profile.enable_all_context_servers { + for context_server in self.context_server_manager.read(cx).all_servers() { + self.tools.enable_source( + ToolSource::ContextServer { + id: context_server.id().into(), + }, + cx, + ); + } + } else { + for (context_server_id, preset) in &profile.context_servers { + self.tools.enable( + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ) + } } } @@ -273,8 +291,9 @@ impl ThreadStore { }) .collect::>(); - this.update(cx, |this, _cx| { + this.update(cx, |this, cx| { this.context_server_tool_ids.insert(server_id, tool_ids); + this.load_default_profile(cx); }) .log_err(); } @@ -287,6 +306,7 @@ impl ThreadStore { context_server::manager::Event::ServerStopped { server_id } => { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { tool_working_set.remove(&tool_ids); + self.load_default_profile(cx); } } } diff --git a/crates/assistant_settings/src/agent_profile.rs b/crates/assistant_settings/src/agent_profile.rs index 0e9459fcf0..00a7fbcedb 100644 --- a/crates/assistant_settings/src/agent_profile.rs +++ b/crates/assistant_settings/src/agent_profile.rs @@ -9,6 +9,7 @@ pub struct AgentProfile { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, + pub enable_all_context_servers: bool, pub context_servers: IndexMap, ContextServerPreset>, } diff --git a/crates/assistant_settings/src/assistant_settings.rs b/crates/assistant_settings/src/assistant_settings.rs index 102fad134f..0a5af98ab5 100644 --- a/crates/assistant_settings/src/assistant_settings.rs +++ b/crates/assistant_settings/src/assistant_settings.rs @@ -352,6 +352,7 @@ impl AssistantSettingsContent { AgentProfileContent { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: Some(profile.enable_all_context_servers), context_servers: profile .context_servers .into_iter() @@ -485,6 +486,8 @@ impl Default for LanguageModelSelection { pub struct AgentProfileContent { pub name: Arc, pub tools: IndexMap, bool>, + /// Whether all context servers are enabled by default. + pub enable_all_context_servers: Option, #[serde(default)] pub context_servers: IndexMap, ContextServerPresetContent>, } @@ -607,6 +610,9 @@ impl Settings for AssistantSettings { AgentProfile { name: profile.name.into(), tools: profile.tools, + enable_all_context_servers: profile + .enable_all_context_servers + .unwrap_or_default(), context_servers: profile .context_servers .into_iter() diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 82a4455920..97060cfdad 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,6 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, + enabled_sources: HashSet, enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -168,21 +169,22 @@ impl WorkingSetState { } fn enable_source(&mut self, source: ToolSource, cx: &App) { - let tools_by_source = self.tools_by_source(cx); - let Some(tools) = tools_by_source.get(&source) else { - return; - }; + self.enabled_sources.insert(source.clone()); - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); + let tools_by_source = self.tools_by_source(cx); + if let Some(tools) = tools_by_source.get(&source) { + self.enabled_tools_by_source.insert( + source, + tools + .into_iter() + .map(|tool| tool.name().into()) + .collect::>(), + ); + } } fn disable_source(&mut self, source: &ToolSource) { + self.enabled_sources.remove(source); self.enabled_tools_by_source.remove(source); }