Associate each thread with a model (#29573)
This PR makes it possible to use different LLM models in the agent panels of two different projects, simultaneously. It also properly restores a thread's original model when restoring it from the history, rather than having it use the default model. As before, newly-created threads will use the current default model. Release Notes: - Enabled different project windows to use different models in the agent panel - Enhanced the agent panel so that when revisiting old threads, their original model will be used. --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
5102c4c002
commit
17903a0999
15 changed files with 168 additions and 114 deletions
|
@ -22,8 +22,8 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
TokenUsage,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use postage::stream::Stream as _;
|
||||
use project::Project;
|
||||
|
@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
|
|||
use crate::ThreadStore;
|
||||
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse, SharedProjectContext,
|
||||
SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
|
||||
SerializedToolResult, SerializedToolUse, SharedProjectContext,
|
||||
};
|
||||
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
|
||||
|
||||
|
@ -332,6 +332,7 @@ pub struct Thread {
|
|||
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
|
||||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
@ -351,6 +352,8 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
updated_at: Utc::now(),
|
||||
|
@ -388,6 +391,7 @@ impl Thread {
|
|||
last_auto_capture_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -411,6 +415,19 @@ impl Thread {
|
|||
let (detailed_summary_tx, detailed_summary_rx) =
|
||||
postage::watch::channel_with(serialized.detailed_summary_state);
|
||||
|
||||
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
serialized
|
||||
.model
|
||||
.and_then(|model| {
|
||||
let model = SelectedModel {
|
||||
provider: model.provider.clone().into(),
|
||||
model: model.model.clone().into(),
|
||||
};
|
||||
registry.select_model(&model, cx)
|
||||
})
|
||||
.or_else(|| registry.default_model())
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
updated_at: serialized.updated_at,
|
||||
|
@ -468,6 +485,7 @@ impl Thread {
|
|||
last_auto_capture_at: None,
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -507,6 +525,22 @@ impl Thread {
|
|||
self.project_context.clone()
|
||||
}
|
||||
|
||||
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
|
||||
if self.configured_model.is_none() {
|
||||
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
}
|
||||
self.configured_model.clone()
|
||||
}
|
||||
|
||||
pub fn configured_model(&self) -> Option<ConfiguredModel> {
|
||||
self.configured_model.clone()
|
||||
}
|
||||
|
||||
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
|
||||
self.configured_model = model;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
|
@ -952,6 +986,13 @@ impl Thread {
|
|||
request_token_usage: this.request_token_usage.clone(),
|
||||
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
|
||||
exceeded_window_error: this.exceeded_window_error.clone(),
|
||||
model: this
|
||||
.configured_model
|
||||
.as_ref()
|
||||
.map(|model| SerializedLanguageModel {
|
||||
provider: model.provider.id().0.to_string(),
|
||||
model: model.model.id().0.to_string(),
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -1733,7 +1774,7 @@ impl Thread {
|
|||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
|
@ -1808,7 +1849,7 @@ impl Thread {
|
|||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
output,
|
||||
cx,
|
||||
thread.configured_model.as_ref(),
|
||||
);
|
||||
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
})
|
||||
|
@ -1826,10 +1867,9 @@ impl Thread {
|
|||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
|
||||
if !canceled {
|
||||
self.send_to_model(model, window, cx);
|
||||
self.send_to_model(model.clone(), window, cx);
|
||||
}
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
|
@ -2254,8 +2294,8 @@ impl Thread {
|
|||
self.cumulative_token_usage
|
||||
}
|
||||
|
||||
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
|
||||
let Some(model) = self.configured_model.as_ref() else {
|
||||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
|
@ -2283,9 +2323,8 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(model) = model_registry.default_model() else {
|
||||
pub fn total_token_usage(&self) -> TotalTokenUsage {
|
||||
let Some(model) = self.configured_model.as_ref() else {
|
||||
return TotalTokenUsage::default();
|
||||
};
|
||||
|
||||
|
@ -2336,8 +2375,12 @@ impl Thread {
|
|||
"Permission to run tool action denied by user"
|
||||
));
|
||||
|
||||
self.tool_use
|
||||
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
||||
self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
err,
|
||||
self.configured_model.as_ref(),
|
||||
);
|
||||
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
||||
}
|
||||
}
|
||||
|
@ -2769,6 +2812,7 @@ fn main() {{
|
|||
prompt_store::init(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue