Associate each thread with a model (#29573)

This PR makes it possible to use different LLM models in the agent
panels of two different projects, simultaneously. It also properly
restores a thread's original model when restoring it from the history,
rather than having it use the default model. As before, newly-created
threads will use the current default model.

Release Notes:

- Enabled different project windows to use different models in the agent
panel
- Enhanced the agent panel so that when revisiting old threads, their
original model will be used.

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Max Brunsfeld 2025-04-28 16:43:16 -07:00 committed by GitHub
parent 5102c4c002
commit 17903a0999
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 168 additions and 114 deletions

View file

@ -25,8 +25,8 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry};
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
RequestUsage, Role, StopReason,
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role,
StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@ -1252,7 +1252,7 @@ impl ActiveThread {
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
let Some(configured_model) = self.thread.read(cx).configured_model() else {
state.last_estimated_token_count.take();
return;
};
@ -1305,7 +1305,7 @@ impl ActiveThread {
temperature: None,
};
Some(default_model.model.count_tokens(request, cx))
Some(configured_model.model.count_tokens(request, cx))
})? {
task.await?
} else {
@ -1338,7 +1338,7 @@ impl ActiveThread {
return;
};
let edited_text = state.editor.read(cx).text(cx);
self.thread.update(cx, |thread, cx| {
let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
@ -1348,9 +1348,10 @@ impl ActiveThread {
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
thread.get_or_init_configured_model(cx)
});
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
let Some(model) = thread_model else {
return;
};

View file

@ -951,6 +951,7 @@ mod tests {
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
language_model::init_settings(cx);
});
let fs = FakeFs::new(cx.executor());

View file

@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use crate::Thread;
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
@ -9,7 +11,11 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
pub use language_model_selector::ModelType;
#[derive(Clone)]
pub enum ModelType {
Default(Entity<Thread>),
InlineAssistant,
}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
@ -24,18 +30,39 @@ impl AssistantModelSelector {
focus_handle: FocusHandle,
model_type: ModelType,
window: &mut Window,
cx: &mut App,
cx: &mut Context<Self>,
) -> Self {
Self {
selector: cx.new(|cx| {
selector: cx.new(move |cx| {
let fs = fs.clone();
LanguageModelSelector::new(
{
let model_type = model_type.clone();
move |cx| match &model_type {
ModelType::Default(thread) => thread.read(cx).configured_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
},
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
match model_type {
ModelType::Default => {
match &model_type {
ModelType::Default(thread) => {
thread.update(cx, |thread, cx| {
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id())
{
thread.set_configured_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
}),
cx,
);
}
});
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
@ -58,7 +85,6 @@ impl AssistantModelSelector {
}
}
},
model_type,
window,
cx,
)

View file

@ -1274,12 +1274,12 @@ impl AssistantPanel {
let is_generating = thread.is_generating();
let message_editor = self.message_editor.read(cx);
let conversation_token_usage = thread.total_token_usage(cx);
let conversation_token_usage = thread.total_token_usage();
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
self.thread.read(cx).editing_message_id()
{
let combined = thread
.token_usage_up_to_message(editing_message_id, cx)
.token_usage_up_to_message(editing_message_id)
.add(unsent_tokens);
(combined, unsent_tokens > 0)

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use language_model_selector::{ModelType, ToggleModelSelector};
use language_model_selector::ToggleModelSelector;
use parking_lot::Mutex;
use settings::Settings;
use std::cmp;

View file

@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
@ -21,9 +21,7 @@ use gpui::{
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
};
use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@ -36,7 +34,6 @@ use util::ResultExt as _;
use workspace::Workspace;
use zed_llm_client::CompletionMode;
use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@ -153,6 +150,17 @@ impl MessageEditor {
}),
];
let model_selector = cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default(thread.clone()),
window,
cx,
)
});
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
@ -165,16 +173,7 @@ impl MessageEditor {
context_picker_menu_handle,
load_context_task: None,
last_loaded_context: None,
model_selector: cx.new(|cx| {
AssistantModelSelector::new(
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelType::Default,
window,
cx,
)
}),
model_selector,
edits_expanded: false,
editor_is_expanded: false,
profile_selector: cx
@ -263,15 +262,11 @@ impl MessageEditor {
self.editor.read(cx).text(cx).trim().is_empty()
}
fn is_model_selected(&self, cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.default_model()
.is_some()
}
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
let Some(ConfiguredModel { model, provider }) = self
.thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return;
};
@ -408,14 +403,13 @@ impl MessageEditor {
return None;
}
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model.clone())?;
if !model.supports_max_mode() {
let thread = self.thread.read(cx);
let model = thread.configured_model();
if !model?.model.supports_max_mode() {
return None;
}
let active_completion_mode = self.thread.read(cx).completion_mode();
let active_completion_mode = thread.completion_mode();
Some(
IconButton::new("max-mode", IconName::SquarePlus)
@ -442,24 +436,21 @@ impl MessageEditor {
cx: &mut Context<Self>,
) -> Div {
let thread = self.thread.read(cx);
let model = thread.configured_model();
let editor_bg_color = cx.theme().colors().editor_background;
let is_generating = thread.is_generating();
let focus_handle = self.editor.focus_handle(cx);
let is_model_selected = self.is_model_selected(cx);
let is_model_selected = model.is_some();
let is_editor_empty = self.is_editor_empty(cx);
let model = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.model.clone());
let incompatible_tools = model
.as_ref()
.map(|model| {
self.incompatible_tools_state.update(cx, |state, cx| {
state
.incompatible_tools(model, cx)
.incompatible_tools(&model.model, cx)
.iter()
.cloned()
.collect::<Vec<_>>()
@ -1058,7 +1049,7 @@ impl MessageEditor {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
let Some(model) = self.thread.read(cx).configured_model() else {
self.last_estimated_token_count.take();
return;
};
@ -1111,7 +1102,7 @@ impl MessageEditor {
temperature: None,
};
Some(default_model.model.count_tokens(request, cx))
Some(model.model.count_tokens(request, cx))
})? {
task.await?
} else {
@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor {
impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
let total_token_usage = thread.total_token_usage(cx);
let total_token_usage = thread.total_token_usage();
let token_usage_ratio = total_token_usage.ratio();
let action_log = self.thread.read(cx).action_log();

View file

@ -22,8 +22,8 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
TokenUsage,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::Project;
@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
use crate::ThreadStore;
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
SerializedToolResult, SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
@ -332,6 +332,7 @@ pub struct Thread {
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -351,6 +352,8 @@ impl Thread {
cx: &mut Context<Self>,
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@ -388,6 +391,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
}
}
@ -411,6 +415,19 @@ impl Thread {
let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state);
let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
serialized
.model
.and_then(|model| {
let model = SelectedModel {
provider: model.provider.clone().into(),
model: model.model.clone().into(),
};
registry.select_model(&model, cx)
})
.or_else(|| registry.default_model())
});
Self {
id,
updated_at: serialized.updated_at,
@ -468,6 +485,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
}
}
@ -507,6 +525,22 @@ impl Thread {
self.project_context.clone()
}
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
if self.configured_model.is_none() {
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
}
self.configured_model.clone()
}
pub fn configured_model(&self) -> Option<ConfiguredModel> {
self.configured_model.clone()
}
pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
self.configured_model = model;
cx.notify();
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@ -952,6 +986,13 @@ impl Thread {
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
model: this
.configured_model
.as_ref()
.map(|model| SerializedLanguageModel {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
})
})
}
@ -1733,7 +1774,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
cx,
self.configured_model.as_ref(),
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
@ -1808,7 +1849,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
output,
cx,
thread.configured_model.as_ref(),
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
@ -1826,10 +1867,9 @@ impl Thread {
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
self.send_to_model(model, window, cx);
self.send_to_model(model.clone(), window, cx);
}
self.auto_capture_telemetry(cx);
}
@ -2254,8 +2294,8 @@ impl Thread {
self.cumulative_token_usage
}
pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@ -2283,9 +2323,8 @@ impl Thread {
}
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {
pub fn total_token_usage(&self) -> TotalTokenUsage {
let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@ -2336,8 +2375,12 @@ impl Thread {
"Permission to run tool action denied by user"
));
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
self.tool_use.insert_tool_output(
tool_use_id.clone(),
tool_name,
err,
self.configured_model.as_ref(),
);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}
}
@ -2769,6 +2812,7 @@ fn main() {{
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);

View file

@ -640,6 +640,14 @@ pub struct SerializedThread {
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
#[serde(default)]
pub model: Option<SerializedLanguageModel>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SerializedLanguageModel {
pub provider: String,
pub model: String,
}
impl SerializedThread {
@ -774,6 +782,7 @@ impl LegacySerializedThread {
request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
model: None,
}
}
}

View file

@ -7,7 +7,7 @@ use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
};
use ui::IconName;
@ -353,7 +353,7 @@ impl ToolUseState {
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<String>,
cx: &App,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@ -373,13 +373,10 @@ impl ToolUseState {
match output {
Ok(tool_result) => {
let model_registry = LanguageModelRegistry::read_global(cx);
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output
let tool_output_limit = model_registry
.default_model()
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);

View file

@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction};
@ -1759,6 +1759,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| {
let fs = fs.clone();
LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@ -1766,7 +1767,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file};
use std::{
@ -749,6 +749,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| {
let fs = fs.clone();
LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@ -756,7 +757,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -39,7 +39,7 @@ use language_model::{
Role,
};
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
use multi_buffer::MultiBufferRow;
use picker::Picker;
@ -291,6 +291,7 @@ impl ContextEditor {
dragged_file_worktrees: Vec::new(),
language_model_selector: cx.new(|cx| {
LanguageModelSelector::new(
|cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@ -298,7 +299,6 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -39,10 +39,14 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub fn init(client: Arc<Client>, cx: &mut App) {
registry::init(cx);
init_settings(cx);
RefreshLlmTokenListener::register(client.clone(), cx);
}
pub fn init_settings(cx: &mut App) {
registry::init(cx);
}
/// The availability of a [`LanguageModel`].
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LanguageModelAvailability {

View file

@ -188,7 +188,7 @@ impl LanguageModelRegistry {
.collect::<Vec<_>>();
}
fn select_model(
pub fn select_model(
&mut self,
selected_model: &SelectedModel,
cx: &mut Context<Self>,

View file

@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>,
@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>,
}
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
impl LanguageModelSelector {
pub fn new(
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
model_type: ModelType,
get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -52,9 +47,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, model_type, cx),
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
model_type,
get_active_model: Arc::new(get_active_model),
};
let picker = cx.new(|cx| {
@ -204,26 +199,13 @@ impl LanguageModelSelector {
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
let model_type = self.picker.read(cx).delegate.model_type;
Self::active_model_by_type(model_type, cx)
}
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
match model_type {
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
(self.picker.read(cx).delegate.get_active_model)(cx)
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
model_type: ModelType,
cx: &App,
active_model: Option<ConfiguredModel>,
) -> usize {
let active_model = Self::active_model_by_type(model_type, cx);
entries
.iter()
.position(|entry| {
@ -232,7 +214,7 @@ impl LanguageModelSelector {
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.model.provider_id() == model.model.provider_id()
&& active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
@ -325,10 +307,10 @@ struct ModelInfo {
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged,
get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
model_type: ModelType,
}
struct GroupedModels {
@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
let active_model = (self.get_active_model)(cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());