assistant2: Rework enabled tool representation (#27527)

This PR reworks how we store enabled tools in the `ToolWorkingSet`.

We now track them based on which tools are explicitly enabled, rather
than by the tools that have been disabled.

Also fixed an issue where switching profiles wouldn't properly set the
right tools.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-26 16:26:26 -04:00 committed by GitHub
parent 435a36b9f9
commit 848a99c605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 53 additions and 70 deletions

View file

@ -78,7 +78,7 @@ impl ProfileSelector {
thread_store thread_store
.update(cx, |this, cx| { .update(cx, |this, cx| {
this.load_default_profile(cx); this.load_profile_by_id(&profile_id, cx);
}) })
.log_err(); .log_err();
} }

View file

@ -3,7 +3,7 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use assistant_settings::AssistantSettings; use assistant_settings::{AgentProfile, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use collections::HashMap; use collections::HashMap;
@ -187,35 +187,42 @@ impl ThreadStore {
}) })
} }
pub fn load_default_profile(&self, cx: &mut Context<Self>) { fn load_default_profile(&self, cx: &Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx); let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings self.load_profile_by_id(&assistant_settings.default_profile, cx);
.profiles }
.get(&assistant_settings.default_profile)
{ pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
self.tools.disable_source(ToolSource::Native, cx); let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile);
}
}
pub fn load_profile(&self, profile: &AgentProfile) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable( self.tools.enable(
ToolSource::Native, ToolSource::ContextServer {
&profile id: context_server_id.clone().into(),
},
&preset
.tools .tools
.iter() .iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); )
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
}
} }
} }

View file

@ -19,7 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState { struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>, context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>, context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>, enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId, next_tool_id: ToolId,
} }
@ -41,38 +41,23 @@ impl ToolWorkingSet {
self.state.lock().tools_by_source(cx) self.state.lock().tools_by_source(cx)
} }
pub fn are_all_tools_enabled(&self) -> bool {
let state = self.state.lock();
state.disabled_tools_by_source.is_empty()
}
pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
let state = self.state.lock();
!state.disabled_tools_by_source.contains_key(source)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> { pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx) self.state.lock().enabled_tools(cx)
} }
pub fn enable_all_tools(&self) { pub fn disable_all_tools(&self) {
let mut state = self.state.lock(); let mut state = self.state.lock();
state.disabled_tools_by_source.clear(); state.disable_all_tools();
} }
pub fn disable_all_tools(&self, cx: &App) { pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock(); let mut state = self.state.lock();
state.disable_all_tools(cx); state.enable_source(source, cx);
} }
pub fn enable_source(&self, source: &ToolSource) { pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock(); let mut state = self.state.lock();
state.enable_source(source); state.disable_source(source);
}
pub fn disable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.disable_source(source, cx);
} }
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId { pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
@ -159,40 +144,36 @@ impl WorkingSetState {
} }
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_disabled(source, name) self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
} }
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool { fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.disabled_tools_by_source !self.is_enabled(source, name)
.get(source)
.map_or(false, |disabled_tools| disabled_tools.contains(name))
} }
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) { fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.disabled_tools_by_source self.enabled_tools_by_source
.entry(source) .entry(source)
.or_default() .or_default()
.retain(|name| !tools_to_enable.contains(name)); .extend(tools_to_enable.into_iter().cloned());
} }
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) { fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.disabled_tools_by_source self.enabled_tools_by_source
.entry(source) .entry(source)
.or_default() .or_default()
.extend(tools_to_disable.into_iter().cloned()); .retain(|name| !tools_to_disable.contains(name));
} }
fn enable_source(&mut self, source: &ToolSource) { fn enable_source(&mut self, source: ToolSource, cx: &App) {
self.disabled_tools_by_source.remove(source);
}
fn disable_source(&mut self, source: ToolSource, cx: &App) {
let tools_by_source = self.tools_by_source(cx); let tools_by_source = self.tools_by_source(cx);
let Some(tools) = tools_by_source.get(&source) else { let Some(tools) = tools_by_source.get(&source) else {
return; return;
}; };
self.disabled_tools_by_source.insert( self.enabled_tools_by_source.insert(
source, source,
tools tools
.into_iter() .into_iter()
@ -201,16 +182,11 @@ impl WorkingSetState {
); );
} }
fn disable_all_tools(&mut self, cx: &App) { fn disable_source(&mut self, source: &ToolSource) {
let tools = self.tools_by_source(cx); self.enabled_tools_by_source.remove(source);
}
for (source, tools) in tools { fn disable_all_tools(&mut self) {
let tool_names = tools self.enabled_tools_by_source.clear();
.into_iter()
.map(|tool| tool.name().into())
.collect::<Vec<_>>();
self.disable(source, &tool_names);
}
} }
} }