From 4e6c37d23b7c964d02ca9093ff6a6c11fc9c4164 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 11 Mar 2025 16:50:18 -0400 Subject: [PATCH] assistant2: Add tool selector (#26480) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a tool selector to Assistant 2 to facilitate customizing the tools that the model sees: Screenshot 2025-03-11 at 4 25 31 PM Release Notes: - N/A --- crates/assistant2/src/assistant.rs | 1 + crates/assistant2/src/message_editor.rs | 40 ++++++----- crates/assistant2/src/thread.rs | 17 ++--- crates/assistant2/src/tool_selector.rs | 70 +++++++++++++++++++ crates/assistant_tool/src/assistant_tool.rs | 2 +- crates/assistant_tool/src/tool_working_set.rs | 69 ++++++++++++++++-- 6 files changed, 166 insertions(+), 33 deletions(-) create mode 100644 crates/assistant2/src/tool_selector.rs diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index ee84582e3b..a7cb831b53 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -16,6 +16,7 @@ mod terminal_inline_assistant; mod thread; mod thread_history; mod thread_store; +mod tool_selector; mod tool_use; mod ui; diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 67199e141f..2fa60dc5a4 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -28,6 +28,7 @@ use crate::context_store::{refresh_context_store_text, ContextStore}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::thread::{RequestKind, Thread}; use crate::thread_store::ThreadStore; +use crate::tool_selector::ToolSelector; use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker}; pub struct MessageEditor { @@ -39,6 +40,7 @@ pub struct MessageEditor { inline_context_picker: Entity, inline_context_picker_menu_handle: PopoverMenuHandle, model_selector: Entity, + tool_selector: Entity, use_tools: bool, edits_expanded: bool, _subscriptions: Vec, @@ -53,6 +55,7 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) -> Self { + let tools = thread.read(cx).tools().clone(); let context_store = cx.new(|_cx| ContextStore::new(workspace.clone())); let context_picker_menu_handle = PopoverMenuHandle::default(); let inline_context_picker_menu_handle = PopoverMenuHandle::default(); @@ -118,6 +121,7 @@ impl MessageEditor { cx, ) }), + tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)), use_tools: false, edits_expanded: false, _subscriptions: subscriptions, @@ -538,23 +542,25 @@ impl Render for MessageEditor { h_flex() .justify_between() .child( - Switch::new("use-tools", self.use_tools.into()) - .label("Tools") - .on_click(cx.listener( - |this, selection, _window, _cx| { - this.use_tools = match selection { - ToggleState::Selected => true, - ToggleState::Unselected - | ToggleState::Indeterminate => false, - }; - }, - )) - .key_binding(KeyBinding::for_action_in( - &ChatMode, - &focus_handle, - window, - cx, - )), + h_flex().gap_2().child(self.tool_selector.clone()).child( + Switch::new("use-tools", self.use_tools.into()) + .label("Tools") + .on_click(cx.listener( + |this, selection, _window, _cx| { + this.use_tools = match selection { + ToggleState::Selected => true, + ToggleState::Unselected + | ToggleState::Indeterminate => false, + }; + }, + )) + .key_binding(KeyBinding::for_action_in( + &ChatMode, + &focus_handle, + window, + cx, + )), + ), ) .child( h_flex().gap_1().child(self.model_selector.clone()).child( diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 6f8d72ef7c..8495ab2eb4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -355,16 +355,13 @@ impl Thread { input_schema: ScriptingTool::input_schema(), }); - tools.extend( - self.tools() - .tools(cx) - .into_iter() - .map(|tool| LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool.input_schema(), - }), - ); + tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| { + LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + } + })); request.tools = tools; } diff --git a/crates/assistant2/src/tool_selector.rs b/crates/assistant2/src/tool_selector.rs new file mode 100644 index 0000000000..6432668a77 --- /dev/null +++ b/crates/assistant2/src/tool_selector.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use assistant_tool::{ToolSource, ToolWorkingSet}; +use gpui::Entity; +use ui::{prelude::*, ContextMenu, IconButtonShape, PopoverMenu, Tooltip}; + +pub struct ToolSelector { + tools: Arc, +} + +impl ToolSelector { + pub fn new(tools: Arc, _cx: &mut Context) -> Self { + Self { tools } + } + + fn build_context_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + ContextMenu::build(window, cx, |mut menu, _window, cx| { + let tools_by_source = self.tools.tools_by_source(cx); + + for (source, tools) in tools_by_source { + menu = match source { + ToolSource::Native => menu.header("Zed"), + ToolSource::ContextServer { id } => menu.separator().header(id), + }; + + for tool in tools { + let source = tool.source(); + let name = tool.name().into(); + let is_enabled = self.tools.is_enabled(&source, &name); + + menu = + menu.toggleable_entry(tool.name(), is_enabled, IconPosition::End, None, { + let tools = self.tools.clone(); + move |_window, _cx| { + if is_enabled { + tools.disable(source.clone(), &[name.clone()]); + } else { + tools.enable(source.clone(), &[name.clone()]); + } + } + }); + } + } + + menu + }) + } +} + +impl Render for ToolSelector { + fn render(&mut self, _window: &mut Window, cx: &mut Context<'_, Self>) -> impl IntoElement { + let this = cx.entity().clone(); + PopoverMenu::new("tool-selector") + .menu(move |window, cx| { + Some(this.update(cx, |this, cx| this.build_context_menu(window, cx))) + }) + .trigger_with_tooltip( + IconButton::new("tool-selector-button", IconName::SettingsAlt) + .shape(IconButtonShape::Square) + .icon_size(IconSize::Small) + .icon_color(Color::Muted), + Tooltip::text("Customize Tools"), + ) + .anchor(gpui::Corner::BottomLeft) + } +} diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 1f025eda63..9980d9a47a 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -14,7 +14,7 @@ pub fn init(cx: &mut App) { ToolRegistry::default_global(cx); } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] pub enum ToolSource { /// A native tool built-in to Zed. Native, diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 4ce949443e..b4ef27643e 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use collections::HashMap; +use collections::{HashMap, HashSet, IndexMap}; use gpui::App; use parking_lot::Mutex; -use crate::{Tool, ToolRegistry}; +use crate::{Tool, ToolRegistry, ToolSource}; #[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] pub struct ToolId(usize); @@ -19,6 +19,7 @@ pub struct ToolWorkingSet { struct WorkingSetState { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, + disabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } @@ -45,6 +46,34 @@ impl ToolWorkingSet { tools } + pub fn enabled_tools(&self, cx: &App) -> Vec> { + let all_tools = self.tools(cx); + + all_tools + .into_iter() + .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) + .collect() + } + + pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { + let mut tools_by_source = IndexMap::default(); + + for tool in self.tools(cx) { + tools_by_source + .entry(tool.source()) + .or_insert_with(Vec::new) + .push(tool); + } + + for tools in tools_by_source.values_mut() { + tools.sort_by_key(|tool| tool.name()); + } + + tools_by_source.sort_unstable_keys(); + + tools_by_source + } + pub fn insert(&self, tool: Arc) -> ToolId { let mut state = self.state.lock(); let tool_id = state.next_tool_id; @@ -56,11 +85,41 @@ impl ToolWorkingSet { tool_id } - pub fn remove(&self, command_ids_to_remove: &[ToolId]) { + pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { + !self.is_disabled(source, name) + } + + pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { + let state = self.state.lock(); + state + .disabled_tools_by_source + .get(source) + .map_or(false, |disabled_tools| disabled_tools.contains(name)) + } + + pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc]) { + let mut state = self.state.lock(); + state + .disabled_tools_by_source + .entry(source) + .or_default() + .retain(|name| !tools_to_enable.contains(name)); + } + + pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc]) { + let mut state = self.state.lock(); + state + .disabled_tools_by_source + .entry(source) + .or_default() + .extend(tools_to_disable.into_iter().cloned()); + } + + pub fn remove(&self, tool_ids_to_remove: &[ToolId]) { let mut state = self.state.lock(); state .context_server_tools_by_id - .retain(|id, _| !command_ids_to_remove.contains(id)); + .retain(|id, _| !tool_ids_to_remove.contains(id)); state.tools_changed(); } } @@ -71,7 +130,7 @@ impl WorkingSetState { self.context_server_tools_by_name.extend( self.context_server_tools_by_id .values() - .map(|command| (command.name(), command.clone())), + .map(|tool| (tool.name(), tool.clone())), ); } }