Associate each thread with a model (#29573)

This PR makes it possible to use different LLM models in the agent
panels of two different projects, simultaneously. It also properly
restores a thread's original model when restoring it from the history,
rather than having it use the default model. As before, newly-created
threads will use the current default model.

Release Notes:

- Enabled different project windows to use different models in the agent
panel
- Enhanced the agent panel so that when revisiting old threads, their
original model will be used.

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2025-04-28 16:43:16 -07:00 committed by GitHub
parent 5102c4c002
commit 17903a0999
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 168 additions and 114 deletions

View file

@ -25,8 +25,8 @@ use gpui::{
}; };
use language::{Buffer, LanguageRegistry}; use language::{Buffer, LanguageRegistry};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role,
RequestUsage, Role, StopReason, StopReason,
}; };
use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@ -1252,7 +1252,7 @@ impl ActiveThread {
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take(); state._update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { let Some(configured_model) = self.thread.read(cx).configured_model() else {
state.last_estimated_token_count.take(); state.last_estimated_token_count.take();
return; return;
}; };
@ -1305,7 +1305,7 @@ impl ActiveThread {
temperature: None, temperature: None,
}; };
Some(default_model.model.count_tokens(request, cx)) Some(configured_model.model.count_tokens(request, cx))
})? { })? {
task.await? task.await?
} else { } else {
@ -1338,7 +1338,7 @@ impl ActiveThread {
return; return;
}; };
let edited_text = state.editor.read(cx).text(cx); let edited_text = state.editor.read(cx).text(cx);
self.thread.update(cx, |thread, cx| { let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message( thread.edit_message(
message_id, message_id,
Role::User, Role::User,
@ -1348,9 +1348,10 @@ impl ActiveThread {
for message_id in self.messages_after(message_id) { for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx); thread.delete_message(*message_id, cx);
} }
thread.get_or_init_configured_model(cx)
}); });
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { let Some(model) = thread_model else {
return; return;
}; };

View file

@ -951,6 +951,7 @@ mod tests {
ThemeSettings::register(cx); ThemeSettings::register(cx);
ContextServerSettings::register(cx); ContextServerSettings::register(cx);
EditorSettings::register(cx); EditorSettings::register(cx);
language_model::init_settings(cx);
}); });
let fs = FakeFs::new(cx.executor()); let fs = FakeFs::new(cx.executor());

View file

@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
use fs::Fs; use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString}; use gpui::{Entity, FocusHandle, SharedString};
use crate::Thread;
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::{ use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
}; };
@ -9,7 +11,11 @@ use settings::update_settings_file;
use std::sync::Arc; use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
pub use language_model_selector::ModelType; #[derive(Clone)]
pub enum ModelType {
Default(Entity<Thread>),
InlineAssistant,
}
pub struct AssistantModelSelector { pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>, selector: Entity<LanguageModelSelector>,
@ -24,18 +30,39 @@ impl AssistantModelSelector {
focus_handle: FocusHandle, focus_handle: FocusHandle,
model_type: ModelType, model_type: ModelType,
window: &mut Window, window: &mut Window,
cx: &mut App, cx: &mut Context<Self>,
) -> Self { ) -> Self {
Self { Self {
selector: cx.new(|cx| { selector: cx.new(move |cx| {
let fs = fs.clone(); let fs = fs.clone();
LanguageModelSelector::new( LanguageModelSelector::new(
{
let model_type = model_type.clone();
move |cx| match &model_type {
ModelType::Default(thread) => thread.read(cx).configured_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
},
move |model, cx| { move |model, cx| {
let provider = model.provider_id().0.to_string(); let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string(); let model_id = model.id().0.to_string();
match &model_type {
match model_type { ModelType::Default(thread) => {
ModelType::Default => { thread.update(cx, |thread, cx| {
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id())
{
thread.set_configured_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
}),
cx,
);
}
});
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
fs.clone(), fs.clone(),
cx, cx,
@ -58,7 +85,6 @@ impl AssistantModelSelector {
} }
} }
}, },
model_type,
window, window,
cx, cx,
) )

View file

@ -1274,12 +1274,12 @@ impl AssistantPanel {
let is_generating = thread.is_generating(); let is_generating = thread.is_generating();
let message_editor = self.message_editor.read(cx); let message_editor = self.message_editor.read(cx);
let conversation_token_usage = thread.total_token_usage(cx); let conversation_token_usage = thread.total_token_usage();
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) = let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
self.thread.read(cx).editing_message_id() self.thread.read(cx).editing_message_id()
{ {
let combined = thread let combined = thread
.token_usage_up_to_message(editing_message_id, cx) .token_usage_up_to_message(editing_message_id)
.add(unsent_tokens); .add(unsent_tokens);
(combined, unsent_tokens > 0) (combined, unsent_tokens > 0)

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector; use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen; use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker; use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point, Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
}; };
use language_model::{LanguageModel, LanguageModelRegistry}; use language_model::{LanguageModel, LanguageModelRegistry};
use language_model_selector::{ModelType, ToggleModelSelector}; use language_model_selector::ToggleModelSelector;
use parking_lot::Mutex; use parking_lot::Mutex;
use settings::Settings; use settings::Settings;
use std::cmp; use std::cmp;

View file

@ -1,7 +1,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use crate::assistant_model_selector::ModelType; use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context}; use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff; use buffer_diff::BufferDiff;
@ -21,9 +21,7 @@ use gpui::{
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
}; };
use language::{Buffer, Language}; use language::{Buffer, Language};
use language_model::{ use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
};
use language_model_selector::ToggleModelSelector; use language_model_selector::ToggleModelSelector;
use multi_buffer; use multi_buffer;
use project::Project; use project::Project;
@ -36,7 +34,6 @@ use util::ResultExt as _;
use workspace::Workspace; use workspace::Workspace;
use zed_llm_client::CompletionMode; use zed_llm_client::CompletionMode;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@ -153,6 +150,17 @@ impl MessageEditor {
}), }),
]; ];
let model_selector = cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default(thread.clone()),
window,
cx,
)
});
Self { Self {
editor: editor.clone(), editor: editor.clone(),
project: thread.read(cx).project().clone(), project: thread.read(cx).project().clone(),
@ -165,16 +173,7 @@ impl MessageEditor {
context_picker_menu_handle, context_picker_menu_handle,
load_context_task: None, load_context_task: None,
last_loaded_context: None, last_loaded_context: None,
model_selector: cx.new(|cx| { model_selector,
AssistantModelSelector::new(
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default,
window,
cx,
)
}),
edits_expanded: false, edits_expanded: false,
editor_is_expanded: false, editor_is_expanded: false,
profile_selector: cx profile_selector: cx
@ -263,15 +262,11 @@ impl MessageEditor {
self.editor.read(cx).text(cx).trim().is_empty() self.editor.read(cx).text(cx).trim().is_empty()
} }
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.default_model()
.is_some()
}
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) { fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let model_registry = LanguageModelRegistry::read_global(cx); let Some(ConfiguredModel { model, provider }) = self
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { .thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return; return;
}; };
@ -408,14 +403,13 @@ impl MessageEditor {
return None; return None;
} }
let model = LanguageModelRegistry::read_global(cx) let thread = self.thread.read(cx);
.default_model() let model = thread.configured_model();
.map(|default| default.model.clone())?; if !model?.model.supports_max_mode() {
if !model.supports_max_mode() {
return None; return None;
} }
let active_completion_mode = self.thread.read(cx).completion_mode(); let active_completion_mode = thread.completion_mode();
Some( Some(
IconButton::new("max-mode", IconName::SquarePlus) IconButton::new("max-mode", IconName::SquarePlus)
@ -442,24 +436,21 @@ impl MessageEditor {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Div { ) -> Div {
let thread = self.thread.read(cx); let thread = self.thread.read(cx);
let model = thread.configured_model();
let editor_bg_color = cx.theme().colors().editor_background; let editor_bg_color = cx.theme().colors().editor_background;
let is_generating = thread.is_generating(); let is_generating = thread.is_generating();
let focus_handle = self.editor.focus_handle(cx); let focus_handle = self.editor.focus_handle(cx);
let is_model_selected = self.is_model_selected(cx); let is_model_selected = model.is_some();
let is_editor_empty = self.is_editor_empty(cx); let is_editor_empty = self.is_editor_empty(cx);
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model.clone());
let incompatible_tools = model let incompatible_tools = model
.as_ref() .as_ref()
.map(|model| { .map(|model| {
self.incompatible_tools_state.update(cx, |state, cx| { self.incompatible_tools_state.update(cx, |state, cx| {
state state
.incompatible_tools(model, cx) .incompatible_tools(&model.model, cx)
.iter() .iter()
.cloned() .cloned()
.collect::<Vec<_>>() .collect::<Vec<_>>()
@ -1058,7 +1049,7 @@ impl MessageEditor {
cx.emit(MessageEditorEvent::Changed); cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take(); self.update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { let Some(model) = self.thread.read(cx).configured_model() else {
self.last_estimated_token_count.take(); self.last_estimated_token_count.take();
return; return;
}; };
@ -1111,7 +1102,7 @@ impl MessageEditor {
temperature: None, temperature: None,
}; };
Some(default_model.model.count_tokens(request, cx)) Some(model.model.count_tokens(request, cx))
})? { })? {
task.await? task.await?
} else { } else {
@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor {
impl Render for MessageEditor { impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx); let thread = self.thread.read(cx);
let total_token_usage = thread.total_token_usage(cx); let total_token_usage = thread.total_token_usage();
let token_usage_ratio = total_token_usage.ratio(); let token_usage_ratio = total_token_usage.ratio();
let action_log = self.thread.read(cx).action_log(); let action_log = self.thread.read(cx).action_log();

View file

@ -22,8 +22,8 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
TokenUsage, StopReason, TokenUsage,
}; };
use postage::stream::Stream as _; use postage::stream::Stream as _;
use project::Project; use project::Project;
@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
use crate::ThreadStore; use crate::ThreadStore;
use crate::context::{AgentContext, ContextLoadResult, LoadedContext}; use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{ use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
SerializedToolUse, SharedProjectContext, SerializedToolResult, SerializedToolUse, SharedProjectContext,
}; };
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
@ -332,6 +332,7 @@ pub struct Thread {
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>, Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>, >,
remaining_turns: u32, remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -351,6 +352,8 @@ impl Thread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
Self { Self {
id: ThreadId::new(), id: ThreadId::new(),
updated_at: Utc::now(), updated_at: Utc::now(),
@ -388,6 +391,7 @@ impl Thread {
last_auto_capture_at: None, last_auto_capture_at: None,
request_callback: None, request_callback: None,
remaining_turns: u32::MAX, remaining_turns: u32::MAX,
configured_model,
} }
} }
@ -411,6 +415,19 @@ impl Thread {
let (detailed_summary_tx, detailed_summary_rx) = let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state); postage::watch::channel_with(serialized.detailed_summary_state);
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
serialized
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
});
Self { Self {
id, id,
updated_at: serialized.updated_at, updated_at: serialized.updated_at,
@ -468,6 +485,7 @@ impl Thread {
last_auto_capture_at: None, last_auto_capture_at: None,
request_callback: None, request_callback: None,
remaining_turns: u32::MAX, remaining_turns: u32::MAX,
configured_model,
} }
} }
@ -507,6 +525,22 @@ impl Thread {
self.project_context.clone() self.project_context.clone()
} }
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
}
pub fn configured_model(&self) -> Option<ConfiguredModel> {
self.configured_model.clone()
}
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
self.configured_model = model;
cx.notify();
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString { pub fn summary_or_default(&self) -> SharedString {
@ -952,6 +986,13 @@ impl Thread {
request_token_usage: this.request_token_usage.clone(), request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_rx.borrow().clone(), detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
exceeded_window_error: this.exceeded_window_error.clone(), exceeded_window_error: this.exceeded_window_error.clone(),
model: this
.configured_model
.as_ref()
.map(|model| SerializedLanguageModel {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
}) })
}) })
} }
@ -1733,7 +1774,7 @@ impl Thread {
tool_use_id.clone(), tool_use_id.clone(),
tool_name, tool_name,
Err(anyhow!("Error parsing input JSON: {error}")), Err(anyhow!("Error parsing input JSON: {error}")),
cx, self.configured_model.as_ref(),
); );
let ui_text = if let Some(pending_tool_use) = &pending_tool_use { let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone() pending_tool_use.ui_text.clone()
@ -1808,7 +1849,7 @@ impl Thread {
tool_use_id.clone(), tool_use_id.clone(),
tool_name, tool_name,
output, output,
cx, thread.configured_model.as_ref(),
); );
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
}) })
@ -1826,10 +1867,9 @@ impl Thread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if self.all_tools_finished() { if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx); if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
if !canceled { if !canceled {
self.send_to_model(model, window, cx); self.send_to_model(model.clone(), window, cx);
} }
self.auto_capture_telemetry(cx); self.auto_capture_telemetry(cx);
} }
@ -2254,8 +2294,8 @@ impl Thread {
self.cumulative_token_usage self.cumulative_token_usage
} }
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage { pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default(); return TotalTokenUsage::default();
}; };
@ -2283,9 +2323,8 @@ impl Thread {
} }
} }
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { pub fn total_token_usage(&self) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = self.configured_model.as_ref() else {
let Some(model) = model_registry.default_model() else {
return TotalTokenUsage::default(); return TotalTokenUsage::default();
}; };
@ -2336,8 +2375,12 @@ impl Thread {
"Permission to run tool action denied by user" "Permission to run tool action denied by user"
)); ));
self.tool_use self.tool_use.insert_tool_output(
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx); tool_use_id.clone(),
tool_name,
err,
self.configured_model.as_ref(),
);
self.tool_finished(tool_use_id.clone(), None, true, window, cx); self.tool_finished(tool_use_id.clone(), None, true, window, cx);
} }
} }
@ -2769,6 +2812,7 @@ fn main() {{
prompt_store::init(cx); prompt_store::init(cx);
thread_store::init(cx); thread_store::init(cx);
workspace::init_settings(cx); workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx); ThemeSettings::register(cx);
ContextServerSettings::register(cx); ContextServerSettings::register(cx);
EditorSettings::register(cx); EditorSettings::register(cx);

View file

@ -640,6 +640,14 @@ pub struct SerializedThread {
pub detailed_summary_state: DetailedSummaryState, pub detailed_summary_state: DetailedSummaryState,
#[serde(default)] #[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>, pub exceeded_window_error: Option<ExceededWindowError>,
#[serde(default)]
pub model: Option<SerializedLanguageModel>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
} }
impl SerializedThread { impl SerializedThread {
@ -774,6 +782,7 @@ impl LegacySerializedThread {
request_token_usage: Vec::new(), request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(), detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None, exceeded_window_error: None,
model: None,
} }
} }
} }

View file

@ -7,7 +7,7 @@ use futures::FutureExt as _;
use futures::future::Shared; use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task}; use gpui::{App, Entity, SharedString, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult, ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
}; };
use ui::IconName; use ui::IconName;
@ -353,7 +353,7 @@ impl ToolUseState {
tool_use_id: LanguageModelToolUseId, tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>, tool_name: Arc<str>,
output: Result<String>, output: Result<String>,
cx: &App, configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> { ) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@ -373,13 +373,10 @@ impl ToolUseState {
match output { match output {
Ok(tool_result) => { Ok(tool_result) => {
let model_registry = LanguageModelRegistry::read_global(cx);
const BYTES_PER_TOKEN_ESTIMATE: usize = 3; const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output // Protect from clearly large output
let tool_output_limit = model_registry let tool_output_limit = configured_model
.default_model()
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE) .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX); .unwrap_or(usize::MAX);

View file

@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event, LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
}; };
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use parking_lot::Mutex; use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction}; use project::{CodeAction, LspAction, ProjectTransaction};
@ -1759,6 +1759,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| { language_model_selector: cx.new(|cx| {
let fs = fs.clone(); let fs = fs.clone();
LanguageModelSelector::new( LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| { move |model, cx| {
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
fs.clone(), fs.clone(),
@ -1766,7 +1767,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event, Role, report_assistant_event,
}; };
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder; use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
use std::{ use std::{
@ -749,6 +749,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| { language_model_selector: cx.new(|cx| {
let fs = fs.clone(); let fs = fs.clone();
LanguageModelSelector::new( LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| { move |model, cx| {
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
fs.clone(), fs.clone(),
@ -756,7 +757,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -39,7 +39,7 @@ use language_model::{
Role, Role,
}; };
use language_model_selector::{ use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector, LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
}; };
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use picker::Picker; use picker::Picker;
@ -291,6 +291,7 @@ impl ContextEditor {
dragged_file_worktrees: Vec::new(), dragged_file_worktrees: Vec::new(),
language_model_selector: cx.new(|cx| { language_model_selector: cx.new(|cx| {
LanguageModelSelector::new( LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| { move |model, cx| {
update_settings_file::<AssistantSettings>( update_settings_file::<AssistantSettings>(
fs.clone(), fs.clone(),
@ -298,7 +299,6 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -39,10 +39,14 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub fn init(client: Arc<Client>, cx: &mut App) { pub fn init(client: Arc<Client>, cx: &mut App) {
registry::init(cx); init_settings(cx);
RefreshLlmTokenListener::register(client.clone(), cx); RefreshLlmTokenListener::register(client.clone(), cx);
} }
pub fn init_settings(cx: &mut App) {
registry::init(cx);
}
/// The availability of a [`LanguageModel`]. /// The availability of a [`LanguageModel`].
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LanguageModelAvailability { pub enum LanguageModelAvailability {

View file

@ -188,7 +188,7 @@ impl LanguageModelRegistry {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
} }
fn select_model( pub fn select_model(
&mut self, &mut self,
selected_model: &SelectedModel, selected_model: &SelectedModel,
cx: &mut Context<Self>, cx: &mut Context<Self>,

View file

@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>; type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub struct LanguageModelSelector { pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>, picker: Entity<Picker<LanguageModelPickerDelegate>>,
@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
impl LanguageModelSelector { impl LanguageModelSelector {
pub fn new( pub fn new(
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static, get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
model_type: ModelType, on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -52,9 +47,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(), language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(), on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models), all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, model_type, cx), selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries, filtered_entries: entries,
model_type, get_active_model: Arc::new(get_active_model),
}; };
let picker = cx.new(|cx| { let picker = cx.new(|cx| {
@ -204,26 +199,13 @@ impl LanguageModelSelector {
} }
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> { pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
let model_type = self.picker.read(cx).delegate.model_type; (self.picker.read(cx).delegate.get_active_model)(cx)
Self::active_model_by_type(model_type, cx)
}
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
match model_type {
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
} }
fn get_active_model_index( fn get_active_model_index(
entries: &[LanguageModelPickerEntry], entries: &[LanguageModelPickerEntry],
model_type: ModelType, active_model: Option<ConfiguredModel>,
cx: &App,
) -> usize { ) -> usize {
let active_model = Self::active_model_by_type(model_type, cx);
entries entries
.iter() .iter()
.position(|entry| { .position(|entry| {
@ -232,7 +214,7 @@ impl LanguageModelSelector {
.as_ref() .as_ref()
.map(|active_model| { .map(|active_model| {
active_model.model.id() == model.model.id() active_model.model.id() == model.model.id()
&& active_model.model.provider_id() == model.model.provider_id() && active_model.provider.id() == model.model.provider_id()
}) })
.unwrap_or_default() .unwrap_or_default()
} else { } else {
@ -325,10 +307,10 @@ struct ModelInfo {
pub struct LanguageModelPickerDelegate { pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>, language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged, on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>, all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>, filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize, selected_index: usize,
model_type: ModelType,
} }
struct GroupedModels { struct GroupedModels {
@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(), .into_any_element(),
), ),
LanguageModelPickerEntry::Model(model_info) => { LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx); let active_model = (self.get_active_model)(cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id()); let active_model_id = active_model.map(|m| m.model.id());