assistant2: Allow profiles to enable all context servers (#27847)

This PR adds a new `enable_all_context_servers` field to agent profiles
to allow them to enable all context servers without having to opt into
them individually.

The "Write" profile will now have all context servers enabled out of the
box.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-01 11:25:23 -04:00 committed by GitHub
parent ab31eb5d51
commit 12037dc2c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 68 additions and 29 deletions

View file

@ -637,6 +637,8 @@
"profiles": { "profiles": {
"ask": { "ask": {
"name": "Ask", "name": "Ask",
// We don't know which of the context server tools are safe for the "Ask" profile, so we don't enable them by default.
// "enable_all_context_servers": true,
"tools": { "tools": {
"diagnostics": true, "diagnostics": true,
"fetch": true, "fetch": true,
@ -650,6 +652,7 @@
}, },
"write": { "write": {
"name": "Write", "name": "Write",
"enable_all_context_servers": true,
"tools": { "tools": {
"bash": true, "bash": true,
"batch-tool": true, "batch-tool": true,

View file

@ -227,6 +227,10 @@ impl ManageProfilesModal {
.as_ref() .as_ref()
.map(|profile| profile.tools.clone()) .map(|profile| profile.tools.clone())
.unwrap_or_default(), .unwrap_or_default(),
enable_all_context_servers: base_profile
.as_ref()
.map(|profile| profile.enable_all_context_servers)
.unwrap_or_default(),
context_servers: base_profile context_servers: base_profile
.map(|profile| profile.context_servers) .map(|profile| profile.context_servers)
.unwrap_or_default(), .unwrap_or_default(),

View file

@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate {
let active_profile_id = &AssistantSettings::get_global(cx).default_profile; let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
if active_profile_id == &self.profile_id { if active_profile_id == &self.profile_id {
self.thread_store self.thread_store
.update(cx, |this, _cx| { .update(cx, |this, cx| {
this.load_profile(&self.profile); this.load_profile(&self.profile, cx);
}) })
.log_err(); .log_err();
} }
@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate {
.or_insert_with(|| AgentProfileContent { .or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(), name: default_profile.name.into(),
tools: default_profile.tools, tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile context_servers: default_profile
.context_servers .context_servers
.into_iter() .into_iter()

View file

@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use futures::FutureExt as _; use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared}; use futures::future::{self, BoxFuture, Shared};
use gpui::{ use gpui::{
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
prelude::*,
}; };
use heed::Database; use heed::Database;
use heed::types::SerdeBincode; use heed::types::SerdeBincode;
@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::Project; use project::Project;
use prompt_store::PromptBuilder; use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::Settings as _; use settings::{Settings as _, SettingsStore};
use util::ResultExt as _; use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
@ -36,6 +37,7 @@ pub struct ThreadStore {
context_server_manager: Entity<ContextServerManager>, context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>, context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>, threads: Vec<SerializedThreadMetadata>,
_subscriptions: Vec<Subscription>,
} }
impl ThreadStore { impl ThreadStore {
@ -50,6 +52,10 @@ impl ThreadStore {
let context_server_manager = cx.new(|cx| { let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx) ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
}); });
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
let this = Self { let this = Self {
project, project,
@ -58,6 +64,7 @@ impl ThreadStore {
context_server_manager, context_server_manager,
context_server_tool_ids: HashMap::default(), context_server_tool_ids: HashMap::default(),
threads: Vec::new(), threads: Vec::new(),
_subscriptions: vec![settings_subscription],
}; };
this.load_default_profile(cx); this.load_default_profile(cx);
this.register_context_server_handlers(cx); this.register_context_server_handlers(cx);
@ -197,11 +204,11 @@ impl ThreadStore {
let assistant_settings = AssistantSettings::get_global(cx); let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) { if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile); self.load_profile(profile, cx);
} }
} }
pub fn load_profile(&self, profile: &AgentProfile) { pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools(); self.tools.disable_all_tools();
self.tools.enable( self.tools.enable(
ToolSource::Native, ToolSource::Native,
@ -212,17 +219,28 @@ impl ThreadStore {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
for (context_server_id, preset) in &profile.context_servers { if profile.enable_all_context_servers {
self.tools.enable( for context_server in self.context_server_manager.read(cx).all_servers() {
ToolSource::ContextServer { self.tools.enable_source(
id: context_server_id.clone().into(), ToolSource::ContextServer {
}, id: context_server.id().into(),
&preset },
.tools cx,
.iter() );
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) }
.collect::<Vec<_>>(), } else {
) for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
}
} }
} }
@ -273,8 +291,9 @@ impl ThreadStore {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
this.update(cx, |this, _cx| { this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids); this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
}) })
.log_err(); .log_err();
} }
@ -287,6 +306,7 @@ impl ThreadStore {
context_server::manager::Event::ServerStopped { server_id } => { context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids); tool_working_set.remove(&tool_ids);
self.load_default_profile(cx);
} }
} }
} }

View file

@ -9,6 +9,7 @@ pub struct AgentProfile {
/// The name of the profile. /// The name of the profile.
pub name: SharedString, pub name: SharedString,
pub tools: IndexMap<Arc<str>, bool>, pub tools: IndexMap<Arc<str>, bool>,
pub enable_all_context_servers: bool,
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>, pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
} }

View file

@ -352,6 +352,7 @@ impl AssistantSettingsContent {
AgentProfileContent { AgentProfileContent {
name: profile.name.into(), name: profile.name.into(),
tools: profile.tools, tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile context_servers: profile
.context_servers .context_servers
.into_iter() .into_iter()
@ -485,6 +486,8 @@ impl Default for LanguageModelSelection {
pub struct AgentProfileContent { pub struct AgentProfileContent {
pub name: Arc<str>, pub name: Arc<str>,
pub tools: IndexMap<Arc<str>, bool>, pub tools: IndexMap<Arc<str>, bool>,
/// Whether all context servers are enabled by default.
pub enable_all_context_servers: Option<bool>,
#[serde(default)] #[serde(default)]
pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>, pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>,
} }
@ -607,6 +610,9 @@ impl Settings for AssistantSettings {
AgentProfile { AgentProfile {
name: profile.name.into(), name: profile.name.into(),
tools: profile.tools, tools: profile.tools,
enable_all_context_servers: profile
.enable_all_context_servers
.unwrap_or_default(),
context_servers: profile context_servers: profile
.context_servers .context_servers
.into_iter() .into_iter()

View file

@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState { struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>, context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>, context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>, enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId, next_tool_id: ToolId,
} }
@ -168,21 +169,22 @@ impl WorkingSetState {
} }
fn enable_source(&mut self, source: ToolSource, cx: &App) { fn enable_source(&mut self, source: ToolSource, cx: &App) {
let tools_by_source = self.tools_by_source(cx); self.enabled_sources.insert(source.clone());
let Some(tools) = tools_by_source.get(&source) else {
return;
};
self.enabled_tools_by_source.insert( let tools_by_source = self.tools_by_source(cx);
source, if let Some(tools) = tools_by_source.get(&source) {
tools self.enabled_tools_by_source.insert(
.into_iter() source,
.map(|tool| tool.name().into()) tools
.collect::<HashSet<_>>(), .into_iter()
); .map(|tool| tool.name().into())
.collect::<HashSet<_>>(),
);
}
} }
fn disable_source(&mut self, source: &ToolSource) { fn disable_source(&mut self, source: &ToolSource) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source); self.enabled_tools_by_source.remove(source);
} }