agent: Refine language model selector (#28597)

Release Notes:

- agent: Show recommended models in the agent model selector and display
the provider in the model selector's trigger.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>
This commit is contained in:
Bennet Bo Fenner 2025-04-11 17:02:50 -06:00 committed by GitHub
parent dafe994eef
commit b22faf96e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 350 additions and 218 deletions

1
Cargo.lock generated
View file

@ -7657,6 +7657,7 @@ dependencies = [
name = "language_model_selector" name = "language_model_selector"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"collections",
"feature_flags", "feature_flags",
"gpui", "gpui",
"language_model", "language_model",

View file

@ -1,12 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
<g clip-path="url(#clip0_1916_18)">
<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
</g>
<defs>
<clipPath id="clip0_1916_18">
<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 601 B

View file

@ -80,17 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector { impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement { fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model_registry = LanguageModelRegistry::read_global(cx); let focus_handle = self.focus_handle.clone();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type { let model = match self.model_type {
ModelType::Default => model_registry.default_model(), ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(), ModelType::InlineAssistant => model_registry.inline_assistant_model(),
}; };
let (model_name, model_icon) = match model {
let focus_handle = self.focus_handle.clone(); Some(model) => (model.model.name().0, Some(model.provider.icon())),
let model_name = match model { _ => (SharedString::from("No model selected"), None),
Some(model) => model.model.name().0,
_ => SharedString::from("No model selected"),
}; };
LanguageModelSelectorPopoverMenu::new( LanguageModelSelectorPopoverMenu::new(
@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
.child( .child(
h_flex() h_flex()
.gap_0p5() .gap_0p5()
.children(
model_icon.map(|icon| {
Icon::new(icon).color(Color::Muted).size(IconSize::Small)
}),
)
.child( .child(
Label::new(model_name) Label::new(model_name)
.size(LabelSize::Small) .size(LabelSize::Small)
.color(Color::Muted), .color(Color::Muted)
.ml_1(),
) )
.child( .child(
Icon::new(IconName::ChevronDown) Icon::new(IconName::ChevronDown)

View file

@ -2133,18 +2133,28 @@ async fn test_repeat_toggle_action(cx: &mut gpui::TestAppContext) {
cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default());
let picker = active_file_picker(&workspace, cx); let picker = active_file_picker(&workspace, cx);
picker.update_in(cx, |picker, window, cx| {
picker.update_matches(".txt".to_string(), window, cx)
});
cx.run_until_parked();
picker.update(cx, |picker, _| { picker.update(cx, |picker, _| {
assert_eq!(picker.delegate.matches.len(), 6);
assert_eq!(picker.delegate.selected_index, 0); assert_eq!(picker.delegate.selected_index, 0);
assert_eq!(picker.logical_scroll_top_index(), 0);
}); });
// When toggling repeatedly, the picker scrolls to reveal the selected item. // When toggling repeatedly, the picker scrolls to reveal the selected item.
cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default());
cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default());
cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default());
cx.run_until_parked();
picker.update(cx, |picker, _| { picker.update(cx, |picker, _| {
assert_eq!(picker.delegate.matches.len(), 6);
assert_eq!(picker.delegate.selected_index, 3); assert_eq!(picker.delegate.selected_index, 3);
assert_eq!(picker.logical_scroll_top_index(), 3);
}); });
} }

View file

@ -10,7 +10,6 @@ use strum::{EnumIter, EnumString, IntoStaticStr};
pub enum IconName { pub enum IconName {
Ai, Ai,
AiAnthropic, AiAnthropic,
AiAnthropicHosted,
AiBedrock, AiBedrock,
AiDeepSeek, AiDeepSeek,
AiEdit, AiEdit,

View file

@ -174,10 +174,6 @@ impl Default for LanguageModelTextStream {
pub trait LanguageModel: Send + Sync { pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId; fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName; fn name(&self) -> LanguageModelName;
/// If None, falls back to [LanguageModelProvider::icon]
fn icon(&self) -> Option<IconName> {
None
}
fn provider_id(&self) -> LanguageModelProviderId; fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName; fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String; fn telemetry_id(&self) -> String;
@ -304,6 +300,9 @@ pub trait LanguageModelProvider: 'static {
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>; fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>; fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
Vec::new()
}
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {} fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool; fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>; fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;

View file

@ -6,7 +6,6 @@ use client::Client;
use gpui::{ use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
}; };
use icons::IconName;
use proto::{Plan, TypedEnvelope}; use proto::{Plan, TypedEnvelope};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -53,13 +52,6 @@ impl CloudModel {
} }
} }
pub fn icon(&self) -> Option<IconName> {
match self {
Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
_ => None,
}
}
pub fn max_token_count(&self) -> usize { pub fn max_token_count(&self) -> usize {
match self { match self {
Self::Anthropic(model) => model.max_token_count(), Self::Anthropic(model) => model.max_token_count(),

View file

@ -12,6 +12,7 @@ workspace = true
path = "src/language_model_selector.rs" path = "src/language_model_selector.rs"
[dependencies] [dependencies]
collections.workspace = true
feature_flags.workspace = true feature_flags.workspace = true
gpui.workspace = true gpui.workspace = true
language_model.workspace = true language_model.workspace = true

View file

@ -1,12 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use collections::{HashSet, IndexMap};
use feature_flags::{Assistant2FeatureFlag, ZedPro}; use feature_flags::{Assistant2FeatureFlag, ZedPro};
use gpui::{ use gpui::{
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
}; };
use language_model::{ use language_model::{
AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry, AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
}; };
use picker::{Picker, PickerDelegate}; use picker::{Picker, PickerDelegate};
use proto::Plan; use proto::Plan;
@ -24,9 +25,6 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
pub struct LanguageModelSelector { pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>, picker: Entity<Picker<LanguageModelPickerDelegate>>,
/// The task used to update the picker's matches when there is a change to
/// the language model registry.
update_matches_task: Option<Task<()>>,
_authenticate_all_providers_task: Task<()>, _authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -40,16 +38,18 @@ impl LanguageModelSelector {
let on_model_changed = Arc::new(on_model_changed); let on_model_changed = Arc::new(on_model_changed);
let all_models = Self::all_models(cx); let all_models = Self::all_models(cx);
let entries = all_models.entries();
let delegate = LanguageModelPickerDelegate { let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.entity().downgrade(), language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(), on_model_changed: on_model_changed.clone(),
all_models: all_models.clone(), all_models: Arc::new(all_models),
filtered_models: all_models, selected_index: Self::get_active_model_index(&entries, cx),
selected_index: Self::get_active_model_index(cx), filtered_entries: entries,
}; };
let picker = cx.new(|cx| { let picker = cx.new(|cx| {
Picker::uniform_list(delegate, window, cx) Picker::list(delegate, window, cx)
.show_scrollbar(true) .show_scrollbar(true)
.width(rems(20.)) .width(rems(20.))
.max_height(Some(rems(20.).into())) .max_height(Some(rems(20.).into()))
@ -59,7 +59,6 @@ impl LanguageModelSelector {
LanguageModelSelector { LanguageModelSelector {
picker, picker,
update_matches_task: None,
_authenticate_all_providers_task: Self::authenticate_all_providers(cx), _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![ _subscriptions: vec![
cx.subscribe_in( cx.subscribe_in(
@ -83,12 +82,13 @@ impl LanguageModelSelector {
language_model::Event::ProviderStateChanged language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_) | language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => { | language_model::Event::RemovedProvider(_) => {
let task = self.picker.update(cx, |this, cx| { self.picker.update(cx, |this, cx| {
let query = this.query(cx); let query = this.query(cx);
this.delegate.all_models = Self::all_models(cx); this.delegate.all_models = Arc::new(Self::all_models(cx));
this.delegate.update_matches(query, window, cx) // Update matches will automatically drop the previous task
// if we get a provider event again
this.update_matches(query, window, cx)
}); });
self.update_matches_task = Some(task);
} }
_ => {} _ => {}
} }
@ -144,34 +144,72 @@ impl LanguageModelSelector {
}) })
} }
fn all_models(cx: &App) -> Vec<ModelInfo> { fn all_models(cx: &App) -> GroupedModels {
LanguageModelRegistry::global(cx) let mut recommended = Vec::new();
let mut recommended_set = HashSet::default();
for provider in LanguageModelRegistry::global(cx)
.read(cx) .read(cx)
.providers() .providers()
.iter() .iter()
.flat_map(|provider| { {
let icon = provider.icon(); let models = provider.recommended_models(cx);
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
provider.provided_models(cx).into_iter().map(move |model| { recommended.extend(
let model = model.clone(); provider
let icon = model.icon().unwrap_or(icon); .recommended_models(cx)
.into_iter()
ModelInfo { .map(move |model| ModelInfo {
model: model.clone(), model: model.clone(),
icon, icon: provider.icon(),
availability: model.availability(), }),
} );
}) }
let other_models = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.map(|provider| {
(
provider.id(),
provider
.provided_models(cx)
.into_iter()
.filter_map(|model| {
let not_included =
!recommended_set.contains(&(model.provider_id(), model.id()));
not_included.then(|| ModelInfo {
model: model.clone(),
icon: provider.icon(),
})
})
.collect::<Vec<_>>(),
)
}) })
.collect::<Vec<_>>() .collect::<IndexMap<_, _>>();
GroupedModels {
recommended,
other: other_models,
}
} }
fn get_active_model_index(cx: &App) -> usize { fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
let active_model = LanguageModelRegistry::read_global(cx).default_model(); let active_model = LanguageModelRegistry::read_global(cx).default_model();
Self::all_models(cx) entries
.iter() .iter()
.position(|model_info| { .position(|entry| {
Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id()) if let LanguageModelPickerEntry::Model(model) = entry {
active_model
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
&& active_model.model.provider_id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
false
}
}) })
.unwrap_or(0) .unwrap_or(0)
} }
@ -254,22 +292,61 @@ where
struct ModelInfo { struct ModelInfo {
model: Arc<dyn LanguageModel>, model: Arc<dyn LanguageModel>,
icon: IconName, icon: IconName,
availability: LanguageModelAvailability,
} }
pub struct LanguageModelPickerDelegate { pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>, language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged, on_model_changed: OnModelChanged,
all_models: Vec<ModelInfo>, all_models: Arc<GroupedModels>,
filtered_models: Vec<ModelInfo>, filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize, selected_index: usize,
} }
struct GroupedModels {
recommended: Vec<ModelInfo>,
other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
}
impl GroupedModels {
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
let mut entries = Vec::new();
if !self.recommended.is_empty() {
entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
entries.extend(
self.recommended
.iter()
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
);
}
for models in self.other.values() {
if models.is_empty() {
continue;
}
entries.push(LanguageModelPickerEntry::Separator(
models[0].model.provider_name().0,
));
entries.extend(
models
.iter()
.map(|info| LanguageModelPickerEntry::Model(info.clone())),
);
}
entries
}
}
enum LanguageModelPickerEntry {
Model(ModelInfo),
Separator(SharedString),
}
impl PickerDelegate for LanguageModelPickerDelegate { impl PickerDelegate for LanguageModelPickerDelegate {
type ListItem = ListItem; type ListItem = AnyElement;
fn match_count(&self) -> usize { fn match_count(&self) -> usize {
self.filtered_models.len() self.filtered_entries.len()
} }
fn selected_index(&self) -> usize { fn selected_index(&self) -> usize {
@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate {
} }
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) { fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1)); self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
cx.notify(); cx.notify();
} }
fn can_select(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
Some(LanguageModelPickerEntry::Model(_)) => true,
Some(LanguageModelPickerEntry::Separator(_)) | None => false,
}
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> { fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Select a model...".into() "Select a model".into()
} }
fn update_matches( fn update_matches(
@ -307,22 +396,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx.spawn_in(window, async move |this, cx| { cx.spawn_in(window, async move |this, cx| {
let filtered_models = cx let filtered_models = cx
.background_spawn(async move { .background_spawn(async move {
let displayed_models = if configured_providers.is_empty() { let filter_models = |model_infos: &[ModelInfo]| {
all_models model_infos
} else { .iter()
all_models
.into_iter()
.filter(|model_info| {
configured_providers.contains(&model_info.model.provider_id())
})
.collect::<Vec<_>>()
};
if query.is_empty() {
displayed_models
} else {
displayed_models
.into_iter()
.filter(|model_info| { .filter(|model_info| {
model_info model_info
.model .model
@ -331,20 +407,33 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.to_lowercase() .to_lowercase()
.contains(&query.to_lowercase()) .contains(&query.to_lowercase())
}) })
.collect() .cloned()
.collect::<Vec<_>>()
};
let recommended_models = filter_models(&all_models.recommended);
let mut other_models = IndexMap::default();
for (provider_id, models) in &all_models.other {
if configured_providers.contains(&provider_id) {
other_models.insert(provider_id.clone(), filter_models(models));
}
}
GroupedModels {
recommended: recommended_models,
other: other_models,
} }
}) })
.await; .await;
this.update_in(cx, |this, window, cx| { this.update_in(cx, |this, window, cx| {
this.delegate.filtered_models = filtered_models; this.delegate.filtered_entries = filtered_models.entries();
// Preserve selection focus // Preserve selection focus
let new_index = if current_index >= this.delegate.filtered_models.len() { let new_index = if current_index >= this.delegate.filtered_entries.len() {
0 0
} else { } else {
current_index current_index
}; };
this.delegate.set_selected_index(new_index, window, cx); this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify(); cx.notify();
}) })
.ok(); .ok();
@ -352,7 +441,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
} }
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) { fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(model_info) = self.filtered_models.get(self.selected_index) { if let Some(LanguageModelPickerEntry::Model(model_info)) =
self.filtered_entries.get(self.selected_index)
{
let model = model_info.model.clone(); let model = model_info.model.clone();
(self.on_model_changed)(model.clone(), cx); (self.on_model_changed)(model.clone(), cx);
@ -369,29 +460,6 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.ok(); .ok();
} }
fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
let configured_models_count = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
.filter(|provider| provider.is_authenticated(cx))
.count();
if configured_models_count > 0 {
Some(
Label::new("Configured Models")
.size(LabelSize::Small)
.color(Color::Muted)
.mt_1()
.mb_0p5()
.ml_2()
.into_any_element(),
)
} else {
None
}
}
fn render_match( fn render_match(
&self, &self,
ix: usize, ix: usize,
@ -399,77 +467,68 @@ impl PickerDelegate for LanguageModelPickerDelegate {
_: &mut Window, _: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> { ) -> Option<Self::ListItem> {
use feature_flags::FeatureFlagAppExt; match self.filtered_entries.get(ix)? {
let show_badges = cx.has_flag::<ZedPro>(); LanguageModelPickerEntry::Separator(title) => Some(
div()
.px_2()
.pb_1()
.when(ix > 1, |this| {
this.mt_1()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.child(
Label::new(title)
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
let active_model = LanguageModelRegistry::read_global(cx).default_model();
let model_info = self.filtered_models.get(ix)?; let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let provider_name: String = model_info.model.provider_name().0.clone().into(); let active_model_id = active_model.map(|m| m.model.id());
let active_model = LanguageModelRegistry::read_global(cx).default_model(); let is_selected = Some(model_info.model.provider_id()) == active_provider_id
&& Some(model_info.model.id()) == active_model_id;
let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let model_icon_color = if is_selected {
let active_model_id = active_model.map(|m| m.model.id()); Color::Accent
} else {
Color::Muted
};
let is_selected = Some(model_info.model.provider_id()) == active_provider_id Some(
&& Some(model_info.model.id()) == active_model_id; ListItem::new(ix)
.inset(true)
let model_icon_color = if is_selected { .spacing(ListItemSpacing::Sparse)
Color::Accent .toggle_state(selected)
} else { .start_slot(
Color::Muted Icon::new(model_info.icon)
}; .color(model_icon_color)
.size(IconSize::Small),
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.start_slot(
Icon::new(model_info.icon)
.color(model_icon_color)
.size(IconSize::Small),
)
.child(
h_flex()
.w_full()
.items_center()
.gap_1p5()
.pl_0p5()
.w(px(240.))
.child(
div()
.max_w_40()
.child(Label::new(model_info.model.name().0.clone()).truncate()),
) )
.child( .child(
h_flex() h_flex()
.gap_0p5() .w_full()
.child( .pl_0p5()
Label::new(provider_name) .gap_1p5()
.size(LabelSize::XSmall) .w(px(240.))
.color(Color::Muted), .child(Label::new(model_info.model.name().0.clone()).truncate()),
) )
.children(match model_info.availability { .end_slot(div().pr_3().when(is_selected, |this| {
LanguageModelAvailability::Public => None, this.child(
LanguageModelAvailability::RequiresPlan(Plan::Free) => None, Icon::new(IconName::Check)
LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => { .color(Color::Accent)
show_badges.then(|| { .size(IconSize::Small),
Label::new("Pro") )
.size(LabelSize::XSmall) }))
.color(Color::Muted) .into_any_element(),
})
}
}),
),
) )
.end_slot(div().pr_3().when(is_selected, |this| { }
this.child( }
Icon::new(IconName::Check)
.color(Color::Accent)
.size(IconSize::Small),
)
})),
)
} }
fn render_footer( fn render_footer(

View file

@ -192,6 +192,16 @@ impl AnthropicLanguageModelProvider {
Self { http_client, state } Self { http_client, state }
} }
fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
}
} }
impl LanguageModelProviderState for AnthropicLanguageModelProvider { impl LanguageModelProviderState for AnthropicLanguageModelProvider {
@ -226,6 +236,16 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
})) }))
} }
fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
[
anthropic::Model::Claude3_7Sonnet,
anthropic::Model::Claude3_7SonnetThinking,
]
.into_iter()
.map(|model| self.create_language_model(model))
.collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();
@ -266,15 +286,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
models models
.into_values() .into_values()
.map(|model| { .map(|model| self.create_language_model(model))
Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }

View file

@ -225,6 +225,20 @@ impl CloudLanguageModelProvider {
_maintain_client_status: maintain_client_status, _maintain_client_status: maintain_client_status,
} }
} }
fn create_language_model(
&self,
model: CloudModel,
llm_api_token: LlmApiToken,
) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
}
} }
impl LanguageModelProviderState for CloudLanguageModelProvider { impl LanguageModelProviderState for CloudLanguageModelProvider {
@ -260,6 +274,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
})) }))
} }
fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
[
CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
]
.into_iter()
.map(|model| self.create_language_model(model, llm_api_token.clone()))
.collect()
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default(); let mut models = BTreeMap::default();
@ -345,15 +370,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
let llm_api_token = self.state.read(cx).llm_api_token.clone(); let llm_api_token = self.state.read(cx).llm_api_token.clone();
models models
.into_values() .into_values()
.map(|model| { .map(|model| self.create_language_model(model, llm_api_token.clone()))
Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect() .collect()
} }
@ -575,10 +592,6 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string()) LanguageModelName::from(self.model.display_name().to_string())
} }
fn icon(&self) -> Option<IconName> {
self.model.icon()
}
fn provider_id(&self) -> LanguageModelProviderId { fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
} }

View file

@ -3,8 +3,8 @@ use editor::{Editor, scroll::Autoscroll};
use gpui::{ use gpui::{
AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle, AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render, Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render,
ScrollHandle, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, impl_actions,
impl_actions, list, prelude::*, uniform_list, list, prelude::*, uniform_list,
}; };
use head::Head; use head::Head;
use schemars::JsonSchema; use schemars::JsonSchema;
@ -24,6 +24,11 @@ enum ElementContainer {
UniformList(UniformListScrollHandle), UniformList(UniformListScrollHandle),
} }
pub enum Direction {
Up,
Down,
}
actions!(picker, [ConfirmCompletion]); actions!(picker, [ConfirmCompletion]);
/// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally, /// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally,
@ -86,6 +91,15 @@ pub trait PickerDelegate: Sized + 'static {
window: &mut Window, window: &mut Window,
cx: &mut Context<Picker<Self>>, cx: &mut Context<Picker<Self>>,
); );
fn can_select(
&mut self,
_ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> bool {
true
}
// Allows binding some optional effect to when the selection changes. // Allows binding some optional effect to when the selection changes.
fn selected_index_changed( fn selected_index_changed(
&self, &self,
@ -271,10 +285,7 @@ impl<D: PickerDelegate> Picker<D> {
ElementContainer::UniformList(scroll_handle) => { ElementContainer::UniformList(scroll_handle) => {
ScrollbarState::new(scroll_handle.clone()) ScrollbarState::new(scroll_handle.clone())
} }
ElementContainer::List(_) => { ElementContainer::List(state) => ScrollbarState::new(state.clone()),
// todo smit: implement for list
ScrollbarState::new(ScrollHandle::new())
}
}; };
let focus_handle = cx.focus_handle(); let focus_handle = cx.focus_handle();
let mut this = Self { let mut this = Self {
@ -359,16 +370,58 @@ impl<D: PickerDelegate> Picker<D> {
} }
/// Handles the selecting an index, and passing the change to the delegate. /// Handles the selecting an index, and passing the change to the delegate.
/// If `scroll_to_index` is true, the new selected index will be scrolled into view. /// If `fallback_direction` is set to `None`, the index will not be selected
/// if the element at that index cannot be selected.
/// If `fallback_direction` is set to
/// `Some(..)`, the next selectable element will be selected in the
/// specified direction (Down or Up), cycling through all elements until
/// finding one that can be selected or returning if there are no selectable elements.
/// If `scroll_to_index` is true, the new selected index will be scrolled into
/// view.
/// ///
/// If some effect is bound to `selected_index_changed`, it will be executed. /// If some effect is bound to `selected_index_changed`, it will be executed.
pub fn set_selected_index( pub fn set_selected_index(
&mut self, &mut self,
ix: usize, mut ix: usize,
fallback_direction: Option<Direction>,
scroll_to_index: bool, scroll_to_index: bool,
window: &mut Window, window: &mut Window,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let match_count = self.delegate.match_count();
if match_count == 0 {
return;
}
if let Some(bias) = fallback_direction {
let mut curr_ix = ix;
while !self.delegate.can_select(curr_ix, window, cx) {
curr_ix = match bias {
Direction::Down => {
if curr_ix == match_count - 1 {
0
} else {
curr_ix + 1
}
}
Direction::Up => {
if curr_ix == 0 {
match_count - 1
} else {
curr_ix - 1
}
}
};
// There is no item that can be selected
if ix == curr_ix {
return;
}
}
ix = curr_ix;
} else if !self.delegate.can_select(ix, window, cx) {
return;
}
let previous_index = self.delegate.selected_index(); let previous_index = self.delegate.selected_index();
self.delegate.set_selected_index(ix, window, cx); self.delegate.set_selected_index(ix, window, cx);
let current_index = self.delegate.selected_index(); let current_index = self.delegate.selected_index();
@ -393,7 +446,7 @@ impl<D: PickerDelegate> Picker<D> {
if count > 0 { if count > 0 {
let index = self.delegate.selected_index(); let index = self.delegate.selected_index();
let ix = if index == count - 1 { 0 } else { index + 1 }; let ix = if index == count - 1 { 0 } else { index + 1 };
self.set_selected_index(ix, true, window, cx); self.set_selected_index(ix, Some(Direction::Down), true, window, cx);
cx.notify(); cx.notify();
} }
} }
@ -408,7 +461,7 @@ impl<D: PickerDelegate> Picker<D> {
if count > 0 { if count > 0 {
let index = self.delegate.selected_index(); let index = self.delegate.selected_index();
let ix = if index == 0 { count - 1 } else { index - 1 }; let ix = if index == 0 { count - 1 } else { index - 1 };
self.set_selected_index(ix, true, window, cx); self.set_selected_index(ix, Some(Direction::Up), true, window, cx);
cx.notify(); cx.notify();
} }
} }
@ -416,7 +469,7 @@ impl<D: PickerDelegate> Picker<D> {
fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context<Self>) { fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context<Self>) {
let count = self.delegate.match_count(); let count = self.delegate.match_count();
if count > 0 { if count > 0 {
self.set_selected_index(0, true, window, cx); self.set_selected_index(0, Some(Direction::Down), true, window, cx);
cx.notify(); cx.notify();
} }
} }
@ -424,7 +477,7 @@ impl<D: PickerDelegate> Picker<D> {
fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context<Self>) { fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context<Self>) {
let count = self.delegate.match_count(); let count = self.delegate.match_count();
if count > 0 { if count > 0 {
self.set_selected_index(count - 1, true, window, cx); self.set_selected_index(count - 1, Some(Direction::Up), true, window, cx);
cx.notify(); cx.notify();
} }
} }
@ -433,7 +486,7 @@ impl<D: PickerDelegate> Picker<D> {
let count = self.delegate.match_count(); let count = self.delegate.match_count();
let index = self.delegate.selected_index(); let index = self.delegate.selected_index();
let new_index = if index + 1 == count { 0 } else { index + 1 }; let new_index = if index + 1 == count { 0 } else { index + 1 };
self.set_selected_index(new_index, true, window, cx); self.set_selected_index(new_index, Some(Direction::Down), true, window, cx);
cx.notify(); cx.notify();
} }
@ -506,14 +559,14 @@ impl<D: PickerDelegate> Picker<D> {
) { ) {
cx.stop_propagation(); cx.stop_propagation();
window.prevent_default(); window.prevent_default();
self.set_selected_index(ix, false, window, cx); self.set_selected_index(ix, None, false, window, cx);
self.do_confirm(secondary, window, cx) self.do_confirm(secondary, window, cx)
} }
fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Self>) { fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Self>) {
if let Some(update_query) = self.delegate.confirm_update_query(window, cx) { if let Some(update_query) = self.delegate.confirm_update_query(window, cx) {
self.set_query(update_query, window, cx); self.set_query(update_query, window, cx);
self.delegate.set_selected_index(0, window, cx); self.set_selected_index(0, Some(Direction::Down), false, window, cx);
} else { } else {
self.delegate.confirm(secondary, window, cx) self.delegate.confirm(secondary, window, cx)
} }

View file

@ -657,7 +657,7 @@ impl PromptLibrary {
.iter() .iter()
.position(|mat| mat.id == prompt_id) .position(|mat| mat.id == prompt_id)
{ {
picker.set_selected_index(ix, true, window, cx); picker.set_selected_index(ix, None, true, window, cx);
} }
} }
} else { } else {