From 848a99c605035b6bf26a3330319659b93fcd24f0 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 26 Mar 2025 16:26:26 -0400 Subject: [PATCH] assistant2: Rework enabled tool representation (#27527) This PR reworks how we store enabled tools in the `ToolWorkingSet`. We now track them based on which tools are explicitly enabled, rather than by the tools that have been disabled. Also fixed an issue where switching profiles wouldn't properly set the right tools. Release Notes: - N/A --- crates/assistant2/src/profile_selector.rs | 2 +- crates/assistant2/src/thread_store.rs | 53 ++++++++------- crates/assistant_tool/src/tool_working_set.rs | 68 ++++++------------- 3 files changed, 53 insertions(+), 70 deletions(-) diff --git a/crates/assistant2/src/profile_selector.rs b/crates/assistant2/src/profile_selector.rs index b9b2c45773..c0ba19016c 100644 --- a/crates/assistant2/src/profile_selector.rs +++ b/crates/assistant2/src/profile_selector.rs @@ -78,7 +78,7 @@ impl ProfileSelector { thread_store .update(cx, |this, cx| { - this.load_default_profile(cx); + this.load_profile_by_id(&profile_id, cx); }) .log_err(); } diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index b2ace1a2ac..aa2e0aa971 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{anyhow, Result}; -use assistant_settings::AssistantSettings; +use assistant_settings::{AgentProfile, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; @@ -187,35 +187,42 @@ impl ThreadStore { }) } - pub fn load_default_profile(&self, cx: &mut Context) { + fn load_default_profile(&self, cx: &Context) { let assistant_settings = AssistantSettings::get_global(cx); - if let Some(profile) = assistant_settings - .profiles - .get(&assistant_settings.default_profile) - { - self.tools.disable_source(ToolSource::Native, cx); + self.load_profile_by_id(&assistant_settings.default_profile, cx); + } + + pub fn load_profile_by_id(&self, profile_id: &Arc, cx: &Context) { + let assistant_settings = AssistantSettings::get_global(cx); + + if let Some(profile) = assistant_settings.profiles.get(profile_id) { + self.load_profile(profile); + } + } + + pub fn load_profile(&self, profile: &AgentProfile) { + self.tools.disable_all_tools(); + self.tools.enable( + ToolSource::Native, + &profile + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ); + + for (context_server_id, preset) in &profile.context_servers { self.tools.enable( - ToolSource::Native, - &profile + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset .tools .iter() .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .collect::>(), - ); - - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) - } + ) } } diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 2c1acad053..82a4455920 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,7 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - disabled_tools_by_source: HashMap>>, + enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -41,38 +41,23 @@ impl ToolWorkingSet { self.state.lock().tools_by_source(cx) } - pub fn are_all_tools_enabled(&self) -> bool { - let state = self.state.lock(); - state.disabled_tools_by_source.is_empty() - } - - pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool { - let state = self.state.lock(); - !state.disabled_tools_by_source.contains_key(source) - } - pub fn enabled_tools(&self, cx: &App) -> Vec> { self.state.lock().enabled_tools(cx) } - pub fn enable_all_tools(&self) { + pub fn disable_all_tools(&self) { let mut state = self.state.lock(); - state.disabled_tools_by_source.clear(); + state.disable_all_tools(); } - pub fn disable_all_tools(&self, cx: &App) { + pub fn enable_source(&self, source: ToolSource, cx: &App) { let mut state = self.state.lock(); - state.disable_all_tools(cx); + state.enable_source(source, cx); } - pub fn enable_source(&self, source: &ToolSource) { + pub fn disable_source(&self, source: &ToolSource) { let mut state = self.state.lock(); - state.enable_source(source); - } - - pub fn disable_source(&self, source: ToolSource, cx: &App) { - let mut state = self.state.lock(); - state.disable_source(source, cx); + state.disable_source(source); } pub fn insert(&self, tool: Arc) -> ToolId { @@ -159,40 +144,36 @@ impl WorkingSetState { } fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_disabled(source, name) + self.enabled_tools_by_source + .get(source) + .map_or(false, |enabled_tools| enabled_tools.contains(name)) } fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.disabled_tools_by_source - .get(source) - .map_or(false, |disabled_tools| disabled_tools.contains(name)) + !self.is_enabled(source, name) } fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .retain(|name| !tools_to_enable.contains(name)); + .extend(tools_to_enable.into_iter().cloned()); } fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .extend(tools_to_disable.into_iter().cloned()); + .retain(|name| !tools_to_disable.contains(name)); } - fn enable_source(&mut self, source: &ToolSource) { - self.disabled_tools_by_source.remove(source); - } - - fn disable_source(&mut self, source: ToolSource, cx: &App) { + fn enable_source(&mut self, source: ToolSource, cx: &App) { let tools_by_source = self.tools_by_source(cx); let Some(tools) = tools_by_source.get(&source) else { return; }; - self.disabled_tools_by_source.insert( + self.enabled_tools_by_source.insert( source, tools .into_iter() @@ -201,16 +182,11 @@ impl WorkingSetState { ); } - fn disable_all_tools(&mut self, cx: &App) { - let tools = self.tools_by_source(cx); + fn disable_source(&mut self, source: &ToolSource) { + self.enabled_tools_by_source.remove(source); + } - for (source, tools) in tools { - let tool_names = tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(); - - self.disable(source, &tool_names); - } + fn disable_all_tools(&mut self) { + self.enabled_tools_by_source.clear(); } }