From 709523bf363b48206eb30a8e9bef9a5ce4603c48 Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Fri, 6 Jun 2025 14:05:27 +0200 Subject: [PATCH] Store profile per thread (#31907) This allows storing the profile per thread, as well as moving the logic of which tools are enabled or not to the profile itself. This makes it much easier to switch between profiles, means there is less global state being changed on every profile change. Release Notes: - agent panel: allow saving the profile per thread --------- Co-authored-by: Ben Kunkle --- Cargo.lock | 3 +- assets/settings/default.json | 1 - crates/agent/Cargo.toml | 1 + crates/agent/src/active_thread.rs | 4 + crates/agent/src/agent.rs | 1 + .../manage_profiles_modal.rs | 64 +--- .../src/agent_configuration/tool_picker.rs | 54 +-- crates/agent/src/agent_diff.rs | 3 +- crates/agent/src/agent_profile.rs | 334 ++++++++++++++++++ crates/agent/src/message_editor.rs | 14 +- crates/agent/src/profile_selector.rs | 64 ++-- crates/agent/src/thread.rs | 122 +++++-- crates/agent/src/thread_store.rs | 102 +----- crates/agent/src/tool_compatibility.rs | 22 +- crates/agent_settings/Cargo.toml | 1 - crates/agent_settings/src/agent_profile.rs | 25 +- crates/agent_settings/src/agent_settings.rs | 14 +- crates/assistant_tool/src/tool_working_set.rs | 84 +---- crates/collab/Cargo.toml | 1 - crates/eval/src/example.rs | 1 + crates/eval/src/instance.rs | 10 +- 21 files changed, 556 insertions(+), 369 deletions(-) create mode 100644 crates/agent/src/agent_profile.rs diff --git a/Cargo.lock b/Cargo.lock index 9554c46aac..af14e42430 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,7 @@ dependencies = [ "assistant_slash_command", "assistant_slash_commands", "assistant_tool", + "assistant_tools", "async-watch", "audio", "buffer_diff", @@ -147,7 +148,6 @@ dependencies = [ "deepseek", "fs", "gpui", - "indexmap", "language_model", "lmstudio", "log", @@ -2987,7 +2987,6 @@ dependencies = [ "anyhow", "assistant_context_editor", "assistant_slash_command", - "assistant_tool", "async-stripe", "async-trait", "async-tungstenite", diff --git a/assets/settings/default.json b/assets/settings/default.json index 8d8c65884c..fbcde696c3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -771,7 +771,6 @@ "tools": { "copy_path": true, "create_directory": true, - "create_file": true, "delete_path": true, "diagnostics": true, "edit_file": true, diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index c1f9d9a3fa..1b07d94605 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -102,6 +102,7 @@ zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] +assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index a983d43690..8eda04c60f 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1144,6 +1144,10 @@ impl ActiveThread { cx, ); } + ThreadEvent::ProfileChanged => { + self.save_thread(cx); + cx.notify(); + } } } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index db458b771e..0ac7869920 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -3,6 +3,7 @@ mod agent_configuration; mod agent_diff; mod agent_model_selector; mod agent_panel; +mod agent_profile; mod buffer_codegen; mod context; mod context_picker; diff --git a/crates/agent/src/agent_configuration/manage_profiles_modal.rs b/crates/agent/src/agent_configuration/manage_profiles_modal.rs index 8cb7d4dfe2..feb0a8e53f 100644 --- a/crates/agent/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent/src/agent_configuration/manage_profiles_modal.rs @@ -2,25 +2,21 @@ mod profile_modal_header; use std::sync::Arc; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles}; +use agent_settings::{AgentProfileId, AgentSettings, builtin_profiles}; use assistant_tool::ToolWorkingSet; -use convert_case::{Case, Casing as _}; use editor::Editor; use fs::Fs; -use gpui::{ - DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, WeakEntity, - prelude::*, -}; -use settings::{Settings as _, update_settings_file}; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*}; +use settings::Settings as _; use ui::{ KeyBinding, ListItem, ListItemSpacing, ListSeparator, Navigable, NavigableEntry, prelude::*, }; -use util::ResultExt as _; use workspace::{ModalView, Workspace}; use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader; use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate}; -use crate::{AgentPanel, ManageProfiles, ThreadStore}; +use crate::agent_profile::AgentProfile; +use crate::{AgentPanel, ManageProfiles}; use super::tool_picker::ToolPickerMode; @@ -103,7 +99,6 @@ pub struct NewProfileMode { pub struct ManageProfilesModal { fs: Arc, tools: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, mode: Mode, } @@ -119,9 +114,8 @@ impl ManageProfilesModal { let fs = workspace.app_state().fs.clone(); let thread_store = panel.read(cx).thread_store(); let tools = thread_store.read(cx).tools(); - let thread_store = thread_store.downgrade(); workspace.toggle_modal(window, cx, |window, cx| { - let mut this = Self::new(fs, tools, thread_store, window, cx); + let mut this = Self::new(fs, tools, window, cx); if let Some(profile_id) = action.customize_tools.clone() { this.configure_builtin_tools(profile_id, window, cx); @@ -136,7 +130,6 @@ impl ManageProfilesModal { pub fn new( fs: Arc, tools: Entity, - thread_store: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -145,7 +138,6 @@ impl ManageProfilesModal { Self { fs, tools, - thread_store, focus_handle, mode: Mode::choose_profile(window, cx), } @@ -206,7 +198,6 @@ impl ManageProfilesModal { ToolPickerMode::McpTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -244,7 +235,6 @@ impl ManageProfilesModal { ToolPickerMode::BuiltinTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -270,32 +260,10 @@ impl ManageProfilesModal { match &self.mode { Mode::ChooseProfile { .. } => {} Mode::NewProfile(mode) => { - let settings = AgentSettings::get_global(cx); - - let base_profile = mode - .base_profile_id - .as_ref() - .and_then(|profile_id| settings.profiles.get(profile_id).cloned()); - let name = mode.name_editor.read(cx).text(cx); - let profile_id = AgentProfileId(name.to_case(Case::Kebab).into()); - let profile = AgentProfile { - name: name.into(), - tools: base_profile - .as_ref() - .map(|profile| profile.tools.clone()) - .unwrap_or_default(), - enable_all_context_servers: base_profile - .as_ref() - .map(|profile| profile.enable_all_context_servers) - .unwrap_or_default(), - context_servers: base_profile - .map(|profile| profile.context_servers) - .unwrap_or_default(), - }; - - self.create_profile(profile_id.clone(), profile, cx); + let profile_id = + AgentProfile::create(name, mode.base_profile_id.clone(), self.fs.clone(), cx); self.view_profile(profile_id, window, cx); } Mode::ViewProfile(_) => {} @@ -325,19 +293,6 @@ impl ManageProfilesModal { } } } - - fn create_profile( - &self, - profile_id: AgentProfileId, - profile: AgentProfile, - cx: &mut Context, - ) { - update_settings_file::(self.fs.clone(), cx, { - move |settings, _cx| { - settings.create_profile(profile_id, profile).log_err(); - } - }); - } } impl ModalView for ManageProfilesModal {} @@ -520,14 +475,13 @@ impl ManageProfilesModal { ) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; let profile_name = settings .profiles .get(&mode.profile_id) .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let icon = match profile_id.as_str() { + let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, "ask" => IconName::MessageBubbles, _ => IconName::UserRoundPen, diff --git a/crates/agent/src/agent_configuration/tool_picker.rs b/crates/agent/src/agent_configuration/tool_picker.rs index 5ac2d4496b..7c3d20457e 100644 --- a/crates/agent/src/agent_configuration/tool_picker.rs +++ b/crates/agent/src/agent_configuration/tool_picker.rs @@ -1,19 +1,17 @@ use std::{collections::BTreeMap, sync::Arc}; use agent_settings::{ - AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent, + AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent, ContextServerPresetContent, }; use assistant_tool::{ToolSource, ToolWorkingSet}; use fs::Fs; use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window}; use picker::{Picker, PickerDelegate}; -use settings::{Settings as _, update_settings_file}; +use settings::update_settings_file; use ui::{ListItem, ListItemSpacing, prelude::*}; use util::ResultExt as _; -use crate::ThreadStore; - pub struct ToolPicker { picker: Entity>, } @@ -71,11 +69,10 @@ pub enum PickerItem { pub struct ToolPickerDelegate { tool_picker: WeakEntity, - thread_store: WeakEntity, fs: Arc, items: Arc>, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, filtered_items: Vec, selected_index: usize, mode: ToolPickerMode, @@ -86,20 +83,18 @@ impl ToolPickerDelegate { mode: ToolPickerMode, fs: Arc, tool_set: Entity, - thread_store: WeakEntity, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, cx: &mut Context, ) -> Self { let items = Arc::new(Self::resolve_items(mode, &tool_set, cx)); Self { tool_picker: cx.entity().downgrade(), - thread_store, fs, items, profile_id, - profile, + profile_settings, filtered_items: Vec::new(), selected_index: 0, mode, @@ -249,28 +244,31 @@ impl PickerDelegate for ToolPickerDelegate { }; let is_currently_enabled = if let Some(server_id) = server_id.clone() { - let preset = self.profile.context_servers.entry(server_id).or_default(); + let preset = self + .profile_settings + .context_servers + .entry(server_id) + .or_default(); let is_enabled = *preset.tools.entry(tool_name.clone()).or_default(); *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled; is_enabled } else { - let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default(); - *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled; + let is_enabled = *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default(); + *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default() = !is_enabled; is_enabled }; - let active_profile_id = &AgentSettings::get_global(cx).default_profile; - if active_profile_id == &self.profile_id { - self.thread_store - .update(cx, |this, cx| { - this.load_profile(self.profile.clone(), cx); - }) - .log_err(); - } - update_settings_file::(self.fs.clone(), cx, { let profile_id = self.profile_id.clone(); - let default_profile = self.profile.clone(); + let default_profile = self.profile_settings.clone(); let server_id = server_id.clone(); let tool_name = tool_name.clone(); move |settings: &mut AgentSettingsContent, _cx| { @@ -348,14 +346,18 @@ impl PickerDelegate for ToolPickerDelegate { ), PickerItem::Tool { name, server_id } => { let is_enabled = if let Some(server_id) = server_id { - self.profile + self.profile_settings .context_servers .get(server_id.as_ref()) .and_then(|preset| preset.tools.get(name)) .copied() - .unwrap_or(self.profile.enable_all_context_servers) + .unwrap_or(self.profile_settings.enable_all_context_servers) } else { - self.profile.tools.get(name).copied().unwrap_or(false) + self.profile_settings + .tools + .get(name) + .copied() + .unwrap_or(false) }; Some( diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index b620d53c78..34ff249e95 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1378,7 +1378,8 @@ impl AgentDiff { | ThreadEvent::CheckpointChanged | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached - | ThreadEvent::CancelEditing => {} + | ThreadEvent::CancelEditing + | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs new file mode 100644 index 0000000000..5cd69bd324 --- /dev/null +++ b/crates/agent/src/agent_profile.rs @@ -0,0 +1,334 @@ +use std::sync::Arc; + +use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; +use assistant_tool::{Tool, ToolSource, ToolWorkingSet}; +use collections::IndexMap; +use convert_case::{Case, Casing}; +use fs::Fs; +use gpui::{App, Entity}; +use settings::{Settings, update_settings_file}; +use ui::SharedString; +use util::ResultExt; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AgentProfile { + id: AgentProfileId, + tool_set: Entity, +} + +pub type AvailableProfiles = IndexMap; + +impl AgentProfile { + pub fn new(id: AgentProfileId, tool_set: Entity) -> Self { + Self { id, tool_set } + } + + /// Saves a new profile to the settings. + pub fn create( + name: String, + base_profile_id: Option, + fs: Arc, + cx: &App, + ) -> AgentProfileId { + let id = AgentProfileId(name.to_case(Case::Kebab).into()); + + let base_profile = + base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned()); + + let profile_settings = AgentProfileSettings { + name: name.into(), + tools: base_profile + .as_ref() + .map(|profile| profile.tools.clone()) + .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), + context_servers: base_profile + .map(|profile| profile.context_servers) + .unwrap_or_default(), + }; + + update_settings_file::(fs, cx, { + let id = id.clone(); + move |settings, _cx| { + settings.create_profile(id, profile_settings).log_err(); + } + }); + + id + } + + /// Returns a map of AgentProfileIds to their names + pub fn available_profiles(cx: &App) -> AvailableProfiles { + let mut profiles = AvailableProfiles::default(); + for (id, profile) in AgentSettings::get_global(cx).profiles.iter() { + profiles.insert(id.clone(), profile.name.clone()); + } + profiles + } + + pub fn id(&self) -> &AgentProfileId { + &self.id + } + + pub fn enabled_tools(&self, cx: &App) -> Vec> { + let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { + return Vec::new(); + }; + + self.tool_set + .read(cx) + .tools(cx) + .into_iter() + .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name())) + .collect() + } + + fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { + match source { + ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false), + ToolSource::ContextServer { id } => { + if settings.enable_all_context_servers { + return true; + } + + let Some(preset) = settings.context_servers.get(id.as_ref()) else { + return false; + }; + *preset.tools.get(name.as_str()).unwrap_or(&false) + } + } + } +} + +#[cfg(test)] +mod tests { + use agent_settings::ContextServerPreset; + use assistant_tool::ToolRegistry; + use collections::IndexMap; + use gpui::{AppContext, TestAppContext}; + use http_client::FakeHttpClient; + use project::Project; + use settings::{Settings, SettingsStore}; + use ui::SharedString; + + use super::*; + + #[gpui::test] + async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId::default(); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + // Plus all registered MCP tools + expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_custom_mcp_settings(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("custom_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings.context_servers["mcp"] + .tools + .iter() + .filter_map(|(key, enabled)| enabled.then(|| key.to_string())) + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_only_built_in(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("write_minus_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + AgentSettings::register(cx); + language_model::init_settings(cx); + ToolRegistry::default_global(cx); + assistant_tools::init(FakeHttpClient::with_404_response(), cx); + }); + + cx.update(|cx| { + let mut agent_settings = AgentSettings::get_global(cx).clone(); + agent_settings.profiles.insert( + AgentProfileId("write_minus_mcp".into()), + AgentProfileSettings { + name: "write_minus_mcp".into(), + enable_all_context_servers: false, + ..agent_settings.profiles[&AgentProfileId::default()].clone() + }, + ); + agent_settings.profiles.insert( + AgentProfileId("custom_mcp".into()), + AgentProfileSettings { + name: "mcp".into(), + tools: IndexMap::default(), + enable_all_context_servers: false, + context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]), + }, + ); + AgentSettings::override_global(agent_settings, cx); + }) + } + + fn context_server_preset() -> ContextServerPreset { + ContextServerPreset { + tools: IndexMap::from_iter([ + ("enabled_mcp_tool".into(), true), + ("disabled_mcp_tool".into(), false), + ]), + } + } + + fn default_tool_set(cx: &mut TestAppContext) -> Entity { + cx.new(|_| { + let mut tool_set = ToolWorkingSet::default(); + tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp"))); + tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp"))); + tool_set + }) + } + + struct FakeTool { + name: String, + source: SharedString, + } + + impl FakeTool { + fn new(name: impl Into, source: impl Into) -> Self { + Self { + name: name.into(), + source: source.into(), + } + } + } + + impl Tool for FakeTool { + fn name(&self) -> String { + self.name.clone() + } + + fn source(&self) -> ToolSource { + ToolSource::ContextServer { + id: self.source.clone(), + } + } + + fn description(&self) -> String { + unimplemented!() + } + + fn icon(&self) -> ui::IconName { + unimplemented!() + } + + fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + unimplemented!() + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + unimplemented!() + } + + fn run( + self: Arc, + _input: serde_json::Value, + _request: Arc, + _project: Entity, + _action_log: Entity, + _model: Arc, + _window: Option, + _cx: &mut App, + ) -> assistant_tool::ToolResult { + unimplemented!() + } + + fn may_perform_edits(&self) -> bool { + unimplemented!() + } + } +} diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 0ae326bd44..a3958d9acb 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -175,8 +175,7 @@ impl MessageEditor { ) }); - let incompatible_tools = - cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx)); + let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx)); let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), @@ -204,15 +203,8 @@ impl MessageEditor { ) }); - let profile_selector = cx.new(|cx| { - ProfileSelector::new( - fs, - thread.clone(), - thread_store, - editor.focus_handle(cx), - cx, - ) - }); + let profile_selector = + cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); Self { editor: editor.clone(), diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index a51440ddb9..7a42e45fa4 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -1,26 +1,24 @@ use std::sync::Arc; -use agent_settings::{ - AgentDockPosition, AgentProfile, AgentProfileId, AgentSettings, GroupedAgentProfiles, - builtin_profiles, -}; +use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use fs::Fs; -use gpui::{Action, Empty, Entity, FocusHandle, Subscription, WeakEntity, prelude::*}; +use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*}; use language_model::LanguageModelRegistry; use settings::{Settings as _, SettingsStore, update_settings_file}; use ui::{ ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::ResultExt as _; -use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector}; +use crate::{ + ManageProfiles, Thread, ToggleProfileSelector, + agent_profile::{AgentProfile, AvailableProfiles}, +}; pub struct ProfileSelector { - profiles: GroupedAgentProfiles, + profiles: AvailableProfiles, fs: Arc, thread: Entity, - thread_store: WeakEntity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -30,7 +28,6 @@ impl ProfileSelector { pub fn new( fs: Arc, thread: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { @@ -39,10 +36,9 @@ impl ProfileSelector { }); Self { - profiles: GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)), + profiles: AgentProfile::available_profiles(cx), fs, thread, - thread_store, menu_handle: PopoverMenuHandle::default(), focus_handle, _subscriptions: vec![settings_subscription], @@ -54,7 +50,7 @@ impl ProfileSelector { } fn refresh_profiles(&mut self, cx: &mut Context) { - self.profiles = GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)); + self.profiles = AgentProfile::available_profiles(cx); } fn build_context_menu( @@ -64,21 +60,30 @@ impl ProfileSelector { ) -> Entity { ContextMenu::build(window, cx, |mut menu, _window, cx| { let settings = AgentSettings::get_global(cx); - for (profile_id, profile) in self.profiles.builtin.iter() { + + let mut found_non_builtin = false; + for (profile_id, profile_name) in self.profiles.iter() { + if !builtin_profiles::is_builtin(profile_id) { + found_non_builtin = true; + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); } - if !self.profiles.custom.is_empty() { + if found_non_builtin { menu = menu.separator().header("Custom Profiles"); - for (profile_id, profile) in self.profiles.custom.iter() { + for (profile_id, profile_name) in self.profiles.iter() { + if builtin_profiles::is_builtin(profile_id) { + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); @@ -99,19 +104,20 @@ impl ProfileSelector { fn menu_entry_for_profile( &self, profile_id: AgentProfileId, - profile: &AgentProfile, + profile_name: &SharedString, settings: &AgentSettings, - _cx: &App, + cx: &App, ) -> ContextMenuEntry { - let documentation = match profile.name.to_lowercase().as_str() { + let documentation = match profile_name.to_lowercase().as_str() { builtin_profiles::WRITE => Some("Get help to write anything."), builtin_profiles::ASK => Some("Chat about your codebase."), builtin_profiles::MINIMAL => Some("Chat about anything with no tools."), _ => None, }; + let thread_profile_id = self.thread.read(cx).profile().id(); - let entry = ContextMenuEntry::new(profile.name.clone()) - .toggleable(IconPosition::End, profile_id == settings.default_profile); + let entry = ContextMenuEntry::new(profile_name.clone()) + .toggleable(IconPosition::End, &profile_id == thread_profile_id); let entry = if let Some(doc_text) = documentation { entry.documentation_aside(documentation_side(settings.dock), move |_| { @@ -123,7 +129,7 @@ impl ProfileSelector { entry.handler({ let fs = self.fs.clone(); - let thread_store = self.thread_store.clone(); + let thread = self.thread.clone(); let profile_id = profile_id.clone(); move |_window, cx| { update_settings_file::(fs.clone(), cx, { @@ -133,11 +139,9 @@ impl ProfileSelector { } }); - thread_store - .update(cx, |this, cx| { - this.load_profile_by_id(profile_id.clone(), cx); - }) - .log_err(); + thread.update(cx, |this, cx| { + this.set_profile(profile_id.clone(), cx); + }); } }) } @@ -146,7 +150,7 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; + let profile_id = self.thread.read(cx).profile().id(); let profile = settings.profiles.get(profile_id); let selected_profile = profile diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f857557271..bb8cc706bb 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::Instant; -use agent_settings::{AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; @@ -41,6 +41,7 @@ use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; use crate::ThreadStore; +use crate::agent_profile::AgentProfile; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, @@ -360,6 +361,7 @@ pub struct Thread { >, remaining_turns: u32, configured_model: Option, + profile: AgentProfile, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -407,6 +409,7 @@ impl Thread { ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { id: ThreadId::new(), @@ -449,6 +452,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -495,6 +499,9 @@ impl Thread { let completion_mode = serialized .completion_mode .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); + let profile_id = serialized + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); Self { id, @@ -554,7 +561,7 @@ impl Thread { pending_checkpoint: None, project: project.clone(), prompt_builder, - tools, + tools: tools.clone(), tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), @@ -570,6 +577,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -585,6 +593,17 @@ impl Thread { &self.id } + pub fn profile(&self) -> &AgentProfile { + &self.profile + } + + pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { + if &id != self.profile.id() { + self.profile = AgentProfile::new(id, self.tools.clone()); + cx.emit(ThreadEvent::ProfileChanged); + } + } + pub fn is_empty(&self) -> bool { self.messages.is_empty() } @@ -919,8 +938,7 @@ impl Thread { model: Arc, ) -> Vec { if model.supports_tools() { - self.tools() - .read(cx) + self.profile .enabled_tools(cx) .into_iter() .filter_map(|tool| { @@ -1180,6 +1198,7 @@ impl Thread { }), completion_mode: Some(this.completion_mode), tool_use_limit_reached: this.tool_use_limit_reached, + profile: Some(this.profile.id().clone()), }) }) } @@ -2121,7 +2140,7 @@ impl Thread { window: Option, cx: &mut Context, ) { - let available_tools = self.tools.read(cx).enabled_tools(cx); + let available_tools = self.profile.enabled_tools(cx); let tool_list = available_tools .iter() @@ -2213,19 +2232,15 @@ impl Thread { ) -> Task<()> { let tool_name: Arc = tool.name().into(); - let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) { - Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into() - } else { - tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model, - window, - cx, - ) - }; + let tool_result = tool.run( + input, + request, + self.project.clone(), + self.action_log.clone(), + model, + window, + cx, + ); // Store the card separately if it exists if let Some(card) = tool_result.card.clone() { @@ -2344,8 +2359,7 @@ impl Thread { let client = self.project.read(cx).client(); let enabled_tool_names: Vec = self - .tools() - .read(cx) + .profile .enabled_tools(cx) .iter() .map(|tool| tool.name()) @@ -2858,6 +2872,7 @@ pub enum ThreadEvent { ToolUseLimitReached, CancelEditing, CompletionCanceled, + ProfileChanged, } impl EventEmitter for Thread {} @@ -2872,7 +2887,7 @@ struct PendingCompletion { mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; - use agent_settings::{AgentSettings, LanguageModelParameters}; + use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; @@ -3285,6 +3300,71 @@ fn main() {{ ); } + #[gpui::test] + async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Check that we are starting with the default profile + let profile = cx.read(|cx| thread.read(cx).profile.clone()); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + assert_eq!( + profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + + #[gpui::test] + async fn test_serializing_thread_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Profile gets serialized with default values + let serialized = thread + .update(cx, |thread, cx| thread.serialize(cx)) + .await + .unwrap(); + + assert_eq!(serialized.profile, Some(AgentProfileId::default())); + + let deserialized = cx.update(|cx| { + thread.update(cx, |thread, cx| { + Thread::deserialize( + thread.id.clone(), + serialized, + thread.project.clone(), + thread.tools.clone(), + thread.prompt_builder.clone(), + thread.project_context.clone(), + None, + cx, + ) + }) + }); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + + assert_eq!( + deserialized.profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + #[gpui::test] async fn test_temperature_setting(cx: &mut TestAppContext) { init_test_settings(cx); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 964cb8d75e..504280fac4 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -3,9 +3,9 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::{Arc, Mutex}; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; +use assistant_tool::{ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::ContextServerId; @@ -25,7 +25,6 @@ use prompt_store::{ UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; -use settings::{Settings as _, SettingsStore}; use ui::Window; use util::ResultExt as _; @@ -147,12 +146,7 @@ impl ThreadStore { prompt_store: Option>, cx: &mut Context, ) -> (Self, oneshot::Receiver<()>) { - let mut subscriptions = vec![ - cx.observe_global::(move |this: &mut Self, cx| { - this.load_default_profile(cx); - }), - cx.subscribe(&project, Self::handle_project_event), - ]; + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe( @@ -200,7 +194,6 @@ impl ThreadStore { _reload_system_prompt_task: reload_system_prompt_task, _subscriptions: subscriptions, }; - this.load_default_profile(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); (this, ready_rx) @@ -520,86 +513,6 @@ impl ThreadStore { }) } - fn load_default_profile(&self, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - self.load_profile_by_id(assistant_settings.default_profile.clone(), cx); - } - - pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - if let Some(profile) = assistant_settings.profiles.get(&profile_id) { - self.load_profile(profile.clone(), cx); - } - } - - pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context) { - self.tools.update(cx, |tools, cx| { - tools.disable_all_tools(cx); - tools.enable( - ToolSource::Native, - &profile - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ); - }); - - if profile.enable_all_context_servers { - for context_server_id in self - .project - .read(cx) - .context_server_store() - .read(cx) - .all_server_ids() - { - self.tools.update(cx, |tools, cx| { - tools.enable_source( - ToolSource::ContextServer { - id: context_server_id.0.into(), - }, - cx, - ); - }); - } - // Enable all the tools from all context servers, but disable the ones that are explicitly disabled - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.disable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| (!enabled).then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } else { - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.enable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } - } - fn register_context_server_handlers(&self, cx: &mut Context) { cx.subscribe( &self.project.read(cx).context_server_store(), @@ -618,6 +531,7 @@ impl ThreadStore { match event { project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { + ContextServerStatus::Starting => {} ContextServerStatus::Running => { if let Some(server) = context_server_store.read(cx).get_running_server(server_id) @@ -656,10 +570,9 @@ impl ThreadStore { .log_err(); if let Some(tool_ids) = tool_ids { - this.update(cx, |this, cx| { + this.update(cx, |this, _| { this.context_server_tool_ids .insert(server_id, tool_ids); - this.load_default_profile(cx); }) .log_err(); } @@ -675,10 +588,8 @@ impl ThreadStore { tool_working_set.update(cx, |tool_working_set, _| { tool_working_set.remove(&tool_ids); }); - self.load_default_profile(cx); } } - _ => {} } } } @@ -714,6 +625,8 @@ pub struct SerializedThread { pub completion_mode: Option, #[serde(default)] pub tool_use_limit_reached: bool, + #[serde(default)] + pub profile: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -856,6 +769,7 @@ impl LegacySerializedThread { model: None, completion_mode: None, tool_use_limit_reached: false, + profile: None, } } } diff --git a/crates/agent/src/tool_compatibility.rs b/crates/agent/src/tool_compatibility.rs index 141d87c96f..6193b0929d 100644 --- a/crates/agent/src/tool_compatibility.rs +++ b/crates/agent/src/tool_compatibility.rs @@ -1,30 +1,33 @@ use std::sync::Arc; -use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent}; +use assistant_tool::{Tool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use language_model::{LanguageModel, LanguageModelToolSchemaFormat}; use ui::prelude::*; +use crate::{Thread, ThreadEvent}; + pub struct IncompatibleToolsState { cache: HashMap>>, - tool_working_set: Entity, - _tool_working_set_subscription: Subscription, + thread: Entity, + _thread_subscription: Subscription, } impl IncompatibleToolsState { - pub fn new(tool_working_set: Entity, cx: &mut Context) -> Self { + pub fn new(thread: Entity, cx: &mut Context) -> Self { let _tool_working_set_subscription = - cx.subscribe(&tool_working_set, |this, _, event, _| match event { - ToolWorkingSetEvent::EnabledToolsChanged => { + cx.subscribe(&thread, |this, _, event, _| match event { + ThreadEvent::ProfileChanged => { this.cache.clear(); } + _ => {} }); Self { cache: HashMap::default(), - tool_working_set, - _tool_working_set_subscription, + thread, + _thread_subscription: _tool_working_set_subscription, } } @@ -36,8 +39,9 @@ impl IncompatibleToolsState { self.cache .entry(model.tool_input_format()) .or_insert_with(|| { - self.tool_working_set + self.thread .read(cx) + .profile() .enabled_tools(cx) .iter() .filter(|tool| tool.input_schema(model.tool_input_format()).is_err()) diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 200c531c3c..c6a4bedbb5 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -16,7 +16,6 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true collections.workspace = true gpui.workspace = true -indexmap.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index 599932114a..a6b8633b34 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -17,29 +17,6 @@ pub mod builtin_profiles { } } -#[derive(Default)] -pub struct GroupedAgentProfiles { - pub builtin: IndexMap, - pub custom: IndexMap, -} - -impl GroupedAgentProfiles { - pub fn from_settings(settings: &crate::AgentSettings) -> Self { - let mut builtin = IndexMap::default(); - let mut custom = IndexMap::default(); - - for (profile_id, profile) in settings.profiles.clone() { - if builtin_profiles::is_builtin(&profile_id) { - builtin.insert(profile_id, profile); - } else { - custom.insert(profile_id, profile); - } - } - - Self { builtin, custom } - } -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentProfileId(pub Arc); @@ -63,7 +40,7 @@ impl Default for AgentProfileId { /// A profile for the Zed Agent that controls its behavior. #[derive(Debug, Clone)] -pub struct AgentProfile { +pub struct AgentProfileSettings { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 36480f30d5..9e8fd0c699 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -102,7 +102,7 @@ pub struct AgentSettings { pub using_outdated_settings_version: bool, pub default_profile: AgentProfileId, pub default_view: DefaultView, - pub profiles: IndexMap, + pub profiles: IndexMap, pub always_allow_tool_actions: bool, pub notify_when_agent_waiting: NotifyWhenAgentWaiting, pub play_sound_when_agent_done: bool, @@ -531,7 +531,7 @@ impl AgentSettingsContent { pub fn create_profile( &mut self, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, ) -> Result<()> { self.v2_setting(|settings| { let profiles = settings.profiles.get_or_insert_default(); @@ -542,10 +542,10 @@ impl AgentSettingsContent { profiles.insert( profile_id, AgentProfileContent { - name: profile.name.into(), - tools: profile.tools, - enable_all_context_servers: Some(profile.enable_all_context_servers), - context_servers: profile + name: profile_settings.name.into(), + tools: profile_settings.tools, + enable_all_context_servers: Some(profile_settings.enable_all_context_servers), + context_servers: profile_settings .context_servers .into_iter() .map(|(server_id, preset)| { @@ -910,7 +910,7 @@ impl Settings for AgentSettings { .extend(profiles.into_iter().map(|(id, profile)| { ( id, - AgentProfile { + AgentProfileSettings { name: profile.name.into(), tools: profile.tools, enable_all_context_servers: profile diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index c7e20d3517..c72c52ba7a 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use collections::{HashMap, HashSet, IndexMap}; -use gpui::{App, Context, EventEmitter}; +use collections::{HashMap, IndexMap}; +use gpui::App; use crate::{Tool, ToolRegistry, ToolSource}; @@ -13,17 +13,9 @@ pub struct ToolId(usize); pub struct ToolWorkingSet { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - enabled_sources: HashSet, - enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } -pub enum ToolWorkingSetEvent { - EnabledToolsChanged, -} - -impl EventEmitter for ToolWorkingSet {} - impl ToolWorkingSet { pub fn tool(&self, name: &str, cx: &App) -> Option> { self.context_server_tools_by_name @@ -57,42 +49,6 @@ impl ToolWorkingSet { tools_by_source } - pub fn enabled_tools(&self, cx: &App) -> Vec> { - let all_tools = self.tools(cx); - - all_tools - .into_iter() - .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) - .collect() - } - - pub fn disable_all_tools(&mut self, cx: &mut Context) { - self.enabled_tools_by_source.clear(); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context) { - self.enabled_sources.insert(source.clone()); - - let tools_by_source = self.tools_by_source(cx); - if let Some(tools) = tools_by_source.get(&source) { - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); - } - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context) { - self.enabled_sources.remove(source); - self.enabled_tools_by_source.remove(source); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn insert(&mut self, tool: Arc) -> ToolId { let tool_id = self.next_tool_id; self.next_tool_id.0 += 1; @@ -102,42 +58,6 @@ impl ToolWorkingSet { tool_id } - pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.enabled_tools_by_source - .get(source) - .map_or(false, |enabled_tools| enabled_tools.contains(name)) - } - - pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_enabled(source, name) - } - - pub fn enable( - &mut self, - source: ToolSource, - tools_to_enable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .extend(tools_to_enable.into_iter().cloned()); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable( - &mut self, - source: ToolSource, - tools_to_disable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .retain(|name| !tools_to_disable.contains(name)); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) { self.context_server_tools_by_id .retain(|id, _| !tool_ids_to_remove.contains(id)); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 020aedbc57..a91fdac992 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -80,7 +80,6 @@ zed_llm_client.workspace = true agent_settings.workspace = true assistant_context_editor.workspace = true assistant_slash_command.workspace = true -assistant_tool.workspace = true async-trait.workspace = true audio.workspace = true buffer_diff.workspace = true diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index dc384668c3..85af49e339 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -294,6 +294,7 @@ impl ExampleContext { | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryChanged | ThreadEvent::SummaryGenerated + | ThreadEvent::ProfileChanged | ThreadEvent::ReceivedTextChunk | ThreadEvent::StreamedToolUse { .. } | ThreadEvent::CheckpointChanged diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 94fdaf90bf..f28165e859 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -306,17 +306,19 @@ impl ExampleInstance { let thread_store = thread_store.await?; - let profile_id = meta.profile_id.clone(); - thread_store.update(cx, |thread_store, cx| thread_store.load_profile_by_id(profile_id, cx)).expect("Failed to load profile"); let thread = thread_store.update(cx, |thread_store, cx| { - if let Some(json) = &meta.existing_thread_json { + let thread = if let Some(json) = &meta.existing_thread_json { let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); thread_store.create_thread_from_serialized(serialized, cx) } else { thread_store.create_thread(cx) - } + }; + thread.update(cx, |thread, cx| { + thread.set_profile(meta.profile_id.clone(), cx); + }); + thread })?;