diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs index 091071af29..f63eb588bc 100644 --- a/crates/agent/src/assistant_model_selector.rs +++ b/crates/agent/src/assistant_model_selector.rs @@ -1,7 +1,7 @@ use assistant_settings::AssistantSettings; use fs::Fs; use gpui::{Entity, FocusHandle, SharedString}; -use language_model::LanguageModelRegistry; + use language_model_selector::{ LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, }; @@ -9,17 +9,12 @@ use settings::update_settings_file; use std::sync::Arc; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; -#[derive(Clone, Copy)] -pub enum ModelType { - Default, - InlineAssistant, -} +pub use language_model_selector::ModelType; pub struct AssistantModelSelector { selector: Entity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, - model_type: ModelType, } impl AssistantModelSelector { @@ -63,13 +58,13 @@ impl AssistantModelSelector { } } }, + model_type, window, cx, ) }), menu_handle, focus_handle, - model_type, } } @@ -82,11 +77,7 @@ impl Render for AssistantModelSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let focus_handle = self.focus_handle.clone(); - let model_registry = LanguageModelRegistry::read_global(cx); - let model = match self.model_type { - ModelType::Default => model_registry.default_model(), - ModelType::InlineAssistant => model_registry.inline_assistant_model(), - }; + let model = self.selector.read(cx).active_model(cx); let (model_name, model_icon) = match model { Some(model) => (model.model.name().0, Some(model.provider.icon())), _ => (SharedString::from("No model selected"), None), diff --git a/crates/agent/src/buffer_codegen.rs b/crates/agent/src/buffer_codegen.rs index bd8bf29758..f323c0ccab 100644 --- a/crates/agent/src/buffer_codegen.rs +++ b/crates/agent/src/buffer_codegen.rs @@ -1,7 +1,7 @@ use crate::context::attach_context_to_message; use crate::context_store::ContextStore; use crate::inline_prompt_editor::CodegenStatus; -use anyhow::{Context as _, Result}; +use anyhow::Result; use client::telemetry::Telemetry; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; @@ -131,7 +131,12 @@ impl BufferCodegen { cx.notify(); } - pub fn start(&mut self, user_prompt: String, cx: &mut Context) -> Result<()> { + pub fn start( + &mut self, + primary_model: Arc, + user_prompt: String, + cx: &mut Context, + ) -> Result<()> { let alternative_models = LanguageModelRegistry::read_global(cx) .inline_alternative_models() .to_vec(); @@ -155,11 +160,6 @@ impl BufferCodegen { })); } - let primary_model = LanguageModelRegistry::read_global(cx) - .default_model() - .context("no active model")? - .model; - for (model, alternative) in iter::once(primary_model) .chain(alternative_models) .zip(&self.alternatives) diff --git a/crates/agent/src/inline_assistant.rs b/crates/agent/src/inline_assistant.rs index 451ae92893..cca40e06b1 100644 --- a/crates/agent/src/inline_assistant.rs +++ b/crates/agent/src/inline_assistant.rs @@ -24,6 +24,7 @@ use gpui::{ WeakEntity, Window, point, }; use language::{Buffer, Point, Selection, TransactionId}; +use language_model::ConfiguredModel; use language_model::{LanguageModelRegistry, report_assistant_event}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; @@ -1221,9 +1222,15 @@ impl InlineAssistant { self.prompt_history.pop_front(); } + let Some(ConfiguredModel { model, .. }) = + LanguageModelRegistry::read_global(cx).inline_assistant_model() + else { + return; + }; + assist .codegen - .update(cx, |codegen, cx| codegen.start(user_prompt, cx)) + .update(cx, |codegen, cx| codegen.start(model, user_prompt, cx)) .log_err(); } diff --git a/crates/agent/src/inline_prompt_editor.rs b/crates/agent/src/inline_prompt_editor.rs index 913368d0e4..09d867db8b 100644 --- a/crates/agent/src/inline_prompt_editor.rs +++ b/crates/agent/src/inline_prompt_editor.rs @@ -1,4 +1,4 @@ -use crate::assistant_model_selector::{AssistantModelSelector, ModelType}; +use crate::assistant_model_selector::AssistantModelSelector; use crate::buffer_codegen::BufferCodegen; use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; @@ -20,7 +20,7 @@ use gpui::{ Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point, }; use language_model::{LanguageModel, LanguageModelRegistry}; -use language_model_selector::ToggleModelSelector; +use language_model_selector::{ModelType, ToggleModelSelector}; use parking_lot::Mutex; use settings::Settings; use std::cmp; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index b7ba199d83..ea395c5dd4 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -37,7 +37,7 @@ use language_model::{ ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event, }; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; +use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::{CodeAction, LspAction, ProjectTransaction}; @@ -1766,6 +1766,7 @@ impl PromptEditor { move |settings, _| settings.set_model(model.clone()), ); }, + ModelType::Default, window, cx, ) diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index db50193d2e..b7ac8fb43d 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -19,7 +19,7 @@ use language_model::{ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, report_assistant_event, }; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; +use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; use prompt_store::PromptBuilder; use settings::{Settings, update_settings_file}; use std::{ @@ -755,6 +755,7 @@ impl PromptEditor { move |settings, _| settings.set_model(model.clone()), ); }, + ModelType::Default, window, cx, ) diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index e26e2816d2..f21fee3f4c 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -39,7 +39,7 @@ use language_model::{ Role, }; use language_model_selector::{ - LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, + LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector, }; use multi_buffer::MultiBufferRow; use picker::Picker; @@ -298,6 +298,7 @@ impl ContextEditor { move |settings, _| settings.set_model(model.clone()), ); }, + ModelType::Default, window, cx, ) diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index e8b10e3e6d..f4710e26b9 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -7,7 +7,8 @@ use gpui::{ Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, }; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, + AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, + LanguageModelRegistry, }; use picker::{Picker, PickerDelegate}; use proto::Plan; @@ -29,9 +30,16 @@ pub struct LanguageModelSelector { _subscriptions: Vec, } +#[derive(Clone, Copy)] +pub enum ModelType { + Default, + InlineAssistant, +} + impl LanguageModelSelector { pub fn new( on_model_changed: impl Fn(Arc, &App) + 'static, + model_type: ModelType, window: &mut Window, cx: &mut Context, ) -> Self { @@ -44,8 +52,9 @@ impl LanguageModelSelector { language_model_selector: cx.entity().downgrade(), on_model_changed: on_model_changed.clone(), all_models: Arc::new(all_models), - selected_index: Self::get_active_model_index(&entries, cx), + selected_index: Self::get_active_model_index(&entries, model_type, cx), filtered_entries: entries, + model_type, }; let picker = cx.new(|cx| { @@ -194,8 +203,27 @@ impl LanguageModelSelector { } } - fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize { - let active_model = LanguageModelRegistry::read_global(cx).default_model(); + pub fn active_model(&self, cx: &App) -> Option { + let model_type = self.picker.read(cx).delegate.model_type; + Self::active_model_by_type(model_type, cx) + } + + fn active_model_by_type(model_type: ModelType, cx: &App) -> Option { + match model_type { + ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(), + ModelType::InlineAssistant => { + LanguageModelRegistry::read_global(cx).inline_assistant_model() + } + } + } + + fn get_active_model_index( + entries: &[LanguageModelPickerEntry], + model_type: ModelType, + cx: &App, + ) -> usize { + let active_model = Self::active_model_by_type(model_type, cx); + entries .iter() .position(|entry| { @@ -300,6 +328,7 @@ pub struct LanguageModelPickerDelegate { all_models: Arc, filtered_entries: Vec, selected_index: usize, + model_type: ModelType, } struct GroupedModels { @@ -493,7 +522,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .into_any_element(), ), LanguageModelPickerEntry::Model(model_info) => { - let active_model = LanguageModelRegistry::read_global(cx).default_model(); + let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx); let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let active_model_id = active_model.map(|m| m.model.id());