Associate each thread with a model (#29573)

This PR makes it possible to use different LLM models in the agent
panels of two different projects, simultaneously. It also properly
restores a thread's original model when restoring it from the history,
rather than having it use the default model. As before, newly-created
threads will use the current default model.

Release Notes:

- Enabled different project windows to use different models in the agent
panel
- Enhanced the agent panel so that when revisiting old threads, their
original model will be used.

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2025-04-28 16:43:16 -07:00 committed by GitHub
parent 5102c4c002
commit 17903a0999
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 168 additions and 114 deletions

View file

@ -22,8 +22,8 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::Project;
@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
use crate::ThreadStore;
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
SerializedToolResult, SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
@ -332,6 +332,7 @@ pub struct Thread {
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -351,6 +352,8 @@ impl Thread {
cx: &mut Context<Self>,
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@ -388,6 +391,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
}
}
@ -411,6 +415,19 @@ impl Thread {
let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state);
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
serialized
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
});
Self {
id,
updated_at: serialized.updated_at,
@ -468,6 +485,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
}
}
@ -507,6 +525,22 @@ impl Thread {
self.project_context.clone()
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
}
pub fn configured_model(&self) -> Option<ConfiguredModel> {
self.configured_model.clone()
}
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
self.configured_model = model;
cx.notify();
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@ -952,6 +986,13 @@ impl Thread {
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
model: this
.configured_model
.as_ref()
.map(|model| SerializedLanguageModel {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
})
})
}
@ -1733,7 +1774,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
self.configured_model.as_ref(),
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
@ -1808,7 +1849,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
output,
cx,
thread.configured_model.as_ref(),
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
@ -1826,10 +1867,9 @@ impl Thread {
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
self.send_to_model(model, window, cx);
self.send_to_model(model.clone(), window, cx);
}
self.auto_capture_telemetry(cx);
}
@ -2254,8 +2294,8 @@ impl Thread {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@ -2283,9 +2323,8 @@ impl Thread {
}
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
pub fn total_token_usage(&self) -> TotalTokenUsage {
let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@ -2336,8 +2375,12 @@ impl Thread {
"Permission to run tool action denied by user"
));
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
err,
self.configured_model.as_ref(),
);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}
}
@ -2769,6 +2812,7 @@ fn main() {{
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);