Store profile per thread (#31907)
This allows storing the profile per thread, as well as moving the logic of which tools are enabled or not to the profile itself. This makes it much easier to switch between profiles, means there is less global state being changed on every profile change. Release Notes: - agent panel: allow saving the profile per thread --------- Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>
This commit is contained in:
parent
7afee64119
commit
709523bf36
21 changed files with 556 additions and 369 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -59,6 +59,7 @@ dependencies = [
|
|||
"assistant_slash_command",
|
||||
"assistant_slash_commands",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"async-watch",
|
||||
"audio",
|
||||
"buffer_diff",
|
||||
|
@ -147,7 +148,6 @@ dependencies = [
|
|||
"deepseek",
|
||||
"fs",
|
||||
"gpui",
|
||||
"indexmap",
|
||||
"language_model",
|
||||
"lmstudio",
|
||||
"log",
|
||||
|
@ -2987,7 +2987,6 @@ dependencies = [
|
|||
"anyhow",
|
||||
"assistant_context_editor",
|
||||
"assistant_slash_command",
|
||||
"assistant_tool",
|
||||
"async-stripe",
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
|
|
|
@ -771,7 +771,6 @@
|
|||
"tools": {
|
||||
"copy_path": true,
|
||||
"create_directory": true,
|
||||
"create_file": true,
|
||||
"delete_path": true,
|
||||
"diagnostics": true,
|
||||
"edit_file": true,
|
||||
|
|
|
@ -102,6 +102,7 @@ zed_llm_client.workspace = true
|
|||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
assistant_tools.workspace = true
|
||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
|
|
|
@ -1144,6 +1144,10 @@ impl ActiveThread {
|
|||
cx,
|
||||
);
|
||||
}
|
||||
ThreadEvent::ProfileChanged => {
|
||||
self.save_thread(cx);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ mod agent_configuration;
|
|||
mod agent_diff;
|
||||
mod agent_model_selector;
|
||||
mod agent_panel;
|
||||
mod agent_profile;
|
||||
mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
|
|
|
@ -2,25 +2,21 @@ mod profile_modal_header;
|
|||
|
||||
use std::sync::Arc;
|
||||
|
||||
use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, builtin_profiles};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use convert_case::{Case, Casing as _};
|
||||
use editor::Editor;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, WeakEntity,
|
||||
prelude::*,
|
||||
};
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*};
|
||||
use settings::Settings as _;
|
||||
use ui::{
|
||||
KeyBinding, ListItem, ListItemSpacing, ListSeparator, Navigable, NavigableEntry, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader;
|
||||
use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate};
|
||||
use crate::{AgentPanel, ManageProfiles, ThreadStore};
|
||||
use crate::agent_profile::AgentProfile;
|
||||
use crate::{AgentPanel, ManageProfiles};
|
||||
|
||||
use super::tool_picker::ToolPickerMode;
|
||||
|
||||
|
@ -103,7 +99,6 @@ pub struct NewProfileMode {
|
|||
pub struct ManageProfilesModal {
|
||||
fs: Arc<dyn Fs>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
mode: Mode,
|
||||
}
|
||||
|
@ -119,9 +114,8 @@ impl ManageProfilesModal {
|
|||
let fs = workspace.app_state().fs.clone();
|
||||
let thread_store = panel.read(cx).thread_store();
|
||||
let tools = thread_store.read(cx).tools();
|
||||
let thread_store = thread_store.downgrade();
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
let mut this = Self::new(fs, tools, thread_store, window, cx);
|
||||
let mut this = Self::new(fs, tools, window, cx);
|
||||
|
||||
if let Some(profile_id) = action.customize_tools.clone() {
|
||||
this.configure_builtin_tools(profile_id, window, cx);
|
||||
|
@ -136,7 +130,6 @@ impl ManageProfilesModal {
|
|||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
|
@ -145,7 +138,6 @@ impl ManageProfilesModal {
|
|||
Self {
|
||||
fs,
|
||||
tools,
|
||||
thread_store,
|
||||
focus_handle,
|
||||
mode: Mode::choose_profile(window, cx),
|
||||
}
|
||||
|
@ -206,7 +198,6 @@ impl ManageProfilesModal {
|
|||
ToolPickerMode::McpTools,
|
||||
self.fs.clone(),
|
||||
self.tools.clone(),
|
||||
self.thread_store.clone(),
|
||||
profile_id.clone(),
|
||||
profile,
|
||||
cx,
|
||||
|
@ -244,7 +235,6 @@ impl ManageProfilesModal {
|
|||
ToolPickerMode::BuiltinTools,
|
||||
self.fs.clone(),
|
||||
self.tools.clone(),
|
||||
self.thread_store.clone(),
|
||||
profile_id.clone(),
|
||||
profile,
|
||||
cx,
|
||||
|
@ -270,32 +260,10 @@ impl ManageProfilesModal {
|
|||
match &self.mode {
|
||||
Mode::ChooseProfile { .. } => {}
|
||||
Mode::NewProfile(mode) => {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
|
||||
let base_profile = mode
|
||||
.base_profile_id
|
||||
.as_ref()
|
||||
.and_then(|profile_id| settings.profiles.get(profile_id).cloned());
|
||||
|
||||
let name = mode.name_editor.read(cx).text(cx);
|
||||
let profile_id = AgentProfileId(name.to_case(Case::Kebab).into());
|
||||
|
||||
let profile = AgentProfile {
|
||||
name: name.into(),
|
||||
tools: base_profile
|
||||
.as_ref()
|
||||
.map(|profile| profile.tools.clone())
|
||||
.unwrap_or_default(),
|
||||
enable_all_context_servers: base_profile
|
||||
.as_ref()
|
||||
.map(|profile| profile.enable_all_context_servers)
|
||||
.unwrap_or_default(),
|
||||
context_servers: base_profile
|
||||
.map(|profile| profile.context_servers)
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
self.create_profile(profile_id.clone(), profile, cx);
|
||||
let profile_id =
|
||||
AgentProfile::create(name, mode.base_profile_id.clone(), self.fs.clone(), cx);
|
||||
self.view_profile(profile_id, window, cx);
|
||||
}
|
||||
Mode::ViewProfile(_) => {}
|
||||
|
@ -325,19 +293,6 @@ impl ManageProfilesModal {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_profile(
|
||||
&self,
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
|
||||
move |settings, _cx| {
|
||||
settings.create_profile(profile_id, profile).log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ModalView for ManageProfilesModal {}
|
||||
|
@ -520,14 +475,13 @@ impl ManageProfilesModal {
|
|||
) -> impl IntoElement {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile_name = settings
|
||||
.profiles
|
||||
.get(&mode.profile_id)
|
||||
.map(|profile| profile.name.clone())
|
||||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
let icon = match profile_id.as_str() {
|
||||
let icon = match mode.profile_id.as_str() {
|
||||
"write" => IconName::Pencil,
|
||||
"ask" => IconName::MessageBubbles,
|
||||
_ => IconName::UserRoundPen,
|
||||
|
|
|
@ -1,19 +1,17 @@
|
|||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use agent_settings::{
|
||||
AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent,
|
||||
AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
|
||||
ContextServerPresetContent,
|
||||
};
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use settings::{Settings as _, update_settings_file};
|
||||
use settings::update_settings_file;
|
||||
use ui::{ListItem, ListItemSpacing, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::ThreadStore;
|
||||
|
||||
pub struct ToolPicker {
|
||||
picker: Entity<Picker<ToolPickerDelegate>>,
|
||||
}
|
||||
|
@ -71,11 +69,10 @@ pub enum PickerItem {
|
|||
|
||||
pub struct ToolPickerDelegate {
|
||||
tool_picker: WeakEntity<ToolPicker>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
fs: Arc<dyn Fs>,
|
||||
items: Arc<Vec<PickerItem>>,
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
profile_settings: AgentProfileSettings,
|
||||
filtered_items: Vec<PickerItem>,
|
||||
selected_index: usize,
|
||||
mode: ToolPickerMode,
|
||||
|
@ -86,20 +83,18 @@ impl ToolPickerDelegate {
|
|||
mode: ToolPickerMode,
|
||||
fs: Arc<dyn Fs>,
|
||||
tool_set: Entity<ToolWorkingSet>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
profile_settings: AgentProfileSettings,
|
||||
cx: &mut Context<ToolPicker>,
|
||||
) -> Self {
|
||||
let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
|
||||
|
||||
Self {
|
||||
tool_picker: cx.entity().downgrade(),
|
||||
thread_store,
|
||||
fs,
|
||||
items,
|
||||
profile_id,
|
||||
profile,
|
||||
profile_settings,
|
||||
filtered_items: Vec::new(),
|
||||
selected_index: 0,
|
||||
mode,
|
||||
|
@ -249,28 +244,31 @@ impl PickerDelegate for ToolPickerDelegate {
|
|||
};
|
||||
|
||||
let is_currently_enabled = if let Some(server_id) = server_id.clone() {
|
||||
let preset = self.profile.context_servers.entry(server_id).or_default();
|
||||
let preset = self
|
||||
.profile_settings
|
||||
.context_servers
|
||||
.entry(server_id)
|
||||
.or_default();
|
||||
let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
|
||||
*preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
|
||||
is_enabled
|
||||
} else {
|
||||
let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default();
|
||||
*self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled;
|
||||
let is_enabled = *self
|
||||
.profile_settings
|
||||
.tools
|
||||
.entry(tool_name.clone())
|
||||
.or_default();
|
||||
*self
|
||||
.profile_settings
|
||||
.tools
|
||||
.entry(tool_name.clone())
|
||||
.or_default() = !is_enabled;
|
||||
is_enabled
|
||||
};
|
||||
|
||||
let active_profile_id = &AgentSettings::get_global(cx).default_profile;
|
||||
if active_profile_id == &self.profile_id {
|
||||
self.thread_store
|
||||
.update(cx, |this, cx| {
|
||||
this.load_profile(self.profile.clone(), cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
|
||||
let profile_id = self.profile_id.clone();
|
||||
let default_profile = self.profile.clone();
|
||||
let default_profile = self.profile_settings.clone();
|
||||
let server_id = server_id.clone();
|
||||
let tool_name = tool_name.clone();
|
||||
move |settings: &mut AgentSettingsContent, _cx| {
|
||||
|
@ -348,14 +346,18 @@ impl PickerDelegate for ToolPickerDelegate {
|
|||
),
|
||||
PickerItem::Tool { name, server_id } => {
|
||||
let is_enabled = if let Some(server_id) = server_id {
|
||||
self.profile
|
||||
self.profile_settings
|
||||
.context_servers
|
||||
.get(server_id.as_ref())
|
||||
.and_then(|preset| preset.tools.get(name))
|
||||
.copied()
|
||||
.unwrap_or(self.profile.enable_all_context_servers)
|
||||
.unwrap_or(self.profile_settings.enable_all_context_servers)
|
||||
} else {
|
||||
self.profile.tools.get(name).copied().unwrap_or(false)
|
||||
self.profile_settings
|
||||
.tools
|
||||
.get(name)
|
||||
.copied()
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
Some(
|
||||
|
|
|
@ -1378,7 +1378,8 @@ impl AgentDiff {
|
|||
| ThreadEvent::CheckpointChanged
|
||||
| ThreadEvent::ToolConfirmationNeeded
|
||||
| ThreadEvent::ToolUseLimitReached
|
||||
| ThreadEvent::CancelEditing => {}
|
||||
| ThreadEvent::CancelEditing
|
||||
| ThreadEvent::ProfileChanged => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
334
crates/agent/src/agent_profile.rs
Normal file
334
crates/agent/src/agent_profile.rs
Normal file
|
@ -0,0 +1,334 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
|
||||
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
|
||||
use collections::IndexMap;
|
||||
use convert_case::{Case, Casing};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::SharedString;
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct AgentProfile {
|
||||
id: AgentProfileId,
|
||||
tool_set: Entity<ToolWorkingSet>,
|
||||
}
|
||||
|
||||
pub type AvailableProfiles = IndexMap<AgentProfileId, SharedString>;
|
||||
|
||||
impl AgentProfile {
|
||||
pub fn new(id: AgentProfileId, tool_set: Entity<ToolWorkingSet>) -> Self {
|
||||
Self { id, tool_set }
|
||||
}
|
||||
|
||||
/// Saves a new profile to the settings.
|
||||
pub fn create(
|
||||
name: String,
|
||||
base_profile_id: Option<AgentProfileId>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &App,
|
||||
) -> AgentProfileId {
|
||||
let id = AgentProfileId(name.to_case(Case::Kebab).into());
|
||||
|
||||
let base_profile =
|
||||
base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned());
|
||||
|
||||
let profile_settings = AgentProfileSettings {
|
||||
name: name.into(),
|
||||
tools: base_profile
|
||||
.as_ref()
|
||||
.map(|profile| profile.tools.clone())
|
||||
.unwrap_or_default(),
|
||||
enable_all_context_servers: base_profile
|
||||
.as_ref()
|
||||
.map(|profile| profile.enable_all_context_servers)
|
||||
.unwrap_or_default(),
|
||||
context_servers: base_profile
|
||||
.map(|profile| profile.context_servers)
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
update_settings_file::<AgentSettings>(fs, cx, {
|
||||
let id = id.clone();
|
||||
move |settings, _cx| {
|
||||
settings.create_profile(id, profile_settings).log_err();
|
||||
}
|
||||
});
|
||||
|
||||
id
|
||||
}
|
||||
|
||||
/// Returns a map of AgentProfileIds to their names
|
||||
pub fn available_profiles(cx: &App) -> AvailableProfiles {
|
||||
let mut profiles = AvailableProfiles::default();
|
||||
for (id, profile) in AgentSettings::get_global(cx).profiles.iter() {
|
||||
profiles.insert(id.clone(), profile.name.clone());
|
||||
}
|
||||
profiles
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &AgentProfileId {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
self.tool_set
|
||||
.read(cx)
|
||||
.tools(cx)
|
||||
.into_iter()
|
||||
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
|
||||
match source {
|
||||
ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
|
||||
ToolSource::ContextServer { id } => {
|
||||
if settings.enable_all_context_servers {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(preset) = settings.context_servers.get(id.as_ref()) else {
|
||||
return false;
|
||||
};
|
||||
*preset.tools.get(name.as_str()).unwrap_or(&false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use agent_settings::ContextServerPreset;
|
||||
use assistant_tool::ToolRegistry;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use http_client::FakeHttpClient;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use ui::SharedString;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let id = AgentProfileId::default();
|
||||
let profile_settings = cx.read(|cx| {
|
||||
AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&id)
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
let tool_set = default_tool_set(cx);
|
||||
|
||||
let profile = AgentProfile::new(id.clone(), tool_set);
|
||||
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
let mut expected_tools = profile_settings
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
|
||||
// Provider dependent
|
||||
.filter(|tool| tool != "web_search")
|
||||
.collect::<Vec<_>>();
|
||||
// Plus all registered MCP tools
|
||||
expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]);
|
||||
expected_tools.sort();
|
||||
|
||||
assert_eq!(enabled_tools, expected_tools);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_custom_mcp_settings(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let id = AgentProfileId("custom_mcp".into());
|
||||
let profile_settings = cx.read(|cx| {
|
||||
AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&id)
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
let tool_set = default_tool_set(cx);
|
||||
|
||||
let profile = AgentProfile::new(id.clone(), tool_set);
|
||||
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
let mut expected_tools = profile_settings.context_servers["mcp"]
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(key, enabled)| enabled.then(|| key.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
expected_tools.sort();
|
||||
|
||||
assert_eq!(enabled_tools, expected_tools);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_only_built_in(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let id = AgentProfileId("write_minus_mcp".into());
|
||||
let profile_settings = cx.read(|cx| {
|
||||
AgentSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&id)
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
let tool_set = default_tool_set(cx);
|
||||
|
||||
let profile = AgentProfile::new(id.clone(), tool_set);
|
||||
|
||||
let mut enabled_tools = cx
|
||||
.read(|cx| profile.enabled_tools(cx))
|
||||
.into_iter()
|
||||
.map(|tool| tool.name())
|
||||
.collect::<Vec<_>>();
|
||||
enabled_tools.sort();
|
||||
|
||||
let mut expected_tools = profile_settings
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then_some(tool.to_string()))
|
||||
// Provider dependent
|
||||
.filter(|tool| tool != "web_search")
|
||||
.collect::<Vec<_>>();
|
||||
expected_tools.sort();
|
||||
|
||||
assert_eq!(enabled_tools, expected_tools);
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
AgentSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
assistant_tools::init(FakeHttpClient::with_404_response(), cx);
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
let mut agent_settings = AgentSettings::get_global(cx).clone();
|
||||
agent_settings.profiles.insert(
|
||||
AgentProfileId("write_minus_mcp".into()),
|
||||
AgentProfileSettings {
|
||||
name: "write_minus_mcp".into(),
|
||||
enable_all_context_servers: false,
|
||||
..agent_settings.profiles[&AgentProfileId::default()].clone()
|
||||
},
|
||||
);
|
||||
agent_settings.profiles.insert(
|
||||
AgentProfileId("custom_mcp".into()),
|
||||
AgentProfileSettings {
|
||||
name: "mcp".into(),
|
||||
tools: IndexMap::default(),
|
||||
enable_all_context_servers: false,
|
||||
context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]),
|
||||
},
|
||||
);
|
||||
AgentSettings::override_global(agent_settings, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn context_server_preset() -> ContextServerPreset {
|
||||
ContextServerPreset {
|
||||
tools: IndexMap::from_iter([
|
||||
("enabled_mcp_tool".into(), true),
|
||||
("disabled_mcp_tool".into(), false),
|
||||
]),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
|
||||
cx.new(|_| {
|
||||
let mut tool_set = ToolWorkingSet::default();
|
||||
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
|
||||
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
|
||||
tool_set
|
||||
})
|
||||
}
|
||||
|
||||
struct FakeTool {
|
||||
name: String,
|
||||
source: SharedString,
|
||||
}
|
||||
|
||||
impl FakeTool {
|
||||
fn new(name: impl Into<String>, source: impl Into<SharedString>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
source: source.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for FakeTool {
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn source(&self) -> ToolSource {
|
||||
ToolSource::ContextServer {
|
||||
id: self.source.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn icon(&self) -> ui::IconName {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
_input: serde_json::Value,
|
||||
_request: Arc<language_model::LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<assistant_tool::ActionLog>,
|
||||
_model: Arc<dyn language_model::LanguageModel>,
|
||||
_window: Option<gpui::AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> assistant_tool::ToolResult {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn may_perform_edits(&self) -> bool {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -175,8 +175,7 @@ impl MessageEditor {
|
|||
)
|
||||
});
|
||||
|
||||
let incompatible_tools =
|
||||
cx.new(|cx| IncompatibleToolsState::new(thread.read(cx).tools().clone(), cx));
|
||||
let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx));
|
||||
|
||||
let subscriptions = vec![
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
|
||||
|
@ -204,15 +203,8 @@ impl MessageEditor {
|
|||
)
|
||||
});
|
||||
|
||||
let profile_selector = cx.new(|cx| {
|
||||
ProfileSelector::new(
|
||||
fs,
|
||||
thread.clone(),
|
||||
thread_store,
|
||||
editor.focus_handle(cx),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let profile_selector =
|
||||
cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx));
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
|
|
|
@ -1,26 +1,24 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use agent_settings::{
|
||||
AgentDockPosition, AgentProfile, AgentProfileId, AgentSettings, GroupedAgentProfiles,
|
||||
builtin_profiles,
|
||||
};
|
||||
use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
|
||||
use fs::Fs;
|
||||
use gpui::{Action, Empty, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
|
||||
use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
use ui::{
|
||||
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
|
||||
prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
|
||||
use crate::{
|
||||
ManageProfiles, Thread, ToggleProfileSelector,
|
||||
agent_profile::{AgentProfile, AvailableProfiles},
|
||||
};
|
||||
|
||||
pub struct ProfileSelector {
|
||||
profiles: GroupedAgentProfiles,
|
||||
profiles: AvailableProfiles,
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
focus_handle: FocusHandle,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
|
@ -30,7 +28,6 @@ impl ProfileSelector {
|
|||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
|
@ -39,10 +36,9 @@ impl ProfileSelector {
|
|||
});
|
||||
|
||||
Self {
|
||||
profiles: GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx)),
|
||||
profiles: AgentProfile::available_profiles(cx),
|
||||
fs,
|
||||
thread,
|
||||
thread_store,
|
||||
menu_handle: PopoverMenuHandle::default(),
|
||||
focus_handle,
|
||||
_subscriptions: vec![settings_subscription],
|
||||
|
@ -54,7 +50,7 @@ impl ProfileSelector {
|
|||
}
|
||||
|
||||
fn refresh_profiles(&mut self, cx: &mut Context<Self>) {
|
||||
self.profiles = GroupedAgentProfiles::from_settings(AgentSettings::get_global(cx));
|
||||
self.profiles = AgentProfile::available_profiles(cx);
|
||||
}
|
||||
|
||||
fn build_context_menu(
|
||||
|
@ -64,21 +60,30 @@ impl ProfileSelector {
|
|||
) -> Entity<ContextMenu> {
|
||||
ContextMenu::build(window, cx, |mut menu, _window, cx| {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
for (profile_id, profile) in self.profiles.builtin.iter() {
|
||||
|
||||
let mut found_non_builtin = false;
|
||||
for (profile_id, profile_name) in self.profiles.iter() {
|
||||
if !builtin_profiles::is_builtin(profile_id) {
|
||||
found_non_builtin = true;
|
||||
continue;
|
||||
}
|
||||
menu = menu.item(self.menu_entry_for_profile(
|
||||
profile_id.clone(),
|
||||
profile,
|
||||
profile_name,
|
||||
settings,
|
||||
cx,
|
||||
));
|
||||
}
|
||||
|
||||
if !self.profiles.custom.is_empty() {
|
||||
if found_non_builtin {
|
||||
menu = menu.separator().header("Custom Profiles");
|
||||
for (profile_id, profile) in self.profiles.custom.iter() {
|
||||
for (profile_id, profile_name) in self.profiles.iter() {
|
||||
if builtin_profiles::is_builtin(profile_id) {
|
||||
continue;
|
||||
}
|
||||
menu = menu.item(self.menu_entry_for_profile(
|
||||
profile_id.clone(),
|
||||
profile,
|
||||
profile_name,
|
||||
settings,
|
||||
cx,
|
||||
));
|
||||
|
@ -99,19 +104,20 @@ impl ProfileSelector {
|
|||
fn menu_entry_for_profile(
|
||||
&self,
|
||||
profile_id: AgentProfileId,
|
||||
profile: &AgentProfile,
|
||||
profile_name: &SharedString,
|
||||
settings: &AgentSettings,
|
||||
_cx: &App,
|
||||
cx: &App,
|
||||
) -> ContextMenuEntry {
|
||||
let documentation = match profile.name.to_lowercase().as_str() {
|
||||
let documentation = match profile_name.to_lowercase().as_str() {
|
||||
builtin_profiles::WRITE => Some("Get help to write anything."),
|
||||
builtin_profiles::ASK => Some("Chat about your codebase."),
|
||||
builtin_profiles::MINIMAL => Some("Chat about anything with no tools."),
|
||||
_ => None,
|
||||
};
|
||||
let thread_profile_id = self.thread.read(cx).profile().id();
|
||||
|
||||
let entry = ContextMenuEntry::new(profile.name.clone())
|
||||
.toggleable(IconPosition::End, profile_id == settings.default_profile);
|
||||
let entry = ContextMenuEntry::new(profile_name.clone())
|
||||
.toggleable(IconPosition::End, &profile_id == thread_profile_id);
|
||||
|
||||
let entry = if let Some(doc_text) = documentation {
|
||||
entry.documentation_aside(documentation_side(settings.dock), move |_| {
|
||||
|
@ -123,7 +129,7 @@ impl ProfileSelector {
|
|||
|
||||
entry.handler({
|
||||
let fs = self.fs.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let thread = self.thread.clone();
|
||||
let profile_id = profile_id.clone();
|
||||
move |_window, cx| {
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, {
|
||||
|
@ -133,11 +139,9 @@ impl ProfileSelector {
|
|||
}
|
||||
});
|
||||
|
||||
thread_store
|
||||
.update(cx, |this, cx| {
|
||||
this.load_profile_by_id(profile_id.clone(), cx);
|
||||
})
|
||||
.log_err();
|
||||
thread.update(cx, |this, cx| {
|
||||
this.set_profile(profile_id.clone(), cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -146,7 +150,7 @@ impl ProfileSelector {
|
|||
impl Render for ProfileSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile_id = self.thread.read(cx).profile().id();
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
|
||||
let selected_profile = profile
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::ops::Range;
|
|||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
@ -41,6 +41,7 @@ use uuid::Uuid;
|
|||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
|
||||
use crate::ThreadStore;
|
||||
use crate::agent_profile::AgentProfile;
|
||||
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
|
||||
use crate::thread_store::{
|
||||
SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
|
||||
|
@ -360,6 +361,7 @@ pub struct Thread {
|
|||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
profile: AgentProfile,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
|
@ -407,6 +409,7 @@ impl Thread {
|
|||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
|
@ -449,6 +452,7 @@ impl Thread {
|
|||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
profile: AgentProfile::new(profile_id, tools),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -495,6 +499,9 @@ impl Thread {
|
|||
let completion_mode = serialized
|
||||
.completion_mode
|
||||
.unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
|
||||
let profile_id = serialized
|
||||
.profile
|
||||
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
|
||||
|
||||
Self {
|
||||
id,
|
||||
|
@ -554,7 +561,7 @@ impl Thread {
|
|||
pending_checkpoint: None,
|
||||
project: project.clone(),
|
||||
prompt_builder,
|
||||
tools,
|
||||
tools: tools.clone(),
|
||||
tool_use,
|
||||
action_log: cx.new(|_| ActionLog::new(project)),
|
||||
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
|
||||
|
@ -570,6 +577,7 @@ impl Thread {
|
|||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
profile: AgentProfile::new(profile_id, tools),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -585,6 +593,17 @@ impl Thread {
|
|||
&self.id
|
||||
}
|
||||
|
||||
pub fn profile(&self) -> &AgentProfile {
|
||||
&self.profile
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
|
||||
if &id != self.profile.id() {
|
||||
self.profile = AgentProfile::new(id, self.tools.clone());
|
||||
cx.emit(ThreadEvent::ProfileChanged);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.messages.is_empty()
|
||||
}
|
||||
|
@ -919,8 +938,7 @@ impl Thread {
|
|||
model: Arc<dyn LanguageModel>,
|
||||
) -> Vec<LanguageModelRequestTool> {
|
||||
if model.supports_tools() {
|
||||
self.tools()
|
||||
.read(cx)
|
||||
self.profile
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
|
@ -1180,6 +1198,7 @@ impl Thread {
|
|||
}),
|
||||
completion_mode: Some(this.completion_mode),
|
||||
tool_use_limit_reached: this.tool_use_limit_reached,
|
||||
profile: Some(this.profile.id().clone()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -2121,7 +2140,7 @@ impl Thread {
|
|||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
let available_tools = self.tools.read(cx).enabled_tools(cx);
|
||||
let available_tools = self.profile.enabled_tools(cx);
|
||||
|
||||
let tool_list = available_tools
|
||||
.iter()
|
||||
|
@ -2213,19 +2232,15 @@ impl Thread {
|
|||
) -> Task<()> {
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
|
||||
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
|
||||
Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
|
||||
} else {
|
||||
tool.run(
|
||||
input,
|
||||
request,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
model,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
let tool_result = tool.run(
|
||||
input,
|
||||
request,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
model,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Store the card separately if it exists
|
||||
if let Some(card) = tool_result.card.clone() {
|
||||
|
@ -2344,8 +2359,7 @@ impl Thread {
|
|||
let client = self.project.read(cx).client();
|
||||
|
||||
let enabled_tool_names: Vec<String> = self
|
||||
.tools()
|
||||
.read(cx)
|
||||
.profile
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.map(|tool| tool.name())
|
||||
|
@ -2858,6 +2872,7 @@ pub enum ThreadEvent {
|
|||
ToolUseLimitReached,
|
||||
CancelEditing,
|
||||
CompletionCanceled,
|
||||
ProfileChanged,
|
||||
}
|
||||
|
||||
impl EventEmitter<ThreadEvent> for Thread {}
|
||||
|
@ -2872,7 +2887,7 @@ struct PendingCompletion {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
|
||||
use agent_settings::{AgentSettings, LanguageModelParameters};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
|
||||
use assistant_tool::ToolRegistry;
|
||||
use editor::EditorSettings;
|
||||
use gpui::TestAppContext;
|
||||
|
@ -3285,6 +3300,71 @@ fn main() {{
|
|||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, thread_store, thread, _context_store, _model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Check that we are starting with the default profile
|
||||
let profile = cx.read(|cx| thread.read(cx).profile.clone());
|
||||
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
|
||||
assert_eq!(
|
||||
profile,
|
||||
AgentProfile::new(AgentProfileId::default(), tool_set)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (_workspace, thread_store, thread, _context_store, _model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Profile gets serialized with default values
|
||||
let serialized = thread
|
||||
.update(cx, |thread, cx| thread.serialize(cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(serialized.profile, Some(AgentProfileId::default()));
|
||||
|
||||
let deserialized = cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
Thread::deserialize(
|
||||
thread.id.clone(),
|
||||
serialized,
|
||||
thread.project.clone(),
|
||||
thread.tools.clone(),
|
||||
thread.prompt_builder.clone(),
|
||||
thread.project_context.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
|
||||
|
||||
assert_eq!(
|
||||
deserialized.profile,
|
||||
AgentProfile::new(AgentProfileId::default(), tool_set)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_temperature_setting(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
|
|
@ -3,9 +3,9 @@ use std::path::{Path, PathBuf};
|
|||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, CompletionMode};
|
||||
use agent_settings::{AgentProfileId, CompletionMode};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
|
||||
use assistant_tool::{ToolId, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerId;
|
||||
|
@ -25,7 +25,6 @@ use prompt_store::{
|
|||
UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use ui::Window;
|
||||
use util::ResultExt as _;
|
||||
|
||||
|
@ -147,12 +146,7 @@ impl ThreadStore {
|
|||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (Self, oneshot::Receiver<()>) {
|
||||
let mut subscriptions = vec![
|
||||
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
|
||||
this.load_default_profile(cx);
|
||||
}),
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
];
|
||||
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
|
||||
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(
|
||||
|
@ -200,7 +194,6 @@ impl ThreadStore {
|
|||
_reload_system_prompt_task: reload_system_prompt_task,
|
||||
_subscriptions: subscriptions,
|
||||
};
|
||||
this.load_default_profile(cx);
|
||||
this.register_context_server_handlers(cx);
|
||||
this.reload(cx).detach_and_log_err(cx);
|
||||
(this, ready_rx)
|
||||
|
@ -520,86 +513,6 @@ impl ThreadStore {
|
|||
})
|
||||
}
|
||||
|
||||
fn load_default_profile(&self, cx: &mut Context<Self>) {
|
||||
let assistant_settings = AgentSettings::get_global(cx);
|
||||
|
||||
self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
|
||||
}
|
||||
|
||||
pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
|
||||
let assistant_settings = AgentSettings::get_global(cx);
|
||||
|
||||
if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
|
||||
self.load_profile(profile.clone(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.disable_all_tools(cx);
|
||||
tools.enable(
|
||||
ToolSource::Native,
|
||||
&profile
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
if profile.enable_all_context_servers {
|
||||
for context_server_id in self
|
||||
.project
|
||||
.read(cx)
|
||||
.context_server_store()
|
||||
.read(cx)
|
||||
.all_server_ids()
|
||||
{
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable_source(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.0.into(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
// Enable all the tools from all context servers, but disable the ones that are explicitly disabled
|
||||
for (context_server_id, preset) in profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.disable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| (!enabled).then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
for (context_server_id, preset) in profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
|
||||
cx.subscribe(
|
||||
&self.project.read(cx).context_server_store(),
|
||||
|
@ -618,6 +531,7 @@ impl ThreadStore {
|
|||
match event {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
ContextServerStatus::Starting => {}
|
||||
ContextServerStatus::Running => {
|
||||
if let Some(server) =
|
||||
context_server_store.read(cx).get_running_server(server_id)
|
||||
|
@ -656,10 +570,9 @@ impl ThreadStore {
|
|||
.log_err();
|
||||
|
||||
if let Some(tool_ids) = tool_ids {
|
||||
this.update(cx, |this, cx| {
|
||||
this.update(cx, |this, _| {
|
||||
this.context_server_tool_ids
|
||||
.insert(server_id, tool_ids);
|
||||
this.load_default_profile(cx);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
@ -675,10 +588,8 @@ impl ThreadStore {
|
|||
tool_working_set.update(cx, |tool_working_set, _| {
|
||||
tool_working_set.remove(&tool_ids);
|
||||
});
|
||||
self.load_default_profile(cx);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -714,6 +625,8 @@ pub struct SerializedThread {
|
|||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub tool_use_limit_reached: bool,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
@ -856,6 +769,7 @@ impl LegacySerializedThread {
|
|||
model: None,
|
||||
completion_mode: None,
|
||||
tool_use_limit_reached: false,
|
||||
profile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,30 +1,33 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, ToolWorkingSetEvent};
|
||||
use assistant_tool::{Tool, ToolSource};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
|
||||
use language_model::{LanguageModel, LanguageModelToolSchemaFormat};
|
||||
use ui::prelude::*;
|
||||
|
||||
use crate::{Thread, ThreadEvent};
|
||||
|
||||
pub struct IncompatibleToolsState {
|
||||
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
|
||||
tool_working_set: Entity<ToolWorkingSet>,
|
||||
_tool_working_set_subscription: Subscription,
|
||||
thread: Entity<Thread>,
|
||||
_thread_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl IncompatibleToolsState {
|
||||
pub fn new(tool_working_set: Entity<ToolWorkingSet>, cx: &mut Context<Self>) -> Self {
|
||||
pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self {
|
||||
let _tool_working_set_subscription =
|
||||
cx.subscribe(&tool_working_set, |this, _, event, _| match event {
|
||||
ToolWorkingSetEvent::EnabledToolsChanged => {
|
||||
cx.subscribe(&thread, |this, _, event, _| match event {
|
||||
ThreadEvent::ProfileChanged => {
|
||||
this.cache.clear();
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
|
||||
Self {
|
||||
cache: HashMap::default(),
|
||||
tool_working_set,
|
||||
_tool_working_set_subscription,
|
||||
thread,
|
||||
_thread_subscription: _tool_working_set_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,8 +39,9 @@ impl IncompatibleToolsState {
|
|||
self.cache
|
||||
.entry(model.tool_input_format())
|
||||
.or_insert_with(|| {
|
||||
self.tool_working_set
|
||||
self.thread
|
||||
.read(cx)
|
||||
.profile()
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
|
||||
|
|
|
@ -16,7 +16,6 @@ anthropic = { workspace = true, features = ["schemars"] }
|
|||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
indexmap.workspace = true
|
||||
language_model.workspace = true
|
||||
lmstudio = { workspace = true, features = ["schemars"] }
|
||||
log.workspace = true
|
||||
|
|
|
@ -17,29 +17,6 @@ pub mod builtin_profiles {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct GroupedAgentProfiles {
|
||||
pub builtin: IndexMap<AgentProfileId, AgentProfile>,
|
||||
pub custom: IndexMap<AgentProfileId, AgentProfile>,
|
||||
}
|
||||
|
||||
impl GroupedAgentProfiles {
|
||||
pub fn from_settings(settings: &crate::AgentSettings) -> Self {
|
||||
let mut builtin = IndexMap::default();
|
||||
let mut custom = IndexMap::default();
|
||||
|
||||
for (profile_id, profile) in settings.profiles.clone() {
|
||||
if builtin_profiles::is_builtin(&profile_id) {
|
||||
builtin.insert(profile_id, profile);
|
||||
} else {
|
||||
custom.insert(profile_id, profile);
|
||||
}
|
||||
}
|
||||
|
||||
Self { builtin, custom }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentProfileId(pub Arc<str>);
|
||||
|
||||
|
@ -63,7 +40,7 @@ impl Default for AgentProfileId {
|
|||
|
||||
/// A profile for the Zed Agent that controls its behavior.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentProfile {
|
||||
pub struct AgentProfileSettings {
|
||||
/// The name of the profile.
|
||||
pub name: SharedString,
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
|
|
|
@ -102,7 +102,7 @@ pub struct AgentSettings {
|
|||
pub using_outdated_settings_version: bool,
|
||||
pub default_profile: AgentProfileId,
|
||||
pub default_view: DefaultView,
|
||||
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
|
||||
pub profiles: IndexMap<AgentProfileId, AgentProfileSettings>,
|
||||
pub always_allow_tool_actions: bool,
|
||||
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
|
||||
pub play_sound_when_agent_done: bool,
|
||||
|
@ -531,7 +531,7 @@ impl AgentSettingsContent {
|
|||
pub fn create_profile(
|
||||
&mut self,
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
profile_settings: AgentProfileSettings,
|
||||
) -> Result<()> {
|
||||
self.v2_setting(|settings| {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
|
@ -542,10 +542,10 @@ impl AgentSettingsContent {
|
|||
profiles.insert(
|
||||
profile_id,
|
||||
AgentProfileContent {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: Some(profile.enable_all_context_servers),
|
||||
context_servers: profile
|
||||
name: profile_settings.name.into(),
|
||||
tools: profile_settings.tools,
|
||||
enable_all_context_servers: Some(profile_settings.enable_all_context_servers),
|
||||
context_servers: profile_settings
|
||||
.context_servers
|
||||
.into_iter()
|
||||
.map(|(server_id, preset)| {
|
||||
|
@ -910,7 +910,7 @@ impl Settings for AgentSettings {
|
|||
.extend(profiles.into_iter().map(|(id, profile)| {
|
||||
(
|
||||
id,
|
||||
AgentProfile {
|
||||
AgentProfileSettings {
|
||||
name: profile.name.into(),
|
||||
tools: profile.tools,
|
||||
enable_all_context_servers: profile
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use gpui::{App, Context, EventEmitter};
|
||||
use collections::{HashMap, IndexMap};
|
||||
use gpui::App;
|
||||
|
||||
use crate::{Tool, ToolRegistry, ToolSource};
|
||||
|
||||
|
@ -13,17 +13,9 @@ pub struct ToolId(usize);
|
|||
pub struct ToolWorkingSet {
|
||||
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
|
||||
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
|
||||
enabled_sources: HashSet<ToolSource>,
|
||||
enabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
|
||||
next_tool_id: ToolId,
|
||||
}
|
||||
|
||||
pub enum ToolWorkingSetEvent {
|
||||
EnabledToolsChanged,
|
||||
}
|
||||
|
||||
impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
|
||||
|
||||
impl ToolWorkingSet {
|
||||
pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
|
||||
self.context_server_tools_by_name
|
||||
|
@ -57,42 +49,6 @@ impl ToolWorkingSet {
|
|||
tools_by_source
|
||||
}
|
||||
|
||||
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
|
||||
let all_tools = self.tools(cx);
|
||||
|
||||
all_tools
|
||||
.into_iter()
|
||||
.filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
|
||||
self.enabled_tools_by_source.clear();
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
|
||||
self.enabled_sources.insert(source.clone());
|
||||
|
||||
let tools_by_source = self.tools_by_source(cx);
|
||||
if let Some(tools) = tools_by_source.get(&source) {
|
||||
self.enabled_tools_by_source.insert(
|
||||
source,
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| tool.name().into())
|
||||
.collect::<HashSet<_>>(),
|
||||
);
|
||||
}
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
|
||||
self.enabled_sources.remove(source);
|
||||
self.enabled_tools_by_source.remove(source);
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
|
||||
let tool_id = self.next_tool_id;
|
||||
self.next_tool_id.0 += 1;
|
||||
|
@ -102,42 +58,6 @@ impl ToolWorkingSet {
|
|||
tool_id
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
self.enabled_tools_by_source
|
||||
.get(source)
|
||||
.map_or(false, |enabled_tools| enabled_tools.contains(name))
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
|
||||
!self.is_enabled(source, name)
|
||||
}
|
||||
|
||||
pub fn enable(
|
||||
&mut self,
|
||||
source: ToolSource,
|
||||
tools_to_enable: &[Arc<str>],
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.extend(tools_to_enable.into_iter().cloned());
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn disable(
|
||||
&mut self,
|
||||
source: ToolSource,
|
||||
tools_to_disable: &[Arc<str>],
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.enabled_tools_by_source
|
||||
.entry(source)
|
||||
.or_default()
|
||||
.retain(|name| !tools_to_disable.contains(name));
|
||||
cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
|
||||
self.context_server_tools_by_id
|
||||
.retain(|id, _| !tool_ids_to_remove.contains(id));
|
||||
|
|
|
@ -80,7 +80,6 @@ zed_llm_client.workspace = true
|
|||
agent_settings.workspace = true
|
||||
assistant_context_editor.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-trait.workspace = true
|
||||
audio.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
|
|
|
@ -294,6 +294,7 @@ impl ExampleContext {
|
|||
| ThreadEvent::MessageDeleted(_)
|
||||
| ThreadEvent::SummaryChanged
|
||||
| ThreadEvent::SummaryGenerated
|
||||
| ThreadEvent::ProfileChanged
|
||||
| ThreadEvent::ReceivedTextChunk
|
||||
| ThreadEvent::StreamedToolUse { .. }
|
||||
| ThreadEvent::CheckpointChanged
|
||||
|
|
|
@ -306,17 +306,19 @@ impl ExampleInstance {
|
|||
|
||||
let thread_store = thread_store.await?;
|
||||
|
||||
let profile_id = meta.profile_id.clone();
|
||||
thread_store.update(cx, |thread_store, cx| thread_store.load_profile_by_id(profile_id, cx)).expect("Failed to load profile");
|
||||
|
||||
let thread =
|
||||
thread_store.update(cx, |thread_store, cx| {
|
||||
if let Some(json) = &meta.existing_thread_json {
|
||||
let thread = if let Some(json) = &meta.existing_thread_json {
|
||||
let serialized = SerializedThread::from_json(json.as_bytes()).expect("Can't read serialized thread");
|
||||
thread_store.create_thread_from_serialized(serialized, cx)
|
||||
} else {
|
||||
thread_store.create_thread(cx)
|
||||
}
|
||||
};
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(meta.profile_id.clone(), cx);
|
||||
});
|
||||
thread
|
||||
})?;
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue