diff --git a/crates/assistant2/src/profile_selector.rs b/crates/assistant2/src/profile_selector.rs index b9b2c45773..c0ba19016c 100644 --- a/crates/assistant2/src/profile_selector.rs +++ b/crates/assistant2/src/profile_selector.rs @@ -78,7 +78,7 @@ impl ProfileSelector { thread_store .update(cx, |this, cx| { - this.load_default_profile(cx); + this.load_profile_by_id(&profile_id, cx); }) .log_err(); } diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index b2ace1a2ac..aa2e0aa971 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::{anyhow, Result}; -use assistant_settings::AssistantSettings; +use assistant_settings::{AgentProfile, AssistantSettings}; use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; @@ -187,35 +187,42 @@ impl ThreadStore { }) } - pub fn load_default_profile(&self, cx: &mut Context) { + fn load_default_profile(&self, cx: &Context) { let assistant_settings = AssistantSettings::get_global(cx); - if let Some(profile) = assistant_settings - .profiles - .get(&assistant_settings.default_profile) - { - self.tools.disable_source(ToolSource::Native, cx); + self.load_profile_by_id(&assistant_settings.default_profile, cx); + } + + pub fn load_profile_by_id(&self, profile_id: &Arc, cx: &Context) { + let assistant_settings = AssistantSettings::get_global(cx); + + if let Some(profile) = assistant_settings.profiles.get(profile_id) { + self.load_profile(profile); + } + } + + pub fn load_profile(&self, profile: &AgentProfile) { + self.tools.disable_all_tools(); + self.tools.enable( + ToolSource::Native, + &profile + .tools + .iter() + .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) + .collect::>(), + ); + + for (context_server_id, preset) in &profile.context_servers { self.tools.enable( - ToolSource::Native, - &profile + ToolSource::ContextServer { + id: context_server_id.clone().into(), + }, + &preset .tools .iter() .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) .collect::>(), - ); - - for (context_server_id, preset) in &profile.context_servers { - self.tools.enable( - ToolSource::ContextServer { - id: context_server_id.clone().into(), - }, - &preset - .tools - .iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool.clone())) - .collect::>(), - ) - } + ) } } diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 2c1acad053..82a4455920 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -19,7 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - disabled_tools_by_source: HashMap>>, + enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -41,38 +41,23 @@ impl ToolWorkingSet { self.state.lock().tools_by_source(cx) } - pub fn are_all_tools_enabled(&self) -> bool { - let state = self.state.lock(); - state.disabled_tools_by_source.is_empty() - } - - pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool { - let state = self.state.lock(); - !state.disabled_tools_by_source.contains_key(source) - } - pub fn enabled_tools(&self, cx: &App) -> Vec> { self.state.lock().enabled_tools(cx) } - pub fn enable_all_tools(&self) { + pub fn disable_all_tools(&self) { let mut state = self.state.lock(); - state.disabled_tools_by_source.clear(); + state.disable_all_tools(); } - pub fn disable_all_tools(&self, cx: &App) { + pub fn enable_source(&self, source: ToolSource, cx: &App) { let mut state = self.state.lock(); - state.disable_all_tools(cx); + state.enable_source(source, cx); } - pub fn enable_source(&self, source: &ToolSource) { + pub fn disable_source(&self, source: &ToolSource) { let mut state = self.state.lock(); - state.enable_source(source); - } - - pub fn disable_source(&self, source: ToolSource, cx: &App) { - let mut state = self.state.lock(); - state.disable_source(source, cx); + state.disable_source(source); } pub fn insert(&self, tool: Arc) -> ToolId { @@ -159,40 +144,36 @@ impl WorkingSetState { } fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_disabled(source, name) + self.enabled_tools_by_source + .get(source) + .map_or(false, |enabled_tools| enabled_tools.contains(name)) } fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.disabled_tools_by_source - .get(source) - .map_or(false, |disabled_tools| disabled_tools.contains(name)) + !self.is_enabled(source, name) } fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .retain(|name| !tools_to_enable.contains(name)); + .extend(tools_to_enable.into_iter().cloned()); } fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc]) { - self.disabled_tools_by_source + self.enabled_tools_by_source .entry(source) .or_default() - .extend(tools_to_disable.into_iter().cloned()); + .retain(|name| !tools_to_disable.contains(name)); } - fn enable_source(&mut self, source: &ToolSource) { - self.disabled_tools_by_source.remove(source); - } - - fn disable_source(&mut self, source: ToolSource, cx: &App) { + fn enable_source(&mut self, source: ToolSource, cx: &App) { let tools_by_source = self.tools_by_source(cx); let Some(tools) = tools_by_source.get(&source) else { return; }; - self.disabled_tools_by_source.insert( + self.enabled_tools_by_source.insert( source, tools .into_iter() @@ -201,16 +182,11 @@ impl WorkingSetState { ); } - fn disable_all_tools(&mut self, cx: &App) { - let tools = self.tools_by_source(cx); + fn disable_source(&mut self, source: &ToolSource) { + self.enabled_tools_by_source.remove(source); + } - for (source, tools) in tools { - let tool_names = tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(); - - self.disable(source, &tool_names); - } + fn disable_all_tools(&mut self) { + self.enabled_tools_by_source.clear(); } }