Store profile per thread (#31907)

This allows storing the profile per thread, as well as moving the logic
of which tools are enabled or not to the profile itself.

This makes it much easier to switch between profiles, means there is
less global state being changed on every profile change.

Release Notes:

- agent panel: allow saving the profile per thread

---------

Co-authored-by: Ben Kunkle <ben.kunkle@gmail.com>
This commit is contained in:
Ben Brandt 2025-06-06 14:05:27 +02:00 committed by GitHub
parent 7afee64119
commit 709523bf36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 556 additions and 369 deletions

View file

@ -4,7 +4,7 @@ use std::ops::Range;
use std::sync::Arc;
use std::time::Instant;
use agent_settings::{AgentSettings, CompletionMode};
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
@ -41,6 +41,7 @@ use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus};
use crate::ThreadStore;
use crate::agent_profile::AgentProfile;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment,
@ -360,6 +361,7 @@ pub struct Thread {
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
profile: AgentProfile,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@ -407,6 +409,7 @@ impl Thread {
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
Self {
id: ThreadId::new(),
@ -449,6 +452,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
profile: AgentProfile::new(profile_id, tools),
}
}
@ -495,6 +499,9 @@ impl Thread {
let completion_mode = serialized
.completion_mode
.unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode);
let profile_id = serialized
.profile
.unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone());
Self {
id,
@ -554,7 +561,7 @@ impl Thread {
pending_checkpoint: None,
project: project.clone(),
prompt_builder,
tools,
tools: tools.clone(),
tool_use,
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
@ -570,6 +577,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
profile: AgentProfile::new(profile_id, tools),
}
}
@ -585,6 +593,17 @@ impl Thread {
&self.id
}
pub fn profile(&self) -> &AgentProfile {
&self.profile
}
pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context<Self>) {
if &id != self.profile.id() {
self.profile = AgentProfile::new(id, self.tools.clone());
cx.emit(ThreadEvent::ProfileChanged);
}
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
@ -919,8 +938,7 @@ impl Thread {
model: Arc<dyn LanguageModel>,
) -> Vec<LanguageModelRequestTool> {
if model.supports_tools() {
self.tools()
.read(cx)
self.profile
.enabled_tools(cx)
.into_iter()
.filter_map(|tool| {
@ -1180,6 +1198,7 @@ impl Thread {
}),
completion_mode: Some(this.completion_mode),
tool_use_limit_reached: this.tool_use_limit_reached,
profile: Some(this.profile.id().clone()),
})
})
}
@ -2121,7 +2140,7 @@ impl Thread {
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
let available_tools = self.tools.read(cx).enabled_tools(cx);
let available_tools = self.profile.enabled_tools(cx);
let tool_list = available_tools
.iter()
@ -2213,19 +2232,15 @@ impl Thread {
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
} else {
tool.run(
input,
request,
self.project.clone(),
self.action_log.clone(),
model,
window,
cx,
)
};
let tool_result = tool.run(
input,
request,
self.project.clone(),
self.action_log.clone(),
model,
window,
cx,
);
// Store the card separately if it exists
if let Some(card) = tool_result.card.clone() {
@ -2344,8 +2359,7 @@ impl Thread {
let client = self.project.read(cx).client();
let enabled_tool_names: Vec<String> = self
.tools()
.read(cx)
.profile
.enabled_tools(cx)
.iter()
.map(|tool| tool.name())
@ -2858,6 +2872,7 @@ pub enum ThreadEvent {
ToolUseLimitReached,
CancelEditing,
CompletionCanceled,
ProfileChanged,
}
impl EventEmitter<ThreadEvent> for Thread {}
@ -2872,7 +2887,7 @@ struct PendingCompletion {
mod tests {
use super::*;
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
use agent_settings::{AgentSettings, LanguageModelParameters};
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters};
use assistant_tool::ToolRegistry;
use editor::EditorSettings;
use gpui::TestAppContext;
@ -3285,6 +3300,71 @@ fn main() {{
);
}
#[gpui::test]
async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, thread_store, thread, _context_store, _model) =
setup_test_environment(cx, project.clone()).await;
// Check that we are starting with the default profile
let profile = cx.read(|cx| thread.read(cx).profile.clone());
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
assert_eq!(
profile,
AgentProfile::new(AgentProfileId::default(), tool_set)
);
}
#[gpui::test]
async fn test_serializing_thread_profile(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(
cx,
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let (_workspace, thread_store, thread, _context_store, _model) =
setup_test_environment(cx, project.clone()).await;
// Profile gets serialized with default values
let serialized = thread
.update(cx, |thread, cx| thread.serialize(cx))
.await
.unwrap();
assert_eq!(serialized.profile, Some(AgentProfileId::default()));
let deserialized = cx.update(|cx| {
thread.update(cx, |thread, cx| {
Thread::deserialize(
thread.id.clone(),
serialized,
thread.project.clone(),
thread.tools.clone(),
thread.prompt_builder.clone(),
thread.project_context.clone(),
None,
cx,
)
})
});
let tool_set = cx.read(|cx| thread_store.read(cx).tools());
assert_eq!(
deserialized.profile,
AgentProfile::new(AgentProfileId::default(), tool_set)
);
}
#[gpui::test]
async fn test_temperature_setting(cx: &mut TestAppContext) {
init_test_settings(cx);