assistant2: Allow profiles to enable all context servers (#27847)
This PR adds a new `enable_all_context_servers` field to agent profiles to allow them to enable all context servers without having to opt into them individually. The "Write" profile will now have all context servers enabled out of the box. Release Notes: - N/A
This commit is contained in:
parent
ab31eb5d51
commit
12037dc2c6
7 changed files with 68 additions and 29 deletions
|
@ -227,6 +227,10 @@ impl ManageProfilesModal {
|
|||
.as_ref()
|
||||
.map(|profile| profile.tools.clone())
|
||||
.unwrap_or_default(),
|
||||
enable_all_context_servers: base_profile
|
||||
.as_ref()
|
||||
.map(|profile| profile.enable_all_context_servers)
|
||||
.unwrap_or_default(),
|
||||
context_servers: base_profile
|
||||
.map(|profile| profile.context_servers)
|
||||
.unwrap_or_default(),
|
||||
|
|
|
@ -191,8 +191,8 @@ impl PickerDelegate for ToolPickerDelegate {
|
|||
let active_profile_id = &AssistantSettings::get_global(cx).default_profile;
|
||||
if active_profile_id == &self.profile_id {
|
||||
self.thread_store
|
||||
.update(cx, |this, _cx| {
|
||||
this.load_profile(&self.profile);
|
||||
.update(cx, |this, cx| {
|
||||
this.load_profile(&self.profile, cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
@ -212,6 +212,9 @@ impl PickerDelegate for ToolPickerDelegate {
|
|||
.or_insert_with(|| AgentProfileContent {
|
||||
name: default_profile.name.into(),
|
||||
tools: default_profile.tools,
|
||||
enable_all_context_servers: Some(
|
||||
default_profile.enable_all_context_servers,
|
||||
),
|
||||
context_servers: default_profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
|
|
|
@ -12,7 +12,8 @@ use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
|||
use futures::FutureExt as _;
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use gpui::{
|
||||
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*,
|
||||
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
|
||||
prelude::*,
|
||||
};
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
|
@ -20,7 +21,7 @@ use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
|||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings as _;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
|
||||
|
@ -36,6 +37,7 @@ pub struct ThreadStore {
|
|||
context_server_manager: Entity<ContextServerManager>,
|
||||
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
|
||||
threads: Vec<SerializedThreadMetadata>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
impl ThreadStore {
|
||||
|
@ -50,6 +52,10 @@ impl ThreadStore {
|
|||
let context_server_manager = cx.new(|cx| {
|
||||
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
|
||||
});
|
||||
let settings_subscription =
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
});
|
||||
|
||||
let this = Self {
|
||||
project,
|
||||
|
@ -58,6 +64,7 @@ impl ThreadStore {
|
|||
context_server_manager,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
_subscriptions: vec![settings_subscription],
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
|
@ -197,11 +204,11 @@ impl ThreadStore {
|
|||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
|
||||
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
|
||||
self.load_profile(profile);
|
||||
self.load_profile(profile, cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&self, profile: &AgentProfile) {
|
||||
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
|
||||
self.tools.disable_all_tools();
|
||||
self.tools.enable(
|
||||
ToolSource::Native,
|
||||
|
@ -212,17 +219,28 @@ impl ThreadStore {
|
|||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
self.tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
if profile.enable_all_context_servers {
|
||||
for context_server in self.context_server_manager.read(cx).all_servers() {
|
||||
self.tools.enable_source(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server.id().into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
self.tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -273,8 +291,9 @@ impl ThreadStore {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.context_server_tool_ids.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
@ -287,6 +306,7 @@ impl ThreadStore {
|
|||
context_server::manager::Event::ServerStopped { server_id } => {
|
||||
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ pub struct AgentProfile {
|
|||
/// The name of the profile.
|
||||
pub name: SharedString,
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
pub enable_all_context_servers: bool,
|
||||
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
||||
}
|
||||
|
||||
|
|
|
@ -352,6 +352,7 @@ impl AssistantSettingsContent {
|
|||
AgentProfileContent {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: Some(profile.enable_all_context_servers),
|
||||
context_servers: profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
|
@ -485,6 +486,8 @@ impl Default for LanguageModelSelection {
|
|||
pub struct AgentProfileContent {
|
||||
pub name: Arc<str>,
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
/// Whether all context servers are enabled by default.
|
||||
pub enable_all_context_servers: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub context_servers: IndexMap<Arc<str>, ContextServerPresetContent>,
|
||||
}
|
||||
|
@ -607,6 +610,9 @@ impl Settings for AssistantSettings {
|
|||
AgentProfile {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: profile
|
||||
.enable_all_context_servers
|
||||
.unwrap_or_default(),
|
||||
context_servers: profile
|
||||
.context_servers
|
||||
.into_iter()
|
||||
|
|
|
@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
|
|||
struct WorkingSetState {
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
enabled_sources: HashSet<ToolSource>,
|
||||
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
@ -168,21 +169,22 @@ impl WorkingSetState {
|
|||
}
|
||||
|
||||
fn enable_source(&mut self, source: ToolSource, cx: &App) {
|
||||
let tools_by_source = self.tools_by_source(cx);
|
||||
let Some(tools) = tools_by_source.get(&source) else {
|
||||
return;
|
||||
};
|
||||
self.enabled_sources.insert(source.clone());
|
||||
|
||||
self.enabled_tools_by_source.insert(
|
||||
source,
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| tool.name().into())
|
||||
.collect::<HashSet<_>>(),
|
||||
);
|
||||
let tools_by_source = self.tools_by_source(cx);
|
||||
if let Some(tools) = tools_by_source.get(&source) {
|
||||
self.enabled_tools_by_source.insert(
|
||||
source,
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| tool.name().into())
|
||||
.collect::<HashSet<_>>(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn disable_source(&mut self, source: &ToolSource) {
|
||||
self.enabled_sources.remove(source);
|
||||
self.enabled_tools_by_source.remove(source);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue