assistant2: Rework enabled tool representation (#27527)
This PR reworks how we store enabled tools in the `ToolWorkingSet`. We now track them based on which tools are explicitly enabled, rather than by the tools that have been disabled. Also fixed an issue where switching profiles wouldn't properly set the right tools. Release Notes: - N/A
This commit is contained in:
parent
435a36b9f9
commit
848a99c605
3 changed files with 53 additions and 70 deletions
|
@ -78,7 +78,7 @@ impl ProfileSelector {
|
|||
|
||||
thread_store
|
||||
.update(cx, |this, cx| {
|
||||
this.load_default_profile(cx);
|
||||
this.load_profile_by_id(&profile_id, cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::path::PathBuf;
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_settings::{AgentProfile, AssistantSettings};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
|
@ -187,14 +187,22 @@ impl ThreadStore {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn load_default_profile(&self, cx: &mut Context<Self>) {
|
||||
fn load_default_profile(&self, cx: &Context<Self>) {
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
|
||||
if let Some(profile) = assistant_settings
|
||||
.profiles
|
||||
.get(&assistant_settings.default_profile)
|
||||
{
|
||||
self.tools.disable_source(ToolSource::Native, cx);
|
||||
self.load_profile_by_id(&assistant_settings.default_profile, cx);
|
||||
}
|
||||
|
||||
pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
|
||||
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
|
||||
self.load_profile(profile);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&self, profile: &AgentProfile) {
|
||||
self.tools.disable_all_tools();
|
||||
self.tools.enable(
|
||||
ToolSource::Native,
|
||||
&profile
|
||||
|
@ -217,7 +225,6 @@ impl ThreadStore {
|
|||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||
cx.subscribe(
|
||||
|
|
|
@ -19,7 +19,7 @@ pub struct ToolWorkingSet {
|
|||
struct WorkingSetState {
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
|
||||
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
||||
|
@ -41,38 +41,23 @@ impl ToolWorkingSet {
|
|||
self.state.lock().tools_by_source(cx)
|
||||
}
|
||||
|
||||
pub fn are_all_tools_enabled(&self) -> bool {
|
||||
let state = self.state.lock();
|
||||
state.disabled_tools_by_source.is_empty()
|
||||
}
|
||||
|
||||
pub fn are_all_tools_from_source_enabled(&self, source: &ToolSource) -> bool {
|
||||
let state = self.state.lock();
|
||||
!state.disabled_tools_by_source.contains_key(source)
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
self.state.lock().enabled_tools(cx)
|
||||
}
|
||||
|
||||
pub fn enable_all_tools(&self) {
|
||||
pub fn disable_all_tools(&self) {
|
||||
let mut state = self.state.lock();
|
||||
state.disabled_tools_by_source.clear();
|
||||
state.disable_all_tools();
|
||||
}
|
||||
|
||||
pub fn disable_all_tools(&self, cx: &App) {
|
||||
pub fn enable_source(&self, source: ToolSource, cx: &App) {
|
||||
let mut state = self.state.lock();
|
||||
state.disable_all_tools(cx);
|
||||
state.enable_source(source, cx);
|
||||
}
|
||||
|
||||
pub fn enable_source(&self, source: &ToolSource) {
|
||||
pub fn disable_source(&self, source: &ToolSource) {
|
||||
let mut state = self.state.lock();
|
||||
state.enable_source(source);
|
||||
}
|
||||
|
||||
pub fn disable_source(&self, source: ToolSource, cx: &App) {
|
||||
let mut state = self.state.lock();
|
||||
state.disable_source(source, cx);
|
||||
state.disable_source(source);
|
||||
}
|
||||
|
||||
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
|
@ -159,40 +144,36 @@ impl WorkingSetState {
|
|||
}
|
||||
|
||||
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
!self.is_disabled(source, name)
|
||||
self.enabled_tools_by_source
|
||||
.get(source)
|
||||
.map_or(false, |enabled_tools| enabled_tools.contains(name))
|
||||
}
|
||||
|
||||
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.disabled_tools_by_source
|
||||
.get(source)
|
||||
.map_or(false, |disabled_tools| disabled_tools.contains(name))
|
||||
!self.is_enabled(source, name)
|
||||
}
|
||||
|
||||
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
|
||||
self.disabled_tools_by_source
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.retain(|name| !tools_to_enable.contains(name));
|
||||
.extend(tools_to_enable.into_iter().cloned());
|
||||
}
|
||||
|
||||
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
|
||||
self.disabled_tools_by_source
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.extend(tools_to_disable.into_iter().cloned());
|
||||
.retain(|name| !tools_to_disable.contains(name));
|
||||
}
|
||||
|
||||
fn enable_source(&mut self, source: &ToolSource) {
|
||||
self.disabled_tools_by_source.remove(source);
|
||||
}
|
||||
|
||||
fn disable_source(&mut self, source: ToolSource, cx: &App) {
|
||||
fn enable_source(&mut self, source: ToolSource, cx: &App) {
|
||||
let tools_by_source = self.tools_by_source(cx);
|
||||
let Some(tools) = tools_by_source.get(&source) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.disabled_tools_by_source.insert(
|
||||
self.enabled_tools_by_source.insert(
|
||||
source,
|
||||
tools
|
||||
.into_iter()
|
||||
|
@ -201,16 +182,11 @@ impl WorkingSetState {
|
|||
);
|
||||
}
|
||||
|
||||
fn disable_all_tools(&mut self, cx: &App) {
|
||||
let tools = self.tools_by_source(cx);
|
||||
|
||||
for (source, tools) in tools {
|
||||
let tool_names = tools
|
||||
.into_iter()
|
||||
.map(|tool| tool.name().into())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.disable(source, &tool_names);
|
||||
fn disable_source(&mut self, source: &ToolSource) {
|
||||
self.enabled_tools_by_source.remove(source);
|
||||
}
|
||||
|
||||
fn disable_all_tools(&mut self) {
|
||||
self.enabled_tools_by_source.clear();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue