diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 02e0876189..b2c75193a4 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -46,56 +46,104 @@ impl ToolWorkingSet { } pub fn tools(&self, cx: &App) -> Vec> { - let mut tools = ToolRegistry::global(cx).tools(); - tools.extend( - self.state - .lock() - .context_server_tools_by_id - .values() - .cloned(), - ); + self.state.lock().tools(cx) + } - tools + pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { + self.state.lock().tools_by_source(cx) } pub fn are_all_tools_enabled(&self) -> bool { let state = self.state.lock(); - state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled } + pub fn enabled_tools(&self, cx: &App) -> Vec> { + self.state.lock().enabled_tools(cx) + } + pub fn enable_all_tools(&self) { let mut state = self.state.lock(); - state.disabled_tools_by_source.clear(); - state.is_scripting_tool_disabled = false; + state.enable_scripting_tool(); } pub fn disable_all_tools(&self, cx: &App) { - let tools = self.tools_by_source(cx); - - for (source, tools) in tools { - let tool_names = tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(); - - self.disable(source, &tool_names); - } - - self.disable_scripting_tool(); + let mut state = self.state.lock(); + state.disable_all_tools(cx); } - pub fn enabled_tools(&self, cx: &App) -> Vec> { - let all_tools = self.tools(cx); - - all_tools - .into_iter() - .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) - .collect() + pub fn insert(&self, tool: Arc) -> ToolId { + let mut state = self.state.lock(); + let tool_id = state.next_tool_id; + state.next_tool_id.0 += 1; + state + .context_server_tools_by_id + .insert(tool_id, tool.clone()); + state.tools_changed(); + tool_id } - pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { + pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { + self.state.lock().is_enabled(source, name) + } + + pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { + self.state.lock().is_disabled(source, name) + } + + pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc]) { + let mut state = self.state.lock(); + state.enable(source, tools_to_enable); + } + + pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc]) { + let mut state = self.state.lock(); + state.disable(source, tools_to_disable); + } + + pub fn remove(&self, tool_ids_to_remove: &[ToolId]) { + let mut state = self.state.lock(); + state + .context_server_tools_by_id + .retain(|id, _| !tool_ids_to_remove.contains(id)); + state.tools_changed(); + } + + pub fn is_scripting_tool_enabled(&self) -> bool { + let state = self.state.lock(); + !state.is_scripting_tool_disabled + } + + pub fn enable_scripting_tool(&self) { + let mut state = self.state.lock(); + state.enable_scripting_tool(); + } + + pub fn disable_scripting_tool(&self) { + let mut state = self.state.lock(); + state.disable_scripting_tool(); + } +} + +impl WorkingSetState { + fn tools_changed(&mut self) { + self.context_server_tools_by_name.clear(); + self.context_server_tools_by_name.extend( + self.context_server_tools_by_id + .values() + .map(|tool| (tool.name(), tool.clone())), + ); + } + + fn tools(&self, cx: &App) -> Vec> { + let mut tools = ToolRegistry::global(cx).tools(); + tools.extend(self.context_server_tools_by_id.values().cloned()); + + tools + } + + fn tools_by_source(&self, cx: &App) -> IndexMap>> { let mut tools_by_source = IndexMap::default(); for tool in self.tools(cx) { @@ -114,78 +162,59 @@ impl ToolWorkingSet { tools_by_source } - pub fn insert(&self, tool: Arc) -> ToolId { - let mut state = self.state.lock(); - let tool_id = state.next_tool_id; - state.next_tool_id.0 += 1; - state - .context_server_tools_by_id - .insert(tool_id, tool.clone()); - state.tools_changed(); - tool_id + fn enabled_tools(&self, cx: &App) -> Vec> { + let all_tools = self.tools(cx); + + all_tools + .into_iter() + .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) + .collect() } - pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { + fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { !self.is_disabled(source, name) } - pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - let state = self.state.lock(); - state - .disabled_tools_by_source + fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { + self.disabled_tools_by_source .get(source) .map_or(false, |disabled_tools| disabled_tools.contains(name)) } - pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc]) { - let mut state = self.state.lock(); - state - .disabled_tools_by_source + fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc]) { + self.disabled_tools_by_source .entry(source) .or_default() .retain(|name| !tools_to_enable.contains(name)); } - pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc]) { - let mut state = self.state.lock(); - state - .disabled_tools_by_source + fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc]) { + self.disabled_tools_by_source .entry(source) .or_default() .extend(tools_to_disable.into_iter().cloned()); } - pub fn remove(&self, tool_ids_to_remove: &[ToolId]) { - let mut state = self.state.lock(); - state - .context_server_tools_by_id - .retain(|id, _| !tool_ids_to_remove.contains(id)); - state.tools_changed(); + fn disable_all_tools(&mut self, cx: &App) { + let tools = self.tools_by_source(cx); + + for (source, tools) in tools { + let tool_names = tools + .into_iter() + .map(|tool| tool.name().into()) + .collect::>(); + + self.disable(source, &tool_names); + } + + self.disable_scripting_tool(); } - pub fn is_scripting_tool_enabled(&self) -> bool { - let state = self.state.lock(); - !state.is_scripting_tool_disabled + fn enable_scripting_tool(&mut self) { + self.is_scripting_tool_disabled = false; } - pub fn enable_scripting_tool(&self) { - let mut state = self.state.lock(); - state.is_scripting_tool_disabled = false; - } - - pub fn disable_scripting_tool(&self) { - let mut state = self.state.lock(); - state.is_scripting_tool_disabled = true; - } -} - -impl WorkingSetState { - fn tools_changed(&mut self) { - self.context_server_tools_by_name.clear(); - self.context_server_tools_by_name.extend( - self.context_server_tools_by_id - .values() - .map(|tool| (tool.name(), tool.clone())), - ); + fn disable_scripting_tool(&mut self) { + self.is_scripting_tool_disabled = true; } }