assistant2: Add tool selector (#26480)

This PR adds a tool selector to Assistant 2 to facilitate customizing
the tools that the model sees:

<img width="1297" alt="Screenshot 2025-03-11 at 4 25 31 PM"
src="https://github.com/user-attachments/assets/7a656343-83bc-4546-9430-6a5f7ff1fd08"
/>

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-03-11 16:50:18 -04:00 committed by GitHub
parent 0cf6259fec
commit 4e6c37d23b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 166 additions and 33 deletions

View file

@ -16,6 +16,7 @@ mod terminal_inline_assistant;
mod thread; mod thread;
mod thread_history; mod thread_history;
mod thread_store; mod thread_store;
mod tool_selector;
mod tool_use; mod tool_use;
mod ui; mod ui;

View file

@ -28,6 +28,7 @@ use crate::context_store::{refresh_context_store_text, ContextStore};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{RequestKind, Thread}; use crate::thread::{RequestKind, Thread};
use crate::thread_store::ThreadStore; use crate::thread_store::ThreadStore;
use crate::tool_selector::ToolSelector;
use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker}; use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
pub struct MessageEditor { pub struct MessageEditor {
@ -39,6 +40,7 @@ pub struct MessageEditor {
inline_context_picker: Entity<ContextPicker>, inline_context_picker: Entity<ContextPicker>,
inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>, inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>, model_selector: Entity<AssistantModelSelector>,
tool_selector: Entity<ToolSelector>,
use_tools: bool, use_tools: bool,
edits_expanded: bool, edits_expanded: bool,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
@ -53,6 +55,7 @@ impl MessageEditor {
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let tools = thread.read(cx).tools().clone();
let context_store = cx.new(|_cx| ContextStore::new(workspace.clone())); let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
let context_picker_menu_handle = PopoverMenuHandle::default(); let context_picker_menu_handle = PopoverMenuHandle::default();
let inline_context_picker_menu_handle = PopoverMenuHandle::default(); let inline_context_picker_menu_handle = PopoverMenuHandle::default();
@ -118,6 +121,7 @@ impl MessageEditor {
cx, cx,
) )
}), }),
tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
use_tools: false, use_tools: false,
edits_expanded: false, edits_expanded: false,
_subscriptions: subscriptions, _subscriptions: subscriptions,
@ -538,23 +542,25 @@ impl Render for MessageEditor {
h_flex() h_flex()
.justify_between() .justify_between()
.child( .child(
Switch::new("use-tools", self.use_tools.into()) h_flex().gap_2().child(self.tool_selector.clone()).child(
.label("Tools") Switch::new("use-tools", self.use_tools.into())
.on_click(cx.listener( .label("Tools")
|this, selection, _window, _cx| { .on_click(cx.listener(
this.use_tools = match selection { |this, selection, _window, _cx| {
ToggleState::Selected => true, this.use_tools = match selection {
ToggleState::Unselected ToggleState::Selected => true,
| ToggleState::Indeterminate => false, ToggleState::Unselected
}; | ToggleState::Indeterminate => false,
}, };
)) },
.key_binding(KeyBinding::for_action_in( ))
&ChatMode, .key_binding(KeyBinding::for_action_in(
&focus_handle, &ChatMode,
window, &focus_handle,
cx, window,
)), cx,
)),
),
) )
.child( .child(
h_flex().gap_1().child(self.model_selector.clone()).child( h_flex().gap_1().child(self.model_selector.clone()).child(

View file

@ -355,16 +355,13 @@ impl Thread {
input_schema: ScriptingTool::input_schema(), input_schema: ScriptingTool::input_schema(),
}); });
tools.extend( tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
self.tools() LanguageModelRequestTool {
.tools(cx) name: tool.name(),
.into_iter() description: tool.description(),
.map(|tool| LanguageModelRequestTool { input_schema: tool.input_schema(),
name: tool.name(), }
description: tool.description(), }));
input_schema: tool.input_schema(),
}),
);
request.tools = tools; request.tools = tools;
} }

View file

@ -0,0 +1,70 @@
use std::sync::Arc;
use assistant_tool::{ToolSource, ToolWorkingSet};
use gpui::Entity;
use ui::{prelude::*, ContextMenu, IconButtonShape, PopoverMenu, Tooltip};
pub struct ToolSelector {
tools: Arc<ToolWorkingSet>,
}
impl ToolSelector {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
Self { tools }
}
fn build_context_menu(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Entity<ContextMenu> {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let tools_by_source = self.tools.tools_by_source(cx);
for (source, tools) in tools_by_source {
menu = match source {
ToolSource::Native => menu.header("Zed"),
ToolSource::ContextServer { id } => menu.separator().header(id),
};
for tool in tools {
let source = tool.source();
let name = tool.name().into();
let is_enabled = self.tools.is_enabled(&source, &name);
menu =
menu.toggleable_entry(tool.name(), is_enabled, IconPosition::End, None, {
let tools = self.tools.clone();
move |_window, _cx| {
if is_enabled {
tools.disable(source.clone(), &[name.clone()]);
} else {
tools.enable(source.clone(), &[name.clone()]);
}
}
});
}
}
menu
})
}
}
impl Render for ToolSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<'_, Self>) -> impl IntoElement {
let this = cx.entity().clone();
PopoverMenu::new("tool-selector")
.menu(move |window, cx| {
Some(this.update(cx, |this, cx| this.build_context_menu(window, cx)))
})
.trigger_with_tooltip(
IconButton::new("tool-selector-button", IconName::SettingsAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::Small)
.icon_color(Color::Muted),
Tooltip::text("Customize Tools"),
)
.anchor(gpui::Corner::BottomLeft)
}
}

View file

@ -14,7 +14,7 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx); ToolRegistry::default_global(cx);
} }
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum ToolSource { pub enum ToolSource {
/// A native tool built-in to Zed. /// A native tool built-in to Zed.
Native, Native,

View file

@ -1,10 +1,10 @@
use std::sync::Arc; use std::sync::Arc;
use collections::HashMap; use collections::{HashMap, HashSet, IndexMap};
use gpui::App; use gpui::App;
use parking_lot::Mutex; use parking_lot::Mutex;
use crate::{Tool, ToolRegistry}; use crate::{Tool, ToolRegistry, ToolSource};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] #[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize); pub struct ToolId(usize);
@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState { struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>, context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>, context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId, next_tool_id: ToolId,
} }
@ -45,6 +46,34 @@ impl ToolWorkingSet {
tools tools
} }
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx);
all_tools
.into_iter()
.filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
.collect()
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) {
tools_by_source
.entry(tool.source())
.or_insert_with(Vec::new)
.push(tool);
}
for tools in tools_by_source.values_mut() {
tools.sort_by_key(|tool| tool.name());
}
tools_by_source.sort_unstable_keys();
tools_by_source
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId { pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock(); let mut state = self.state.lock();
let tool_id = state.next_tool_id; let tool_id = state.next_tool_id;
@ -56,11 +85,41 @@ impl ToolWorkingSet {
tool_id tool_id
} }
pub fn remove(&self, command_ids_to_remove: &[ToolId]) { pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_disabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
let state = self.state.lock();
state
.disabled_tools_by_source
.get(source)
.map_or(false, |disabled_tools| disabled_tools.contains(name))
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state
.disabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_enable.contains(name));
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state
.disabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_disable.into_iter().cloned());
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock(); let mut state = self.state.lock();
state state
.context_server_tools_by_id .context_server_tools_by_id
.retain(|id, _| !command_ids_to_remove.contains(id)); .retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed(); state.tools_changed();
} }
} }
@ -71,7 +130,7 @@ impl WorkingSetState {
self.context_server_tools_by_name.extend( self.context_server_tools_by_name.extend(
self.context_server_tools_by_id self.context_server_tools_by_id
.values() .values()
.map(|command| (command.name(), command.clone())), .map(|tool| (tool.name(), tool.clone())),
); );
} }
} }