inline assistant: Fix model picker (#29136)

Release Notes:

- inline assistant: Fixed a bug where the default model would be used
even when a specific inline assistant model was configured
This commit is contained in:
Agus Zubiaga 2025-04-20 22:12:57 -03:00 committed by GitHub
parent ceeae790b7
commit 4473b45c3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 61 additions and 31 deletions

View file

@ -1,7 +1,7 @@
use assistant_settings::AssistantSettings; use assistant_settings::AssistantSettings;
use fs::Fs; use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString}; use gpui::{Entity, FocusHandle, SharedString};
use language_model::LanguageModelRegistry;
use language_model_selector::{ use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
}; };
@ -9,17 +9,12 @@ use settings::update_settings_file;
use std::sync::Arc; use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone, Copy)] pub use language_model_selector::ModelType;
pub enum ModelType {
Default,
InlineAssistant,
}
pub struct AssistantModelSelector { pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>, selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>, menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle, focus_handle: FocusHandle,
model_type: ModelType,
} }
impl AssistantModelSelector { impl AssistantModelSelector {
@ -63,13 +58,13 @@ impl AssistantModelSelector {
} }
} }
}, },
model_type,
window, window,
cx, cx,
) )
}), }),
menu_handle, menu_handle,
focus_handle, focus_handle,
model_type,
} }
} }
@ -82,11 +77,7 @@ impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone(); let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx); let model = self.selector.read(cx).active_model(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let (model_name, model_icon) = match model { let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())), Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None), _ => (SharedString::from("No model selected"), None),

View file

@ -1,7 +1,7 @@
use crate::context::attach_context_to_message; use crate::context::attach_context_to_message;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
use crate::inline_prompt_editor::CodegenStatus; use crate::inline_prompt_editor::CodegenStatus;
use anyhow::{Context as _, Result}; use anyhow::Result;
use client::telemetry::Telemetry; use client::telemetry::Telemetry;
use collections::HashSet; use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
@ -131,7 +131,12 @@ impl BufferCodegen {
cx.notify(); cx.notify();
} }
pub fn start(&mut self, user_prompt: String, cx: &mut Context<Self>) -> Result<()> { pub fn start(
&mut self,
primary_model: Arc<dyn LanguageModel>,
user_prompt: String,
cx: &mut Context<Self>,
) -> Result<()> {
let alternative_models = LanguageModelRegistry::read_global(cx) let alternative_models = LanguageModelRegistry::read_global(cx)
.inline_alternative_models() .inline_alternative_models()
.to_vec(); .to_vec();
@ -155,11 +160,6 @@ impl BufferCodegen {
})); }));
} }
let primary_model = LanguageModelRegistry::read_global(cx)
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model) for (model, alternative) in iter::once(primary_model)
.chain(alternative_models) .chain(alternative_models)
.zip(&self.alternatives) .zip(&self.alternatives)

View file

@ -24,6 +24,7 @@ use gpui::{
WeakEntity, Window, point, WeakEntity, Window, point,
}; };
use language::{Buffer, Point, Selection, TransactionId}; use language::{Buffer, Point, Selection, TransactionId};
use language_model::ConfiguredModel;
use language_model::{LanguageModelRegistry, report_assistant_event}; use language_model::{LanguageModelRegistry, report_assistant_event};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use parking_lot::Mutex; use parking_lot::Mutex;
@ -1221,9 +1222,15 @@ impl InlineAssistant {
self.prompt_history.pop_front(); self.prompt_history.pop_front();
} }
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
assist assist
.codegen .codegen
.update(cx, |codegen, cx| codegen.start(user_prompt, cx)) .update(cx, |codegen, cx| codegen.start(model, user_prompt, cx))
.log_err(); .log_err();
} }

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::{AssistantModelSelector, ModelType}; use crate::assistant_model_selector::AssistantModelSelector;
use crate::buffer_codegen::BufferCodegen; use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker; use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore; use crate::context_store::ContextStore;
@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point, Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
}; };
use language_model::{LanguageModel, LanguageModelRegistry}; use language_model::{LanguageModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector; use language_model_selector::{ModelType, ToggleModelSelector};
use parking_lot::Mutex; use parking_lot::Mutex;
use settings::Settings; use settings::Settings;
use std::cmp; use std::cmp;

View file

@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event, LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
}; };
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use parking_lot::Mutex; use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction}; use project::{CodeAction, LspAction, ProjectTransaction};
@ -1766,6 +1766,7 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event, Role, report_assistant_event,
}; };
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use prompt_store::PromptBuilder; use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file}; use settings::{Settings, update_settings_file};
use std::{ use std::{
@ -755,6 +755,7 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -39,7 +39,7 @@ use language_model::{
Role, Role,
}; };
use language_model_selector::{ use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
}; };
use multi_buffer::MultiBufferRow; use multi_buffer::MultiBufferRow;
use picker::Picker; use picker::Picker;
@ -298,6 +298,7 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()), move |settings, _| settings.set_model(model.clone()),
); );
}, },
ModelType::Default,
window, window,
cx, cx,
) )

View file

@ -7,7 +7,8 @@ use gpui::{
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
}; };
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
}; };
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use proto::Plan; use proto::Plan;
@ -29,9 +30,16 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
impl LanguageModelSelector { impl LanguageModelSelector {
pub fn new( pub fn new(
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static, on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
model_type: ModelType,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
@ -44,8 +52,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(), language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(), on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models), all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, cx), selected_index: Self::get_active_model_index(&entries, model_type, cx),
filtered_entries: entries, filtered_entries: entries,
model_type,
}; };
let picker = cx.new(|cx| { let picker = cx.new(|cx| {
@ -194,8 +203,27 @@ impl LanguageModelSelector {
} }
} }
fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize { pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
let active_model = LanguageModelRegistry::read_global(cx).default_model(); let model_type = self.picker.read(cx).delegate.model_type;
Self::active_model_by_type(model_type, cx)
}
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
match model_type {
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
model_type: ModelType,
cx: &App,
) -> usize {
let active_model = Self::active_model_by_type(model_type, cx);
entries entries
.iter() .iter()
.position(|entry| { .position(|entry| {
@ -300,6 +328,7 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>, all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>, filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize, selected_index: usize,
model_type: ModelType,
} }
struct GroupedModels { struct GroupedModels {
@ -493,7 +522,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(), .into_any_element(),
), ),
LanguageModelPickerEntry::Model(model_info) => { LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelRegistry::read_global(cx).default_model(); let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id()); let active_model_id = active_model.map(|m| m.model.id());