assistant2: Allow profiles to enable all context servers (#27847)

This PR adds a new `enable_all_context_servers` field to agent profiles
to allow them to enable all context servers without having to opt into
them individually.

The "Write" profile will now have all context servers enabled out of the
box.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-01 11:25:23 -04:00 committed by GitHub
parent ab31eb5d51
commit 12037dc2c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 68 additions and 29 deletions

View file

@ -227,6 +227,10 @@ impl ManageProfilesModal {
.as_ref()
.map(|profile| profile.tools.clone())
.unwrap_or_default(),
enable_all_context_servers: base_profile
.as_ref()
.map(|profile| profile.enable_all_context_servers)
.unwrap_or_default(),
context_servers: base_profile
.map(|profile| profile.context_servers)
.unwrap_or_default(),

View file

@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate {
let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
if active_profile_id == &self.profile_id {
self.thread_store
.update(cx, |this, _cx| {
this.load_profile(&self.profile);
.update(cx, |this, cx| {
this.load_profile(&self.profile, cx);
})
.log_err();
}
@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate {
.or_insert_with(|| AgentProfileContent {
name: default_profile.name.into(),
tools: default_profile.tools,
enable_all_context_servers: Some(
default_profile.enable_all_context_servers,
),
context_servers: default_profile
.context_servers
.into_iter()

View file

@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*,
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::Project;
use prompt_store::PromptBuilder;
use serde::{Deserialize, Serialize};
use settings::Settings as _;
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
@ -36,6 +37,7 @@ pub struct ThreadStore {
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
_subscriptions: Vec<Subscription>,
}
impl ThreadStore {
@ -50,6 +52,10 @@ impl ThreadStore {
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
let this = Self {
project,
@ -58,6 +64,7 @@ impl ThreadStore {
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
_subscriptions: vec![settings_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
@ -197,11 +204,11 @@ impl ThreadStore {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile);
self.load_profile(profile, cx);
}
}
pub fn load_profile(&self, profile: &AgentProfile) {
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
@ -212,17 +219,28 @@ impl ThreadStore {
.collect::<Vec<_>>(),
);
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
}
}
}
@ -273,8 +291,9 @@ impl ThreadStore {
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
@ -287,6 +306,7 @@ impl ThreadStore {
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
self.load_default_profile(cx);
}
}
}

View file

@ -9,6 +9,7 @@ pub struct AgentProfile {
/// The name of the profile.
pub name: SharedString,
pub tools: IndexMap<Arc<str>, bool>,
pub enable_all_context_servers: bool,
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
}

View file

@ -352,6 +352,7 @@ impl AssistantSettingsContent {
AgentProfileContent {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: Some(profile.enable_all_context_servers),
context_servers: profile
.context_servers
.into_iter()
@ -485,6 +486,8 @@ impl Default for LanguageModelSelection {
pub struct AgentProfileContent {
pub name: Arc<str>,
pub tools: IndexMap<Arc<str>, bool>,
/// Whether all context servers are enabled by default.
pub enable_all_context_servers: Option<bool>,
#[serde(default)]
pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>,
}
@ -607,6 +610,9 @@ impl Settings for AssistantSettings {
AgentProfile {
name: profile.name.into(),
tools: profile.tools,
enable_all_context_servers: profile
.enable_all_context_servers
.unwrap_or_default(),
context_servers: profile
.context_servers
.into_iter()

View file

@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId,
}
@ -168,21 +169,22 @@ impl WorkingSetState {
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
let tools_by_source = self.tools_by_source(cx);
let Some(tools) = tools_by_source.get(&source) else {
return;
};
self.enabled_sources.insert(source.clone());
self.enabled_tools_by_source.insert(
source,
tools
.into_iter()
.map(|tool| tool.name().into())
.collect::<HashSet<_>>(),
);
let tools_by_source = self.tools_by_source(cx);
if let Some(tools) = tools_by_source.get(&source) {
self.enabled_tools_by_source.insert(
source,
tools
.into_iter()
.map(|tool| tool.name().into())
.collect::<HashSet<_>>(),
);
}
}
fn disable_source(&mut self, source: &ToolSource) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
}