agent: Refine language model selector (#28597)
Release Notes: - agent: Show recommended models in the agent model selector and display the provider in the model selector's trigger. --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
This commit is contained in:
parent
dafe994eef
commit
b22faf96e0
13 changed files with 350 additions and 218 deletions
|
@ -12,6 +12,7 @@ workspace = true
|
|||
path = "src/language_model_selector.rs"
|
||||
|
||||
[dependencies]
|
||||
collections.workspace = true
|
||||
feature_flags.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashSet, IndexMap};
|
||||
use feature_flags::{Assistant2FeatureFlag, ZedPro};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
|
||||
};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use proto::Plan;
|
||||
|
@ -24,9 +25,6 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
|
|||
|
||||
pub struct LanguageModelSelector {
|
||||
picker: Entity<Picker<LanguageModelPickerDelegate>>,
|
||||
/// The task used to update the picker's matches when there is a change to
|
||||
/// the language model registry.
|
||||
update_matches_task: Option<Task<()>>,
|
||||
_authenticate_all_providers_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
@ -40,16 +38,18 @@ impl LanguageModelSelector {
|
|||
let on_model_changed = Arc::new(on_model_changed);
|
||||
|
||||
let all_models = Self::all_models(cx);
|
||||
let entries = all_models.entries();
|
||||
|
||||
let delegate = LanguageModelPickerDelegate {
|
||||
language_model_selector: cx.entity().downgrade(),
|
||||
on_model_changed: on_model_changed.clone(),
|
||||
all_models: all_models.clone(),
|
||||
filtered_models: all_models,
|
||||
selected_index: Self::get_active_model_index(cx),
|
||||
all_models: Arc::new(all_models),
|
||||
selected_index: Self::get_active_model_index(&entries, cx),
|
||||
filtered_entries: entries,
|
||||
};
|
||||
|
||||
let picker = cx.new(|cx| {
|
||||
Picker::uniform_list(delegate, window, cx)
|
||||
Picker::list(delegate, window, cx)
|
||||
.show_scrollbar(true)
|
||||
.width(rems(20.))
|
||||
.max_height(Some(rems(20.).into()))
|
||||
|
@ -59,7 +59,6 @@ impl LanguageModelSelector {
|
|||
|
||||
LanguageModelSelector {
|
||||
picker,
|
||||
update_matches_task: None,
|
||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||
_subscriptions: vec![
|
||||
cx.subscribe_in(
|
||||
|
@ -83,12 +82,13 @@ impl LanguageModelSelector {
|
|||
language_model::Event::ProviderStateChanged
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let task = self.picker.update(cx, |this, cx| {
|
||||
self.picker.update(cx, |this, cx| {
|
||||
let query = this.query(cx);
|
||||
this.delegate.all_models = Self::all_models(cx);
|
||||
this.delegate.update_matches(query, window, cx)
|
||||
this.delegate.all_models = Arc::new(Self::all_models(cx));
|
||||
// Update matches will automatically drop the previous task
|
||||
// if we get a provider event again
|
||||
this.update_matches(query, window, cx)
|
||||
});
|
||||
self.update_matches_task = Some(task);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -144,34 +144,72 @@ impl LanguageModelSelector {
|
|||
})
|
||||
}
|
||||
|
||||
fn all_models(cx: &App) -> Vec<ModelInfo> {
|
||||
LanguageModelRegistry::global(cx)
|
||||
fn all_models(cx: &App) -> GroupedModels {
|
||||
let mut recommended = Vec::new();
|
||||
let mut recommended_set = HashSet::default();
|
||||
for provider in LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.flat_map(|provider| {
|
||||
let icon = provider.icon();
|
||||
|
||||
provider.provided_models(cx).into_iter().map(move |model| {
|
||||
let model = model.clone();
|
||||
let icon = model.icon().unwrap_or(icon);
|
||||
|
||||
ModelInfo {
|
||||
{
|
||||
let models = provider.recommended_models(cx);
|
||||
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
|
||||
recommended.extend(
|
||||
provider
|
||||
.recommended_models(cx)
|
||||
.into_iter()
|
||||
.map(move |model| ModelInfo {
|
||||
model: model.clone(),
|
||||
icon,
|
||||
availability: model.availability(),
|
||||
}
|
||||
})
|
||||
icon: provider.icon(),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
let other_models = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.map(|provider| {
|
||||
(
|
||||
provider.id(),
|
||||
provider
|
||||
.provided_models(cx)
|
||||
.into_iter()
|
||||
.filter_map(|model| {
|
||||
let not_included =
|
||||
!recommended_set.contains(&(model.provider_id(), model.id()));
|
||||
not_included.then(|| ModelInfo {
|
||||
model: model.clone(),
|
||||
icon: provider.icon(),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.collect::<IndexMap<_, _>>();
|
||||
|
||||
GroupedModels {
|
||||
recommended,
|
||||
other: other_models,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_active_model_index(cx: &App) -> usize {
|
||||
fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
Self::all_models(cx)
|
||||
entries
|
||||
.iter()
|
||||
.position(|model_info| {
|
||||
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
|
||||
.position(|entry| {
|
||||
if let LanguageModelPickerEntry::Model(model) = entry {
|
||||
active_model
|
||||
.as_ref()
|
||||
.map(|active_model| {
|
||||
active_model.model.id() == model.model.id()
|
||||
&& active_model.model.provider_id() == model.model.provider_id()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
@ -254,22 +292,61 @@ where
|
|||
struct ModelInfo {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
icon: IconName,
|
||||
availability: LanguageModelAvailability,
|
||||
}
|
||||
|
||||
pub struct LanguageModelPickerDelegate {
|
||||
language_model_selector: WeakEntity<LanguageModelSelector>,
|
||||
on_model_changed: OnModelChanged,
|
||||
all_models: Vec<ModelInfo>,
|
||||
filtered_models: Vec<ModelInfo>,
|
||||
all_models: Arc<GroupedModels>,
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
}
|
||||
|
||||
struct GroupedModels {
|
||||
recommended: Vec<ModelInfo>,
|
||||
other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
|
||||
}
|
||||
|
||||
impl GroupedModels {
|
||||
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
if !self.recommended.is_empty() {
|
||||
entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
|
||||
entries.extend(
|
||||
self.recommended
|
||||
.iter()
|
||||
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
|
||||
);
|
||||
}
|
||||
|
||||
for models in self.other.values() {
|
||||
if models.is_empty() {
|
||||
continue;
|
||||
}
|
||||
entries.push(LanguageModelPickerEntry::Separator(
|
||||
models[0].model.provider_name().0,
|
||||
));
|
||||
entries.extend(
|
||||
models
|
||||
.iter()
|
||||
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
|
||||
);
|
||||
}
|
||||
entries
|
||||
}
|
||||
}
|
||||
|
||||
enum LanguageModelPickerEntry {
|
||||
Model(ModelInfo),
|
||||
Separator(SharedString),
|
||||
}
|
||||
|
||||
impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
type ListItem = ListItem;
|
||||
type ListItem = AnyElement;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.filtered_models.len()
|
||||
self.filtered_entries.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
|
@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
|
||||
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn can_select(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) -> bool {
|
||||
match self.filtered_entries.get(ix) {
|
||||
Some(LanguageModelPickerEntry::Model(_)) => true,
|
||||
Some(LanguageModelPickerEntry::Separator(_)) | None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
"Select a model...".into()
|
||||
"Select a model…".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
|
@ -307,22 +396,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
cx.spawn_in(window, async move |this, cx| {
|
||||
let filtered_models = cx
|
||||
.background_spawn(async move {
|
||||
let displayed_models = if configured_providers.is_empty() {
|
||||
all_models
|
||||
} else {
|
||||
all_models
|
||||
.into_iter()
|
||||
.filter(|model_info| {
|
||||
configured_providers.contains(&model_info.model.provider_id())
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if query.is_empty() {
|
||||
displayed_models
|
||||
} else {
|
||||
displayed_models
|
||||
.into_iter()
|
||||
let filter_models = |model_infos: &[ModelInfo]| {
|
||||
model_infos
|
||||
.iter()
|
||||
.filter(|model_info| {
|
||||
model_info
|
||||
.model
|
||||
|
@ -331,20 +407,33 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
.to_lowercase()
|
||||
.contains(&query.to_lowercase())
|
||||
})
|
||||
.collect()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let recommended_models = filter_models(&all_models.recommended);
|
||||
let mut other_models = IndexMap::default();
|
||||
for (provider_id, models) in &all_models.other {
|
||||
if configured_providers.contains(&provider_id) {
|
||||
other_models.insert(provider_id.clone(), filter_models(models));
|
||||
}
|
||||
}
|
||||
GroupedModels {
|
||||
recommended: recommended_models,
|
||||
other: other_models,
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.filtered_models = filtered_models;
|
||||
this.delegate.filtered_entries = filtered_models.entries();
|
||||
// Preserve selection focus
|
||||
let new_index = if current_index >= this.delegate.filtered_models.len() {
|
||||
let new_index = if current_index >= this.delegate.filtered_entries.len() {
|
||||
0
|
||||
} else {
|
||||
current_index
|
||||
};
|
||||
this.delegate.set_selected_index(new_index, window, cx);
|
||||
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
|
@ -352,7 +441,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if let Some(model_info) = self.filtered_models.get(self.selected_index) {
|
||||
if let Some(LanguageModelPickerEntry::Model(model_info)) =
|
||||
self.filtered_entries.get(self.selected_index)
|
||||
{
|
||||
let model = model_info.model.clone();
|
||||
(self.on_model_changed)(model.clone(), cx);
|
||||
|
||||
|
@ -369,29 +460,6 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
.ok();
|
||||
}
|
||||
|
||||
fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
|
||||
let configured_models_count = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.count();
|
||||
|
||||
if configured_models_count > 0 {
|
||||
Some(
|
||||
Label::new("Configured Models")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.mt_1()
|
||||
.mb_0p5()
|
||||
.ml_2()
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
|
@ -399,77 +467,68 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
|||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
let show_badges = cx.has_flag::<ZedPro>();
|
||||
match self.filtered_entries.get(ix)? {
|
||||
LanguageModelPickerEntry::Separator(title) => Some(
|
||||
div()
|
||||
.px_2()
|
||||
.pb_1()
|
||||
.when(ix > 1, |this| {
|
||||
this.mt_1()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(
|
||||
Label::new(title)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
LanguageModelPickerEntry::Model(model_info) => {
|
||||
let active_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let model_info = self.filtered_models.get(ix)?;
|
||||
let provider_name: String = model_info.model.provider_name().0.clone().into();
|
||||
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
|
||||
let active_model_id = active_model.map(|m| m.model.id());
|
||||
|
||||
let active_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
|
||||
&& Some(model_info.model.id()) == active_model_id;
|
||||
|
||||
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
|
||||
let active_model_id = active_model.map(|m| m.model.id());
|
||||
let model_icon_color = if is_selected {
|
||||
Color::Accent
|
||||
} else {
|
||||
Color::Muted
|
||||
};
|
||||
|
||||
let is_selected = Some(model_info.model.provider_id()) == active_provider_id
|
||||
&& Some(model_info.model.id()) == active_model_id;
|
||||
|
||||
let model_icon_color = if is_selected {
|
||||
Color::Accent
|
||||
} else {
|
||||
Color::Muted
|
||||
};
|
||||
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.start_slot(
|
||||
Icon::new(model_info.icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.items_center()
|
||||
.gap_1p5()
|
||||
.pl_0p5()
|
||||
.w(px(240.))
|
||||
.child(
|
||||
div()
|
||||
.max_w_40()
|
||||
.child(Label::new(model_info.model.name().0.clone()).truncate()),
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.start_slot(
|
||||
Icon::new(model_info.icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
Label::new(provider_name)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.children(match model_info.availability {
|
||||
LanguageModelAvailability::Public => None,
|
||||
LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
|
||||
LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
|
||||
show_badges.then(|| {
|
||||
Label::new("Pro")
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
})
|
||||
}
|
||||
}),
|
||||
),
|
||||
.w_full()
|
||||
.pl_0p5()
|
||||
.gap_1p5()
|
||||
.w(px(240.))
|
||||
.child(Label::new(model_info.model.name().0.clone()).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
this.child(
|
||||
Icon::new(IconName::Check)
|
||||
.color(Color::Accent)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
}))
|
||||
.into_any_element(),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
this.child(
|
||||
Icon::new(IconName::Check)
|
||||
.color(Color::Accent)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
})),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_footer(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue