inline assistant: Fix model picker (#29136)

Release Notes:

- inline assistant: Fixed a bug where the default model would be used
even when a specific inline assistant model was configured
This commit is contained in:
Agus Zubiaga 2025-04-20 22:12:57 -03:00 committed by GitHub
parent ceeae790b7
commit 4473b45c3d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 61 additions and 31 deletions

View file

@ -1,7 +1,7 @@
use assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use language_model::LanguageModelRegistry;
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
@ -9,17 +9,12 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
pub use language_model_selector::ModelType;
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
}
impl AssistantModelSelector {
@ -63,13 +58,13 @@ impl AssistantModelSelector {
}
}
},
model_type,
window,
cx,
)
}),
menu_handle,
focus_handle,
model_type,
}
}
@ -82,11 +77,7 @@ impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let model = self.selector.read(cx).active_model(cx);
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),

View file

@ -1,7 +1,7 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::CodegenStatus;
use anyhow::{Context as _, Result};
use anyhow::Result;
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
@ -131,7 +131,12 @@ impl BufferCodegen {
cx.notify();
}
pub fn start(&mut self, user_prompt: String, cx: &mut Context<Self>) -> Result<()> {
pub fn start(
&mut self,
primary_model: Arc<dyn LanguageModel>,
user_prompt: String,
cx: &mut Context<Self>,
) -> Result<()> {
let alternative_models = LanguageModelRegistry::read_global(cx)
.inline_alternative_models()
.to_vec();
@ -155,11 +160,6 @@ impl BufferCodegen {
}));
}
let primary_model = LanguageModelRegistry::read_global(cx)
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
.zip(&self.alternatives)

View file

@ -24,6 +24,7 @@ use gpui::{
WeakEntity, Window, point,
};
use language::{Buffer, Point, Selection, TransactionId};
use language_model::ConfiguredModel;
use language_model::{LanguageModelRegistry, report_assistant_event};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
@ -1221,9 +1222,15 @@ impl InlineAssistant {
self.prompt_history.pop_front();
}
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
assist
.codegen
.update(cx, |codegen, cx| codegen.start(user_prompt, cx))
.update(cx, |codegen, cx| codegen.start(model, user_prompt, cx))
.log_err();
}

View file

@ -1,4 +1,4 @@
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::assistant_model_selector::AssistantModelSelector;
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use language_model_selector::{ModelType, ToggleModelSelector};
use parking_lot::Mutex;
use settings::Settings;
use std::cmp;

View file

@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction};
@ -1766,6 +1766,7 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file};
use std::{
@ -755,6 +755,7 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -39,7 +39,7 @@ use language_model::{
Role,
};
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
};
use multi_buffer::MultiBufferRow;
use picker::Picker;
@ -298,6 +298,7 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View file

@ -7,7 +7,8 @@ use gpui::{
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
};
use picker::{Picker, PickerDelegate};
use proto::Plan;
@ -29,9 +30,16 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>,
}
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
impl LanguageModelSelector {
pub fn new(
on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
model_type: ModelType,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@ -44,8 +52,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models),
selected_index: Self::get_active_model_index(&entries, cx),
selected_index: Self::get_active_model_index(&entries, model_type, cx),
filtered_entries: entries,
model_type,
};
let picker = cx.new(|cx| {
@ -194,8 +203,27 @@ impl LanguageModelSelector {
}
}
fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
let active_model = LanguageModelRegistry::read_global(cx).default_model();
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
let model_type = self.picker.read(cx).delegate.model_type;
Self::active_model_by_type(model_type, cx)
}
fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
match model_type {
ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
ModelType::InlineAssistant => {
LanguageModelRegistry::read_global(cx).inline_assistant_model()
}
}
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
model_type: ModelType,
cx: &App,
) -> usize {
let active_model = Self::active_model_by_type(model_type, cx);
entries
.iter()
.position(|entry| {
@ -300,6 +328,7 @@ pub struct LanguageModelPickerDelegate {
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
model_type: ModelType,
}
struct GroupedModels {
@ -493,7 +522,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelRegistry::read_global(cx).default_model();
let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());