agent: Make ToolWorkingSet an Entity (#28757)

Motivation is to emit events when enabled tools change, want to use this
in #28755

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-04-15 06:42:31 -06:00 committed by GitHub
parent 7e1b419243
commit e26f0a331f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 183 additions and 202 deletions

View file

@ -894,6 +894,7 @@ mod tests {
use super::*;
use crate::{ThreadStore, thread_store};
use assistant_settings::AssistantSettings;
use assistant_tool::ToolWorkingSet;
use context_server::ContextServerSettings;
use editor::EditorSettings;
use gpui::TestAppContext;
@ -937,7 +938,7 @@ mod tests {
.update(|cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View file

@ -29,7 +29,7 @@ pub struct AssistantConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_manager: Entity<ContextServerManager>,
expanded_context_server_tools: HashMap<Arc<str>, bool>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
}
@ -37,7 +37,7 @@ impl AssistantConfiguration {
pub fn new(
fs: Arc<dyn Fs>,
context_server_manager: Entity<ContextServerManager>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -226,7 +226,7 @@ impl AssistantConfiguration {
fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_servers = self.context_server_manager.read(cx).all_servers().clone();
let tools_by_source = self.tools.tools_by_source(cx);
let tools_by_source = self.tools.read(cx).tools_by_source(cx);
let empty = Vec::new();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";

View file

@ -84,7 +84,7 @@ pub struct NewProfileMode {
pub struct ManageProfilesModal {
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
mode: Mode,
@ -117,7 +117,7 @@ impl ManageProfilesModal {
pub fn new(
fs: Arc<dyn Fs>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut Context<Self>,

View file

@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
impl ToolPickerDelegate {
pub fn new(
fs: Arc<dyn Fs>,
tool_set: Arc<ToolWorkingSet>,
tool_set: Entity<ToolWorkingSet>,
thread_store: WeakEntity<ThreadStore>,
profile_id: AgentProfileId,
profile: AgentProfile,
@ -68,7 +68,7 @@ impl ToolPickerDelegate {
) -> Self {
let mut tool_entries = Vec::new();
for (source, tools) in tool_set.tools_by_source(cx) {
for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
name: tool.name().into(),
source: source.clone(),
@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
if active_profile_id == &self.profile_id {
self.thread_store
.update(cx, |this, cx| {
this.load_profile(&self.profile, cx);
this.load_profile(self.profile.clone(), cx);
})
.log_err();
}

View file

@ -203,7 +203,7 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default());
let tools = cx.new(|_| ToolWorkingSet::default())?;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();

View file

@ -86,7 +86,7 @@ impl ProfileSelector {
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(&profile_id, cx);
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}

View file

@ -254,7 +254,7 @@ pub struct Thread {
pending_completions: Vec<PendingCompletion>,
project: Entity<Project>,
prompt_builder: Arc<PromptBuilder>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
@ -278,7 +278,7 @@ pub struct ExceededWindowError {
impl Thread {
pub fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
@ -322,7 +322,7 @@ impl Thread {
id: ThreadId,
serialized: SerializedThread,
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
cx: &mut Context<Self>,
@ -458,7 +458,7 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
pub fn tools(&self) -> &Arc<ToolWorkingSet> {
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
@ -846,6 +846,7 @@ impl Thread {
let mut tools = Vec::new();
tools.extend(
self.tools()
.read(cx)
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
@ -1354,7 +1355,7 @@ impl Thread {
.collect::<Vec<_>>();
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
if tool.needs_confirmation(&tool_use.input, cx)
&& !AssistantSettings::get_global(cx).always_allow_tool_actions
{
@ -1406,7 +1407,7 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
let run_tool = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
} else {
tool.run(
@ -1521,6 +1522,7 @@ impl Thread {
let enabled_tool_names: Vec<String> = self
.tools()
.read(cx)
.enabled_tools(cx)
.iter()
.map(|tool| tool.name().to_string())
@ -2341,7 +2343,7 @@ fn main() {{
.update(|_, cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)

View file

@ -56,7 +56,7 @@ impl SharedProjectContext {
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
@ -74,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Task<Entity<Self>> {
@ -88,7 +88,7 @@ impl ThreadStore {
fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
@ -248,7 +248,7 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn tools(&self) -> Arc<ToolWorkingSet> {
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}
@ -355,52 +355,60 @@ impl ThreadStore {
})
}
fn load_default_profile(&self, cx: &Context<Self>) {
fn load_default_profile(&self, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
self.load_profile_by_id(&assistant_settings.default_profile, cx);
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
}
pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
let assistant_settings = AssistantSettings::get_global(cx);
if let Some(profile) = assistant_settings.profiles.get(profile_id) {
self.load_profile(profile, cx);
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
self.load_profile(profile.clone(), cx);
}
}
pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
self.tools.disable_all_tools();
self.tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
);
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
self.tools.update(cx, |tools, cx| {
tools.disable_all_tools(cx);
tools.enable(
ToolSource::Native,
&profile
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
});
if profile.enable_all_context_servers {
for context_server in self.context_server_manager.read(cx).all_servers() {
self.tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
self.tools.update(cx, |tools, cx| {
tools.enable_source(
ToolSource::ContextServer {
id: context_server.id().into(),
},
cx,
);
});
}
} else {
for (context_server_id, preset) in &profile.context_servers {
self.tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
)
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.clone().into(),
},
&preset
.tools
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
}
}
@ -434,29 +442,36 @@ impl ThreadStore {
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
.log_err();
this.update(cx, |this, cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
if let Some(tool_ids) = tool_ids {
this.update(cx, |this, cx| {
this.context_server_tool_ids
.insert(server_id, tool_ids);
this.load_default_profile(cx);
})
.log_err();
}
}
}
}
@ -466,7 +481,9 @@ impl ThreadStore {
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
self.load_default_profile(cx);
}
}

View file

@ -5,7 +5,7 @@ use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, SharedString, Task};
use gpui::{App, Entity, SharedString, Task};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@ -49,7 +49,7 @@ impl ToolUseStatus {
}
pub struct ToolUseState {
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@ -59,7 +59,7 @@ pub struct ToolUseState {
pub const USING_TOOL_MARKER: &str = "<using_tool>";
impl ToolUseState {
pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
@ -73,7 +73,7 @@ impl ToolUseState {
///
/// Accepts a function to filter the tools that should be used to populate the state.
pub fn from_serialized_messages(
tools: Arc<ToolWorkingSet>,
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
mut filter_by_tool_name: impl FnMut(&str) -> bool,
) -> Self {
@ -199,12 +199,12 @@ impl ToolUseState {
}
})();
let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
{
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
@ -226,7 +226,7 @@ impl ToolUseState {
input: &serde_json::Value,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
tool.ui_text(input).into()
} else {
format!("Unknown tool {tool_name:?}").into()

View file

@ -1,8 +1,7 @@
use std::sync::Arc;
use collections::{HashMap, HashSet, IndexMap};
use gpui::App;
use parking_lot::Mutex;
use gpui::{App, Context, EventEmitter};
use crate::{Tool, ToolRegistry, ToolSource};
@ -12,11 +11,6 @@ pub struct ToolId(usize);
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
state: Mutex<WorkingSetState>,
}
#[derive(Default)]
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
enabled_sources: HashSet<ToolSource>,
@ -24,99 +18,27 @@ struct WorkingSetState {
next_tool_id: ToolId,
}
pub enum ToolWorkingSetEvent {
EnabledToolsChanged,
}
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
impl ToolWorkingSet {
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
self.state
.lock()
.context_server_tools_by_name
self.context_server_tools_by_name
.get(name)
.cloned()
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().tools(cx)
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
self.state.lock().tools_by_source(cx)
}
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
self.state.lock().enabled_tools(cx)
}
pub fn disable_all_tools(&self) {
let mut state = self.state.lock();
state.disable_all_tools();
}
pub fn enable_source(&self, source: ToolSource, cx: &App) {
let mut state = self.state.lock();
state.enable_source(source, cx);
}
pub fn disable_source(&self, source: &ToolSource) {
let mut state = self.state.lock();
state.disable_source(source);
}
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
state.next_tool_id.0 += 1;
state
.context_server_tools_by_id
.insert(tool_id, tool.clone());
state.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_enabled(source, name)
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.state.lock().is_disabled(source, name)
}
pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
let mut state = self.state.lock();
state.enable(source, tools_to_enable);
}
pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
let mut state = self.state.lock();
state.disable(source, tools_to_disable);
}
pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
impl WorkingSetState {
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for tool in self.tools(cx) {
@ -135,7 +57,7 @@ impl WorkingSetState {
tools_by_source
}
fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let all_tools = self.tools(cx);
all_tools
@ -144,31 +66,12 @@ impl WorkingSetState {
.collect()
}
fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
self.enabled_tools_by_source.clear();
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
}
fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
}
fn enable_source(&mut self, source: ToolSource, cx: &App) {
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.insert(source.clone());
let tools_by_source = self.tools_by_source(cx);
@ -181,14 +84,72 @@ impl WorkingSetState {
.collect::<HashSet<_>>(),
);
}
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_source(&mut self, source: &ToolSource) {
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
self.enabled_sources.remove(source);
self.enabled_tools_by_source.remove(source);
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
fn disable_all_tools(&mut self) {
self.enabled_tools_by_source.clear();
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
self.enabled_tools_by_source
.get(source)
.map_or(false, |enabled_tools| enabled_tools.contains(name))
}
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
!self.is_enabled(source, name)
}
pub fn enable(
&mut self,
source: ToolSource,
tools_to_enable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.extend(tools_to_enable.into_iter().cloned());
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn disable(
&mut self,
source: ToolSource,
tools_to_disable: &[Arc<str>],
cx: &mut Context<Self>,
) {
self.enabled_tools_by_source
.entry(source)
.or_default()
.retain(|name| !tools_to_disable.contains(name));
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.map(|tool| (tool.name(), tool.clone())),
);
}
}

View file

@ -6,7 +6,7 @@ use collections::HashMap;
use dap::DapRegistry;
use futures::channel::mpsc;
use futures::{FutureExt, StreamExt as _, select_biased};
use gpui::{App, AsyncApp, Entity, Task};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
use handlebars::Handlebars;
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{
@ -181,7 +181,7 @@ impl Example {
project.create_worktree(&worktree_path, true, cx)
});
let tools = Arc::new(ToolWorkingSet::default());
let tools = cx.new(|_| ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let this = self.clone();