Load Profile state from Thread and tie visibility to the thread's model (#30090)
When deciding if a model supports tools or not, we weren't reading from the configured model in a given thread. This also stores the profile on the thread, which matches the behavior of the Model and Max Mode, which we also already store per thread. Hopefully this helps alleviate some confusion. Release Notes: - agent: Save profile selection per-Agent thread
This commit is contained in:
parent
02ed4aefb8
commit
3615d6d96c
4 changed files with 67 additions and 23 deletions
|
@ -199,6 +199,10 @@ impl MessageEditor {
|
|||
)
|
||||
});
|
||||
|
||||
let profile_selector = cx.new(|cx| {
|
||||
ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx)
|
||||
});
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
|
@ -215,8 +219,7 @@ impl MessageEditor {
|
|||
model_selector,
|
||||
edits_expanded: false,
|
||||
editor_is_expanded: false,
|
||||
profile_selector: cx
|
||||
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
|
||||
profile_selector,
|
||||
last_estimated_token_count: None,
|
||||
update_token_count_task: None,
|
||||
_subscriptions: subscriptions,
|
||||
|
|
|
@ -1,24 +1,21 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{
|
||||
AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles,
|
||||
builtin_profiles,
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use ui::{
|
||||
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
|
||||
prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{ManageProfiles, ThreadStore, ToggleProfileSelector};
|
||||
use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
|
||||
|
||||
pub struct ProfileSelector {
|
||||
profiles: GroupedAgentProfiles,
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
focus_handle: FocusHandle,
|
||||
|
@ -27,7 +24,7 @@ pub struct ProfileSelector {
|
|||
|
||||
impl ProfileSelector {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
cx: &mut Context<Self>,
|
||||
|
@ -38,7 +35,7 @@ impl ProfileSelector {
|
|||
|
||||
Self {
|
||||
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
|
||||
fs,
|
||||
thread,
|
||||
thread_store,
|
||||
menu_handle: PopoverMenuHandle::default(),
|
||||
focus_handle,
|
||||
|
@ -113,15 +110,15 @@ impl ProfileSelector {
|
|||
};
|
||||
|
||||
entry.handler({
|
||||
let fs = self.fs.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let profile_id = profile_id.clone();
|
||||
let profile = profile.clone();
|
||||
|
||||
let thread = self.thread.clone();
|
||||
|
||||
move |_window, cx| {
|
||||
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
|
||||
let profile_id = profile_id.clone();
|
||||
move |settings, _cx| {
|
||||
settings.set_profile(profile_id.clone());
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_configured_profile(Some(profile.clone()), cx);
|
||||
});
|
||||
|
||||
thread_store
|
||||
|
@ -137,17 +134,28 @@ impl ProfileSelector {
|
|||
impl Render for ProfileSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
let profile = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.configured_profile())
|
||||
.or_else(|| {
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
profile.cloned()
|
||||
});
|
||||
|
||||
let selected_profile = profile
|
||||
.map(|profile| profile.name.clone())
|
||||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let supports_tools = model_registry
|
||||
.default_model()
|
||||
.map_or(false, |default| default.model.supports_tools());
|
||||
let configured_model = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.configured_model())
|
||||
.or_else(|| {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
model_registry.default_model()
|
||||
});
|
||||
let supports_tools =
|
||||
configured_model.map_or(false, |default| default.model.supports_tools());
|
||||
|
||||
if supports_tools {
|
||||
let this = cx.entity().clone();
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::sync::Arc;
|
|||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_settings::{AssistantSettings, CompletionMode};
|
||||
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
|
@ -359,6 +359,7 @@ pub struct Thread {
|
|||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
configured_profile: Option<AgentProfile>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -379,6 +380,9 @@ impl Thread {
|
|||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
let profile_id = &assistant_settings.default_profile;
|
||||
let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
|
@ -421,6 +425,7 @@ impl Thread {
|
|||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -468,6 +473,13 @@ impl Thread {
|
|||
.completion_mode
|
||||
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
|
||||
|
||||
let configured_profile = serialized.profile.and_then(|profile| {
|
||||
AssistantSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&profile)
|
||||
.cloned()
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
updated_at: serialized.updated_at,
|
||||
|
@ -541,6 +553,7 @@ impl Thread {
|
|||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -596,6 +609,19 @@ impl Thread {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn configured_profile(&self) -> Option<AgentProfile> {
|
||||
self.configured_profile.clone()
|
||||
}
|
||||
|
||||
pub fn set_configured_profile(
|
||||
&mut self,
|
||||
profile: Option<AgentProfile>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.configured_profile = profile;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
|
@ -1100,6 +1126,10 @@ impl Thread {
|
|||
provider: model.provider.id().0.to_string(),
|
||||
model: model.model.id().0.to_string(),
|
||||
}),
|
||||
profile: this
|
||||
.configured_profile
|
||||
.as_ref()
|
||||
.map(|profile| AgentProfileId(profile.name.clone().into())),
|
||||
completion_mode: Some(this.completion_mode),
|
||||
})
|
||||
})
|
||||
|
|
|
@ -657,6 +657,8 @@ pub struct SerializedThread {
|
|||
pub model: Option<SerializedLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
@ -802,6 +804,7 @@ impl LegacySerializedThread {
|
|||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
profile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue