diff --git a/Cargo.lock b/Cargo.lock index 9554c46aac..af14e42430 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,7 @@ dependencies = [ "assistant_slash_command", "assistant_slash_commands", "assistant_tool", + "assistant_tools", "async-watch", "audio", "buffer_diff", @@ -147,7 +148,6 @@ dependencies = [ "deepseek", "fs", "gpui", - "indexmap", "language_model", "lmstudio", "log", @@ -2987,7 +2987,6 @@ dependencies = [ "anyhow", "assistant_context_editor", "assistant_slash_command", - "assistant_tool", "async-stripe", "async-trait", "async-tungstenite", diff --git a/assets/settings/default.json b/assets/settings/default.json index 8d8c65884c..fbcde696c3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -771,7 +771,6 @@ "tools": { "copy_path": true, "create_directory": true, - "create_file": true, "delete_path": true, "diagnostics": true, "edit_file": true, diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index c1f9d9a3fa..1b07d94605 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -102,6 +102,7 @@ zed_llm_client.workspace = true zstd.workspace = true [dev-dependencies] +assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index a983d43690..8eda04c60f 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -1144,6 +1144,10 @@ impl ActiveThread { cx, ); } + ThreadEvent::ProfileChanged => { + self.save_thread(cx); + cx.notify(); + } } } diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index db458b771e..0ac7869920 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -3,6 +3,7 @@ mod agent_configuration; mod agent_diff; mod agent_model_selector; mod agent_panel; +mod agent_profile; mod buffer_codegen; mod context; mod context_picker; diff --git a/crates/agent/src/agent_configuration/manage_profiles_modal.rs b/crates/agent/src/agent_configuration/manage_profiles_modal.rs index 8cb7d4dfe2..feb0a8e53f 100644 --- a/crates/agent/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent/src/agent_configuration/manage_profiles_modal.rs @@ -2,25 +2,21 @@ mod profile_modal_header; use std::sync::Arc; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles}; +use agent_settings::{AgentProfileId, AgentSettings, builtin_profiles}; use assistant_tool::ToolWorkingSet; -use convert_case::{Case, Casing as _}; use editor::Editor; use fs::Fs; -use gpui::{ - DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, WeakEntity, - prelude::*, -}; -use settings::{Settings as _, update_settings_file}; +use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*}; +use settings::Settings as _; use ui::{ KeyBinding, ListItem, ListItemSpacing, ListSeparator, Navigable, NavigableEntry, prelude::*, }; -use util::ResultExt as _; use workspace::{ModalView, Workspace}; use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader; use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate}; -use crate::{AgentPanel, ManageProfiles, ThreadStore}; +use crate::agent_profile::AgentProfile; +use crate::{AgentPanel, ManageProfiles}; use super::tool_picker::ToolPickerMode; @@ -103,7 +99,6 @@ pub struct NewProfileMode { pub struct ManageProfilesModal { fs: Arc, tools: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, mode: Mode, } @@ -119,9 +114,8 @@ impl ManageProfilesModal { let fs = workspace.app_state().fs.clone(); let thread_store = panel.read(cx).thread_store(); let tools = thread_store.read(cx).tools(); - let thread_store = thread_store.downgrade(); workspace.toggle_modal(window, cx, |window, cx| { - let mut this = Self::new(fs, tools, thread_store, window, cx); + let mut this = Self::new(fs, tools, window, cx); if let Some(profile_id) = action.customize_tools.clone() { this.configure_builtin_tools(profile_id, window, cx); @@ -136,7 +130,6 @@ impl ManageProfilesModal { pub fn new( fs: Arc, tools: Entity, - thread_store: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -145,7 +138,6 @@ impl ManageProfilesModal { Self { fs, tools, - thread_store, focus_handle, mode: Mode::choose_profile(window, cx), } @@ -206,7 +198,6 @@ impl ManageProfilesModal { ToolPickerMode::McpTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -244,7 +235,6 @@ impl ManageProfilesModal { ToolPickerMode::BuiltinTools, self.fs.clone(), self.tools.clone(), - self.thread_store.clone(), profile_id.clone(), profile, cx, @@ -270,32 +260,10 @@ impl ManageProfilesModal { match &self.mode { Mode::ChooseProfile { .. } => {} Mode::NewProfile(mode) => { - let settings = AgentSettings::get_global(cx); - - let base_profile = mode - .base_profile_id - .as_ref() - .and_then(|profile_id| settings.profiles.get(profile_id).cloned()); - let name = mode.name_editor.read(cx).text(cx); - let profile_id = AgentProfileId(name.to_case(Case::Kebab).into()); - let profile = AgentProfile { - name: name.into(), - tools: base_profile - .as_ref() - .map(|profile| profile.tools.clone()) - .unwrap_or_default(), - enable_all_context_servers: base_profile - .as_ref() - .map(|profile| profile.enable_all_context_servers) - .unwrap_or_default(), - context_servers: base_profile - .map(|profile| profile.context_servers) - .unwrap_or_default(), - }; - - self.create_profile(profile_id.clone(), profile, cx); + let profile_id = + AgentProfile::create(name, mode.base_profile_id.clone(), self.fs.clone(), cx); self.view_profile(profile_id, window, cx); } Mode::ViewProfile(_) => {} @@ -325,19 +293,6 @@ impl ManageProfilesModal { } } } - - fn create_profile( - &self, - profile_id: AgentProfileId, - profile: AgentProfile, - cx: &mut Context, - ) { - update_settings_file::(self.fs.clone(), cx, { - move |settings, _cx| { - settings.create_profile(profile_id, profile).log_err(); - } - }); - } } impl ModalView for ManageProfilesModal {} @@ -520,14 +475,13 @@ impl ManageProfilesModal { ) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; let profile_name = settings .profiles .get(&mode.profile_id) .map(|profile| profile.name.clone()) .unwrap_or_else(|| "Unknown".into()); - let icon = match profile_id.as_str() { + let icon = match mode.profile_id.as_str() { "write" => IconName::Pencil, "ask" => IconName::MessageBubbles, _ => IconName::UserRoundPen, diff --git a/crates/agent/src/agent_configuration/tool_picker.rs b/crates/agent/src/agent_configuration/tool_picker.rs index 5ac2d4496b..7c3d20457e 100644 --- a/crates/agent/src/agent_configuration/tool_picker.rs +++ b/crates/agent/src/agent_configuration/tool_picker.rs @@ -1,19 +1,17 @@ use std::{collections::BTreeMap, sync::Arc}; use agent_settings::{ - AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent, + AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent, ContextServerPresetContent, }; use assistant_tool::{ToolSource, ToolWorkingSet}; use fs::Fs; use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window}; use picker::{Picker, PickerDelegate}; -use settings::{Settings as _, update_settings_file}; +use settings::update_settings_file; use ui::{ListItem, ListItemSpacing, prelude::*}; use util::ResultExt as _; -use crate::ThreadStore; - pub struct ToolPicker { picker: Entity>, } @@ -71,11 +69,10 @@ pub enum PickerItem { pub struct ToolPickerDelegate { tool_picker: WeakEntity, - thread_store: WeakEntity, fs: Arc, items: Arc>, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, filtered_items: Vec, selected_index: usize, mode: ToolPickerMode, @@ -86,20 +83,18 @@ impl ToolPickerDelegate { mode: ToolPickerMode, fs: Arc, tool_set: Entity, - thread_store: WeakEntity, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, cx: &mut Context, ) -> Self { let items = Arc::new(Self::resolve_items(mode, &tool_set, cx)); Self { tool_picker: cx.entity().downgrade(), - thread_store, fs, items, profile_id, - profile, + profile_settings, filtered_items: Vec::new(), selected_index: 0, mode, @@ -249,28 +244,31 @@ impl PickerDelegate for ToolPickerDelegate { }; let is_currently_enabled = if let Some(server_id) = server_id.clone() { - let preset = self.profile.context_servers.entry(server_id).or_default(); + let preset = self + .profile_settings + .context_servers + .entry(server_id) + .or_default(); let is_enabled = *preset.tools.entry(tool_name.clone()).or_default(); *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled; is_enabled } else { - let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default(); - *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled; + let is_enabled = *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default(); + *self + .profile_settings + .tools + .entry(tool_name.clone()) + .or_default() = !is_enabled; is_enabled }; - let active_profile_id = &AgentSettings::get_global(cx).default_profile; - if active_profile_id == &self.profile_id { - self.thread_store - .update(cx, |this, cx| { - this.load_profile(self.profile.clone(), cx); - }) - .log_err(); - } - update_settings_file::(self.fs.clone(), cx, { let profile_id = self.profile_id.clone(); - let default_profile = self.profile.clone(); + let default_profile = self.profile_settings.clone(); let server_id = server_id.clone(); let tool_name = tool_name.clone(); move |settings: &mut AgentSettingsContent, _cx| { @@ -348,14 +346,18 @@ impl PickerDelegate for ToolPickerDelegate { ), PickerItem::Tool { name, server_id } => { let is_enabled = if let Some(server_id) = server_id { - self.profile + self.profile_settings .context_servers .get(server_id.as_ref()) .and_then(|preset| preset.tools.get(name)) .copied() - .unwrap_or(self.profile.enable_all_context_servers) + .unwrap_or(self.profile_settings.enable_all_context_servers) } else { - self.profile.tools.get(name).copied().unwrap_or(false) + self.profile_settings + .tools + .get(name) + .copied() + .unwrap_or(false) }; Some( diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index b620d53c78..34ff249e95 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1378,7 +1378,8 @@ impl AgentDiff { | ThreadEvent::CheckpointChanged | ThreadEvent::ToolConfirmationNeeded | ThreadEvent::ToolUseLimitReached - | ThreadEvent::CancelEditing => {} + | ThreadEvent::CancelEditing + | ThreadEvent::ProfileChanged => {} } } diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs new file mode 100644 index 0000000000..5cd69bd324 --- /dev/null +++ b/crates/agent/src/agent_profile.rs @@ -0,0 +1,334 @@ +use std::sync::Arc; + +use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; +use assistant_tool::{Tool, ToolSource, ToolWorkingSet}; +use collections::IndexMap; +use convert_case::{Case, Casing}; +use fs::Fs; +use gpui::{App, Entity}; +use settings::{Settings, update_settings_file}; +use ui::SharedString; +use util::ResultExt; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AgentProfile { + id: AgentProfileId, + tool_set: Entity, +} + +pub type AvailableProfiles = IndexMap; + +impl AgentProfile { + pub fn new(id: AgentProfileId, tool_set: Entity) -> Self { + Self { id, tool_set } + } + + /// Saves a new profile to the settings. + pub fn create( + name: String, + base_profile_id: Option, + fs: Arc, + cx: &App, + ) -> AgentProfileId { + let id = AgentProfileId(name.to_case(Case::Kebab).into()); + + let base_profile = + base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned()); + + let profile_settings = AgentProfileSettings { + name: name.into(), + tools: base_profile + .as_ref() + .map(|profile| profile.tools.clone()) + .unwrap_or_default(), + enable_all_context_servers: base_profile + .as_ref() + .map(|profile| profile.enable_all_context_servers) + .unwrap_or_default(), + context_servers: base_profile + .map(|profile| profile.context_servers) + .unwrap_or_default(), + }; + + update_settings_file::(fs, cx, { + let id = id.clone(); + move |settings, _cx| { + settings.create_profile(id, profile_settings).log_err(); + } + }); + + id + } + + /// Returns a map of AgentProfileIds to their names + pub fn available_profiles(cx: &App) -> AvailableProfiles { + let mut profiles = AvailableProfiles::default(); + for (id, profile) in AgentSettings::get_global(cx).profiles.iter() { + profiles.insert(id.clone(), profile.name.clone()); + } + profiles + } + + pub fn id(&self) -> &AgentProfileId { + &self.id + } + + pub fn enabled_tools(&self, cx: &App) -> Vec> { + let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { + return Vec::new(); + }; + + self.tool_set + .read(cx) + .tools(cx) + .into_iter() + .filter(|tool| Self::is_enabled(settings, tool.source(), tool.name())) + .collect() + } + + fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { + match source { + ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false), + ToolSource::ContextServer { id } => { + if settings.enable_all_context_servers { + return true; + } + + let Some(preset) = settings.context_servers.get(id.as_ref()) else { + return false; + }; + *preset.tools.get(name.as_str()).unwrap_or(&false) + } + } + } +} + +#[cfg(test)] +mod tests { + use agent_settings::ContextServerPreset; + use assistant_tool::ToolRegistry; + use collections::IndexMap; + use gpui::{AppContext, TestAppContext}; + use http_client::FakeHttpClient; + use project::Project; + use settings::{Settings, SettingsStore}; + use ui::SharedString; + + use super::*; + + #[gpui::test] + async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId::default(); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + // Plus all registered MCP tools + expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_custom_mcp_settings(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("custom_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings.context_servers["mcp"] + .tools + .iter() + .filter_map(|(key, enabled)| enabled.then(|| key.to_string())) + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + #[gpui::test] + async fn test_only_built_in(cx: &mut TestAppContext) { + init_test_settings(cx); + + let id = AgentProfileId("write_minus_mcp".into()); + let profile_settings = cx.read(|cx| { + AgentSettings::get_global(cx) + .profiles + .get(&id) + .unwrap() + .clone() + }); + let tool_set = default_tool_set(cx); + + let profile = AgentProfile::new(id.clone(), tool_set); + + let mut enabled_tools = cx + .read(|cx| profile.enabled_tools(cx)) + .into_iter() + .map(|tool| tool.name()) + .collect::>(); + enabled_tools.sort(); + + let mut expected_tools = profile_settings + .tools + .into_iter() + .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) + // Provider dependent + .filter(|tool| tool != "web_search") + .collect::>(); + expected_tools.sort(); + + assert_eq!(enabled_tools, expected_tools); + } + + fn init_test_settings(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + AgentSettings::register(cx); + language_model::init_settings(cx); + ToolRegistry::default_global(cx); + assistant_tools::init(FakeHttpClient::with_404_response(), cx); + }); + + cx.update(|cx| { + let mut agent_settings = AgentSettings::get_global(cx).clone(); + agent_settings.profiles.insert( + AgentProfileId("write_minus_mcp".into()), + AgentProfileSettings { + name: "write_minus_mcp".into(), + enable_all_context_servers: false, + ..agent_settings.profiles[&AgentProfileId::default()].clone() + }, + ); + agent_settings.profiles.insert( + AgentProfileId("custom_mcp".into()), + AgentProfileSettings { + name: "mcp".into(), + tools: IndexMap::default(), + enable_all_context_servers: false, + context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]), + }, + ); + AgentSettings::override_global(agent_settings, cx); + }) + } + + fn context_server_preset() -> ContextServerPreset { + ContextServerPreset { + tools: IndexMap::from_iter([ + ("enabled_mcp_tool".into(), true), + ("disabled_mcp_tool".into(), false), + ]), + } + } + + fn default_tool_set(cx: &mut TestAppContext) -> Entity { + cx.new(|_| { + let mut tool_set = ToolWorkingSet::default(); + tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp"))); + tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp"))); + tool_set + }) + } + + struct FakeTool { + name: String, + source: SharedString, + } + + impl FakeTool { + fn new(name: impl Into, source: impl Into) -> Self { + Self { + name: name.into(), + source: source.into(), + } + } + } + + impl Tool for FakeTool { + fn name(&self) -> String { + self.name.clone() + } + + fn source(&self) -> ToolSource { + ToolSource::ContextServer { + id: self.source.clone(), + } + } + + fn description(&self) -> String { + unimplemented!() + } + + fn icon(&self) -> ui::IconName { + unimplemented!() + } + + fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + unimplemented!() + } + + fn ui_text(&self, _input: &serde_json::Value) -> String { + unimplemented!() + } + + fn run( + self: Arc, + _input: serde_json::Value, + _request: Arc, + _project: Entity, + _action_log: Entity, + _model: Arc, + _window: Option, + _cx: &mut App, + ) -> assistant_tool::ToolResult { + unimplemented!() + } + + fn may_perform_edits(&self) -> bool { + unimplemented!() + } + } +} diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 0ae326bd44..a3958d9acb 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -175,8 +175,7 @@ impl MessageEditor { ) }); - let incompatible_tools = - cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx)); + let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx)); let subscriptions = vec![ cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event), @@ -204,15 +203,8 @@ impl MessageEditor { ) }); - let profile_selector = cx.new(|cx| { - ProfileSelector::new( - fs, - thread.clone(), - thread_store, - editor.focus_handle(cx), - cx, - ) - }); + let profile_selector = + cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx)); Self { editor: editor.clone(), diff --git a/crates/agent/src/profile_selector.rs b/crates/agent/src/profile_selector.rs index a51440ddb9..7a42e45fa4 100644 --- a/crates/agent/src/profile_selector.rs +++ b/crates/agent/src/profile_selector.rs @@ -1,26 +1,24 @@ use std::sync::Arc; -use agent_settings::{ - AgentDockPosition, AgentProfile, AgentProfileId, AgentSettings, GroupedAgentProfiles, - builtin_profiles, -}; +use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; use fs::Fs; -use gpui::{Action, Empty, Entity, FocusHandle, Subscription, WeakEntity, prelude::*}; +use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*}; use language_model::LanguageModelRegistry; use settings::{Settings as _, SettingsStore, update_settings_file}; use ui::{ ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::ResultExt as _; -use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector}; +use crate::{ + ManageProfiles, Thread, ToggleProfileSelector, + agent_profile::{AgentProfile, AvailableProfiles}, +}; pub struct ProfileSelector { - profiles: GroupedAgentProfiles, + profiles: AvailableProfiles, fs: Arc, thread: Entity, - thread_store: WeakEntity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -30,7 +28,6 @@ impl ProfileSelector { pub fn new( fs: Arc, thread: Entity, - thread_store: WeakEntity, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { @@ -39,10 +36,9 @@ impl ProfileSelector { }); Self { - profiles: GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)), + profiles: AgentProfile::available_profiles(cx), fs, thread, - thread_store, menu_handle: PopoverMenuHandle::default(), focus_handle, _subscriptions: vec![settings_subscription], @@ -54,7 +50,7 @@ impl ProfileSelector { } fn refresh_profiles(&mut self, cx: &mut Context) { - self.profiles = GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)); + self.profiles = AgentProfile::available_profiles(cx); } fn build_context_menu( @@ -64,21 +60,30 @@ impl ProfileSelector { ) -> Entity { ContextMenu::build(window, cx, |mut menu, _window, cx| { let settings = AgentSettings::get_global(cx); - for (profile_id, profile) in self.profiles.builtin.iter() { + + let mut found_non_builtin = false; + for (profile_id, profile_name) in self.profiles.iter() { + if !builtin_profiles::is_builtin(profile_id) { + found_non_builtin = true; + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); } - if !self.profiles.custom.is_empty() { + if found_non_builtin { menu = menu.separator().header("Custom Profiles"); - for (profile_id, profile) in self.profiles.custom.iter() { + for (profile_id, profile_name) in self.profiles.iter() { + if builtin_profiles::is_builtin(profile_id) { + continue; + } menu = menu.item(self.menu_entry_for_profile( profile_id.clone(), - profile, + profile_name, settings, cx, )); @@ -99,19 +104,20 @@ impl ProfileSelector { fn menu_entry_for_profile( &self, profile_id: AgentProfileId, - profile: &AgentProfile, + profile_name: &SharedString, settings: &AgentSettings, - _cx: &App, + cx: &App, ) -> ContextMenuEntry { - let documentation = match profile.name.to_lowercase().as_str() { + let documentation = match profile_name.to_lowercase().as_str() { builtin_profiles::WRITE => Some("Get help to write anything."), builtin_profiles::ASK => Some("Chat about your codebase."), builtin_profiles::MINIMAL => Some("Chat about anything with no tools."), _ => None, }; + let thread_profile_id = self.thread.read(cx).profile().id(); - let entry = ContextMenuEntry::new(profile.name.clone()) - .toggleable(IconPosition::End, profile_id == settings.default_profile); + let entry = ContextMenuEntry::new(profile_name.clone()) + .toggleable(IconPosition::End, &profile_id == thread_profile_id); let entry = if let Some(doc_text) = documentation { entry.documentation_aside(documentation_side(settings.dock), move |_| { @@ -123,7 +129,7 @@ impl ProfileSelector { entry.handler({ let fs = self.fs.clone(); - let thread_store = self.thread_store.clone(); + let thread = self.thread.clone(); let profile_id = profile_id.clone(); move |_window, cx| { update_settings_file::(fs.clone(), cx, { @@ -133,11 +139,9 @@ impl ProfileSelector { } }); - thread_store - .update(cx, |this, cx| { - this.load_profile_by_id(profile_id.clone(), cx); - }) - .log_err(); + thread.update(cx, |this, cx| { + this.set_profile(profile_id.clone(), cx); + }); } }) } @@ -146,7 +150,7 @@ impl ProfileSelector { impl Render for ProfileSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings = AgentSettings::get_global(cx); - let profile_id = &settings.default_profile; + let profile_id = self.thread.read(cx).profile().id(); let profile = settings.profiles.get(profile_id); let selected_profile = profile diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index f857557271..bb8cc706bb 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::sync::Arc; use std::time::Instant; -use agent_settings::{AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; use chrono::{DateTime, Utc}; @@ -41,6 +41,7 @@ use uuid::Uuid; use zed_llm_client::{CompletionIntent, CompletionRequestStatus}; use crate::ThreadStore; +use crate::agent_profile::AgentProfile; use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}; use crate::thread_store::{ SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, @@ -360,6 +361,7 @@ pub struct Thread { >, remaining_turns: u32, configured_model: Option, + profile: AgentProfile, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -407,6 +409,7 @@ impl Thread { ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let configured_model = LanguageModelRegistry::read_global(cx).default_model(); + let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { id: ThreadId::new(), @@ -449,6 +452,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -495,6 +499,9 @@ impl Thread { let completion_mode = serialized .completion_mode .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); + let profile_id = serialized + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); Self { id, @@ -554,7 +561,7 @@ impl Thread { pending_checkpoint: None, project: project.clone(), prompt_builder, - tools, + tools: tools.clone(), tool_use, action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), @@ -570,6 +577,7 @@ impl Thread { request_callback: None, remaining_turns: u32::MAX, configured_model, + profile: AgentProfile::new(profile_id, tools), } } @@ -585,6 +593,17 @@ impl Thread { &self.id } + pub fn profile(&self) -> &AgentProfile { + &self.profile + } + + pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { + if &id != self.profile.id() { + self.profile = AgentProfile::new(id, self.tools.clone()); + cx.emit(ThreadEvent::ProfileChanged); + } + } + pub fn is_empty(&self) -> bool { self.messages.is_empty() } @@ -919,8 +938,7 @@ impl Thread { model: Arc, ) -> Vec { if model.supports_tools() { - self.tools() - .read(cx) + self.profile .enabled_tools(cx) .into_iter() .filter_map(|tool| { @@ -1180,6 +1198,7 @@ impl Thread { }), completion_mode: Some(this.completion_mode), tool_use_limit_reached: this.tool_use_limit_reached, + profile: Some(this.profile.id().clone()), }) }) } @@ -2121,7 +2140,7 @@ impl Thread { window: Option, cx: &mut Context, ) { - let available_tools = self.tools.read(cx).enabled_tools(cx); + let available_tools = self.profile.enabled_tools(cx); let tool_list = available_tools .iter() @@ -2213,19 +2232,15 @@ impl Thread { ) -> Task<()> { let tool_name: Arc = tool.name().into(); - let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) { - Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into() - } else { - tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model, - window, - cx, - ) - }; + let tool_result = tool.run( + input, + request, + self.project.clone(), + self.action_log.clone(), + model, + window, + cx, + ); // Store the card separately if it exists if let Some(card) = tool_result.card.clone() { @@ -2344,8 +2359,7 @@ impl Thread { let client = self.project.read(cx).client(); let enabled_tool_names: Vec = self - .tools() - .read(cx) + .profile .enabled_tools(cx) .iter() .map(|tool| tool.name()) @@ -2858,6 +2872,7 @@ pub enum ThreadEvent { ToolUseLimitReached, CancelEditing, CompletionCanceled, + ProfileChanged, } impl EventEmitter for Thread {} @@ -2872,7 +2887,7 @@ struct PendingCompletion { mod tests { use super::*; use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store}; - use agent_settings::{AgentSettings, LanguageModelParameters}; + use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; @@ -3285,6 +3300,71 @@ fn main() {{ ); } + #[gpui::test] + async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Check that we are starting with the default profile + let profile = cx.read(|cx| thread.read(cx).profile.clone()); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + assert_eq!( + profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + + #[gpui::test] + async fn test_serializing_thread_profile(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project( + cx, + json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), + ) + .await; + + let (_workspace, thread_store, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Profile gets serialized with default values + let serialized = thread + .update(cx, |thread, cx| thread.serialize(cx)) + .await + .unwrap(); + + assert_eq!(serialized.profile, Some(AgentProfileId::default())); + + let deserialized = cx.update(|cx| { + thread.update(cx, |thread, cx| { + Thread::deserialize( + thread.id.clone(), + serialized, + thread.project.clone(), + thread.tools.clone(), + thread.prompt_builder.clone(), + thread.project_context.clone(), + None, + cx, + ) + }) + }); + let tool_set = cx.read(|cx| thread_store.read(cx).tools()); + + assert_eq!( + deserialized.profile, + AgentProfile::new(AgentProfileId::default(), tool_set) + ); + } + #[gpui::test] async fn test_temperature_setting(cx: &mut TestAppContext) { init_test_settings(cx); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 964cb8d75e..504280fac4 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -3,9 +3,9 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::{Arc, Mutex}; -use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode}; +use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ToolId, ToolSource, ToolWorkingSet}; +use assistant_tool::{ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::ContextServerId; @@ -25,7 +25,6 @@ use prompt_store::{ UserRulesContext, WorktreeContext, }; use serde::{Deserialize, Serialize}; -use settings::{Settings as _, SettingsStore}; use ui::Window; use util::ResultExt as _; @@ -147,12 +146,7 @@ impl ThreadStore { prompt_store: Option>, cx: &mut Context, ) -> (Self, oneshot::Receiver<()>) { - let mut subscriptions = vec![ - cx.observe_global::(move |this: &mut Self, cx| { - this.load_default_profile(cx); - }), - cx.subscribe(&project, Self::handle_project_event), - ]; + let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; if let Some(prompt_store) = prompt_store.as_ref() { subscriptions.push(cx.subscribe( @@ -200,7 +194,6 @@ impl ThreadStore { _reload_system_prompt_task: reload_system_prompt_task, _subscriptions: subscriptions, }; - this.load_default_profile(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); (this, ready_rx) @@ -520,86 +513,6 @@ impl ThreadStore { }) } - fn load_default_profile(&self, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - self.load_profile_by_id(assistant_settings.default_profile.clone(), cx); - } - - pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context) { - let assistant_settings = AgentSettings::get_global(cx); - - if let Some(profile) = assistant_settings.profiles.get(&profile_id) { - self.load_profile(profile.clone(), cx); - } - } - - pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context) { - self.tools.update(cx, |tools, cx| { - tools.disable_all_tools(cx); - tools.enable( - ToolSource::Native, - &profile - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ); - }); - - if profile.enable_all_context_servers { - for context_server_id in self - .project - .read(cx) - .context_server_store() - .read(cx) - .all_server_ids() - { - self.tools.update(cx, |tools, cx| { - tools.enable_source( - ToolSource::ContextServer { - id: context_server_id.0.into(), - }, - cx, - ); - }); - } - // Enable all the tools from all context servers, but disable the ones that are explicitly disabled - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.disable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| (!enabled).then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } else { - for (context_server_id, preset) in profile.context_servers { - self.tools.update(cx, |tools, cx| { - tools.enable( - ToolSource::ContextServer { - id: context_server_id.into(), - }, - &preset - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then(|| tool)) - .collect::>(), - cx, - ) - }) - } - } - } - fn register_context_server_handlers(&self, cx: &mut Context) { cx.subscribe( &self.project.read(cx).context_server_store(), @@ -618,6 +531,7 @@ impl ThreadStore { match event { project::context_server_store::Event::ServerStatusChanged { server_id, status } => { match status { + ContextServerStatus::Starting => {} ContextServerStatus::Running => { if let Some(server) = context_server_store.read(cx).get_running_server(server_id) @@ -656,10 +570,9 @@ impl ThreadStore { .log_err(); if let Some(tool_ids) = tool_ids { - this.update(cx, |this, cx| { + this.update(cx, |this, _| { this.context_server_tool_ids .insert(server_id, tool_ids); - this.load_default_profile(cx); }) .log_err(); } @@ -675,10 +588,8 @@ impl ThreadStore { tool_working_set.update(cx, |tool_working_set, _| { tool_working_set.remove(&tool_ids); }); - self.load_default_profile(cx); } } - _ => {} } } } @@ -714,6 +625,8 @@ pub struct SerializedThread { pub completion_mode: Option, #[serde(default)] pub tool_use_limit_reached: bool, + #[serde(default)] + pub profile: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -856,6 +769,7 @@ impl LegacySerializedThread { model: None, completion_mode: None, tool_use_limit_reached: false, + profile: None, } } } diff --git a/crates/agent/src/tool_compatibility.rs b/crates/agent/src/tool_compatibility.rs index 141d87c96f..6193b0929d 100644 --- a/crates/agent/src/tool_compatibility.rs +++ b/crates/agent/src/tool_compatibility.rs @@ -1,30 +1,33 @@ use std::sync::Arc; -use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent}; +use assistant_tool::{Tool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use language_model::{LanguageModel, LanguageModelToolSchemaFormat}; use ui::prelude::*; +use crate::{Thread, ThreadEvent}; + pub struct IncompatibleToolsState { cache: HashMap>>, - tool_working_set: Entity, - _tool_working_set_subscription: Subscription, + thread: Entity, + _thread_subscription: Subscription, } impl IncompatibleToolsState { - pub fn new(tool_working_set: Entity, cx: &mut Context) -> Self { + pub fn new(thread: Entity, cx: &mut Context) -> Self { let _tool_working_set_subscription = - cx.subscribe(&tool_working_set, |this, _, event, _| match event { - ToolWorkingSetEvent::EnabledToolsChanged => { + cx.subscribe(&thread, |this, _, event, _| match event { + ThreadEvent::ProfileChanged => { this.cache.clear(); } + _ => {} }); Self { cache: HashMap::default(), - tool_working_set, - _tool_working_set_subscription, + thread, + _thread_subscription: _tool_working_set_subscription, } } @@ -36,8 +39,9 @@ impl IncompatibleToolsState { self.cache .entry(model.tool_input_format()) .or_insert_with(|| { - self.tool_working_set + self.thread .read(cx) + .profile() .enabled_tools(cx) .iter() .filter(|tool| tool.input_schema(model.tool_input_format()).is_err()) diff --git a/crates/agent_settings/Cargo.toml b/crates/agent_settings/Cargo.toml index 200c531c3c..c6a4bedbb5 100644 --- a/crates/agent_settings/Cargo.toml +++ b/crates/agent_settings/Cargo.toml @@ -16,7 +16,6 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true collections.workspace = true gpui.workspace = true -indexmap.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true diff --git a/crates/agent_settings/src/agent_profile.rs b/crates/agent_settings/src/agent_profile.rs index 599932114a..a6b8633b34 100644 --- a/crates/agent_settings/src/agent_profile.rs +++ b/crates/agent_settings/src/agent_profile.rs @@ -17,29 +17,6 @@ pub mod builtin_profiles { } } -#[derive(Default)] -pub struct GroupedAgentProfiles { - pub builtin: IndexMap, - pub custom: IndexMap, -} - -impl GroupedAgentProfiles { - pub fn from_settings(settings: &crate::AgentSettings) -> Self { - let mut builtin = IndexMap::default(); - let mut custom = IndexMap::default(); - - for (profile_id, profile) in settings.profiles.clone() { - if builtin_profiles::is_builtin(&profile_id) { - builtin.insert(profile_id, profile); - } else { - custom.insert(profile_id, profile); - } - } - - Self { builtin, custom } - } -} - #[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentProfileId(pub Arc); @@ -63,7 +40,7 @@ impl Default for AgentProfileId { /// A profile for the Zed Agent that controls its behavior. #[derive(Debug, Clone)] -pub struct AgentProfile { +pub struct AgentProfileSettings { /// The name of the profile. pub name: SharedString, pub tools: IndexMap, bool>, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index 36480f30d5..9e8fd0c699 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -102,7 +102,7 @@ pub struct AgentSettings { pub using_outdated_settings_version: bool, pub default_profile: AgentProfileId, pub default_view: DefaultView, - pub profiles: IndexMap, + pub profiles: IndexMap, pub always_allow_tool_actions: bool, pub notify_when_agent_waiting: NotifyWhenAgentWaiting, pub play_sound_when_agent_done: bool, @@ -531,7 +531,7 @@ impl AgentSettingsContent { pub fn create_profile( &mut self, profile_id: AgentProfileId, - profile: AgentProfile, + profile_settings: AgentProfileSettings, ) -> Result<()> { self.v2_setting(|settings| { let profiles = settings.profiles.get_or_insert_default(); @@ -542,10 +542,10 @@ impl AgentSettingsContent { profiles.insert( profile_id, AgentProfileContent { - name: profile.name.into(), - tools: profile.tools, - enable_all_context_servers: Some(profile.enable_all_context_servers), - context_servers: profile + name: profile_settings.name.into(), + tools: profile_settings.tools, + enable_all_context_servers: Some(profile_settings.enable_all_context_servers), + context_servers: profile_settings .context_servers .into_iter() .map(|(server_id, preset)| { @@ -910,7 +910,7 @@ impl Settings for AgentSettings { .extend(profiles.into_iter().map(|(id, profile)| { ( id, - AgentProfile { + AgentProfileSettings { name: profile.name.into(), tools: profile.tools, enable_all_context_servers: profile diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index c7e20d3517..c72c52ba7a 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use collections::{HashMap, HashSet, IndexMap}; -use gpui::{App, Context, EventEmitter}; +use collections::{HashMap, IndexMap}; +use gpui::App; use crate::{Tool, ToolRegistry, ToolSource}; @@ -13,17 +13,9 @@ pub struct ToolId(usize); pub struct ToolWorkingSet { context_server_tools_by_id: HashMap>, context_server_tools_by_name: HashMap>, - enabled_sources: HashSet, - enabled_tools_by_source: HashMap>>, next_tool_id: ToolId, } -pub enum ToolWorkingSetEvent { - EnabledToolsChanged, -} - -impl EventEmitter for ToolWorkingSet {} - impl ToolWorkingSet { pub fn tool(&self, name: &str, cx: &App) -> Option> { self.context_server_tools_by_name @@ -57,42 +49,6 @@ impl ToolWorkingSet { tools_by_source } - pub fn enabled_tools(&self, cx: &App) -> Vec> { - let all_tools = self.tools(cx); - - all_tools - .into_iter() - .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into())) - .collect() - } - - pub fn disable_all_tools(&mut self, cx: &mut Context) { - self.enabled_tools_by_source.clear(); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context) { - self.enabled_sources.insert(source.clone()); - - let tools_by_source = self.tools_by_source(cx); - if let Some(tools) = tools_by_source.get(&source) { - self.enabled_tools_by_source.insert( - source, - tools - .into_iter() - .map(|tool| tool.name().into()) - .collect::>(), - ); - } - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context) { - self.enabled_sources.remove(source); - self.enabled_tools_by_source.remove(source); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn insert(&mut self, tool: Arc) -> ToolId { let tool_id = self.next_tool_id; self.next_tool_id.0 += 1; @@ -102,42 +58,6 @@ impl ToolWorkingSet { tool_id } - pub fn is_enabled(&self, source: &ToolSource, name: &Arc) -> bool { - self.enabled_tools_by_source - .get(source) - .map_or(false, |enabled_tools| enabled_tools.contains(name)) - } - - pub fn is_disabled(&self, source: &ToolSource, name: &Arc) -> bool { - !self.is_enabled(source, name) - } - - pub fn enable( - &mut self, - source: ToolSource, - tools_to_enable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .extend(tools_to_enable.into_iter().cloned()); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - - pub fn disable( - &mut self, - source: ToolSource, - tools_to_disable: &[Arc], - cx: &mut Context, - ) { - self.enabled_tools_by_source - .entry(source) - .or_default() - .retain(|name| !tools_to_disable.contains(name)); - cx.emit(ToolWorkingSetEvent::EnabledToolsChanged); - } - pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) { self.context_server_tools_by_id .retain(|id, _| !tool_ids_to_remove.contains(id)); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 020aedbc57..a91fdac992 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -80,7 +80,6 @@ zed_llm_client.workspace = true agent_settings.workspace = true assistant_context_editor.workspace = true assistant_slash_command.workspace = true -assistant_tool.workspace = true async-trait.workspace = true audio.workspace = true buffer_diff.workspace = true diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index dc384668c3..85af49e339 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -294,6 +294,7 @@ impl ExampleContext { | ThreadEvent::MessageDeleted(_) | ThreadEvent::SummaryChanged | ThreadEvent::SummaryGenerated + | ThreadEvent::ProfileChanged | ThreadEvent::ReceivedTextChunk | ThreadEvent::StreamedToolUse { .. } | ThreadEvent::CheckpointChanged diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 94fdaf90bf..f28165e859 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -306,17 +306,19 @@ impl ExampleInstance { let thread_store = thread_store.await?; - let profile_id = meta.profile_id.clone(); - thread_store.update(cx, |thread_store, cx| thread_store.load_profile_by_id(profile_id, cx)).expect("Failed to load profile"); let thread = thread_store.update(cx, |thread_store, cx| { - if let Some(json) = &meta.existing_thread_json { + let thread = if let Some(json) = &meta.existing_thread_json { let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread"); thread_store.create_thread_from_serialized(serialized, cx) } else { thread_store.create_thread(cx) - } + }; + thread.update(cx, |thread, cx| { + thread.set_profile(meta.profile_id.clone(), cx); + }); + thread })?;