diff --git a/Cargo.lock b/Cargo.lock
index fc906cf259..9af249804c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -7657,6 +7657,7 @@ dependencies = [
name = "language_model_selector"
version = "0.1.0"
dependencies = [
+ "collections",
"feature_flags",
"gpui",
"language_model",
diff --git a/assets/icons/ai_anthropic_hosted.svg b/assets/icons/ai_anthropic_hosted.svg
deleted file mode 100644
index b088520490..0000000000
--- a/assets/icons/ai_anthropic_hosted.svg
+++ /dev/null
@@ -1,12 +0,0 @@
-
diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs
index 11726b2574..091071af29 100644
--- a/crates/agent/src/assistant_model_selector.rs
+++ b/crates/agent/src/assistant_model_selector.rs
@@ -80,17 +80,16 @@ impl AssistantModelSelector {
impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement {
- let model_registry = LanguageModelRegistry::read_global(cx);
+ let focus_handle = self.focus_handle.clone();
+ let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
-
- let focus_handle = self.focus_handle.clone();
- let model_name = match model {
- Some(model) => model.model.name().0,
- _ => SharedString::from("No model selected"),
+ let (model_name, model_icon) = match model {
+ Some(model) => (model.model.name().0, Some(model.provider.icon())),
+ _ => (SharedString::from("No model selected"), None),
};
LanguageModelSelectorPopoverMenu::new(
@@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
.child(
h_flex()
.gap_0p5()
+ .children(
+ model_icon.map(|icon| {
+ Icon::new(icon).color(Color::Muted).size(IconSize::Small)
+ }),
+ )
.child(
Label::new(model_name)
.size(LabelSize::Small)
- .color(Color::Muted),
+ .color(Color::Muted)
+ .ml_1(),
)
.child(
Icon::new(IconName::ChevronDown)
diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs
index d5d3582858..d2a5f1402d 100644
--- a/crates/file_finder/src/file_finder_tests.rs
+++ b/crates/file_finder/src/file_finder_tests.rs
@@ -2133,18 +2133,28 @@ async fn test_repeat_toggle_action(cx: &mut gpui::TestAppContext) {
cx.dispatch_action(ToggleFileFinder::default());
let picker = active_file_picker(&workspace, cx);
+
+ picker.update_in(cx, |picker, window, cx| {
+ picker.update_matches(".txt".to_string(), window, cx)
+ });
+
+ cx.run_until_parked();
+
picker.update(cx, |picker, _| {
+ assert_eq!(picker.delegate.matches.len(), 6);
assert_eq!(picker.delegate.selected_index, 0);
- assert_eq!(picker.logical_scroll_top_index(), 0);
});
// When toggling repeatedly, the picker scrolls to reveal the selected item.
cx.dispatch_action(ToggleFileFinder::default());
cx.dispatch_action(ToggleFileFinder::default());
cx.dispatch_action(ToggleFileFinder::default());
+
+ cx.run_until_parked();
+
picker.update(cx, |picker, _| {
+ assert_eq!(picker.delegate.matches.len(), 6);
assert_eq!(picker.delegate.selected_index, 3);
- assert_eq!(picker.logical_scroll_top_index(), 3);
});
}
diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs
index 6c448c03ed..d7f4a820da 100644
--- a/crates/icons/src/icons.rs
+++ b/crates/icons/src/icons.rs
@@ -10,7 +10,6 @@ use strum::{EnumIter, EnumString, IntoStaticStr};
pub enum IconName {
Ai,
AiAnthropic,
- AiAnthropicHosted,
AiBedrock,
AiDeepSeek,
AiEdit,
diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs
index aa060f7b30..98456e7db4 100644
--- a/crates/language_model/src/language_model.rs
+++ b/crates/language_model/src/language_model.rs
@@ -174,10 +174,6 @@ impl Default for LanguageModelTextStream {
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
- /// If None, falls back to [LanguageModelProvider::icon]
- fn icon(&self) -> Option {
- None
- }
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
@@ -304,6 +300,9 @@ pub trait LanguageModelProvider: 'static {
}
fn default_model(&self, cx: &App) -> Option>;
fn provided_models(&self, cx: &App) -> Vec>;
+ fn recommended_models(&self, _cx: &App) -> Vec> {
+ Vec::new()
+ }
fn load_model(&self, _model: Arc, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool;
fn authenticate(&self, cx: &mut App) -> Task>;
diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs
index e5c66670d8..cc15ce3364 100644
--- a/crates/language_model/src/model/cloud_model.rs
+++ b/crates/language_model/src/model/cloud_model.rs
@@ -6,7 +6,6 @@ use client::Client;
use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
};
-use icons::IconName;
use proto::{Plan, TypedEnvelope};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -53,13 +52,6 @@ impl CloudModel {
}
}
- pub fn icon(&self) -> Option {
- match self {
- Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
- _ => None,
- }
- }
-
pub fn max_token_count(&self) -> usize {
match self {
Self::Anthropic(model) => model.max_token_count(),
diff --git a/crates/language_model_selector/Cargo.toml b/crates/language_model_selector/Cargo.toml
index 1257ae564c..39bc8a59f9 100644
--- a/crates/language_model_selector/Cargo.toml
+++ b/crates/language_model_selector/Cargo.toml
@@ -12,6 +12,7 @@ workspace = true
path = "src/language_model_selector.rs"
[dependencies]
+collections.workspace = true
feature_flags.workspace = true
gpui.workspace = true
language_model.workspace = true
diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs
index 90747a01f3..7f18b4d9fd 100644
--- a/crates/language_model_selector/src/language_model_selector.rs
+++ b/crates/language_model_selector/src/language_model_selector.rs
@@ -1,12 +1,13 @@
use std::sync::Arc;
+use collections::{HashSet, IndexMap};
use feature_flags::{Assistant2FeatureFlag, ZedPro};
use gpui::{
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
};
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
+ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use picker::{Picker, PickerDelegate};
use proto::Plan;
@@ -24,9 +25,6 @@ type OnModelChanged = Arc, &App) + 'static>;
pub struct LanguageModelSelector {
picker: Entity>,
- /// The task used to update the picker's matches when there is a change to
- /// the language model registry.
- update_matches_task: Option>,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec,
}
@@ -40,16 +38,18 @@ impl LanguageModelSelector {
let on_model_changed = Arc::new(on_model_changed);
let all_models = Self::all_models(cx);
+ let entries = all_models.entries();
+
let delegate = LanguageModelPickerDelegate {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
- all_models: all_models.clone(),
- filtered_models: all_models,
- selected_index: Self::get_active_model_index(cx),
+ all_models: Arc::new(all_models),
+ selected_index: Self::get_active_model_index(&entries, cx),
+ filtered_entries: entries,
};
let picker = cx.new(|cx| {
- Picker::uniform_list(delegate, window, cx)
+ Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
@@ -59,7 +59,6 @@ impl LanguageModelSelector {
LanguageModelSelector {
picker,
- update_matches_task: None,
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![
cx.subscribe_in(
@@ -83,12 +82,13 @@ impl LanguageModelSelector {
language_model::Event::ProviderStateChanged
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
- let task = self.picker.update(cx, |this, cx| {
+ self.picker.update(cx, |this, cx| {
let query = this.query(cx);
- this.delegate.all_models = Self::all_models(cx);
- this.delegate.update_matches(query, window, cx)
+ this.delegate.all_models = Arc::new(Self::all_models(cx));
+ // Update matches will automatically drop the previous task
+ // if we get a provider event again
+ this.update_matches(query, window, cx)
});
- self.update_matches_task = Some(task);
}
_ => {}
}
@@ -144,34 +144,72 @@ impl LanguageModelSelector {
})
}
- fn all_models(cx: &App) -> Vec {
- LanguageModelRegistry::global(cx)
+ fn all_models(cx: &App) -> GroupedModels {
+ let mut recommended = Vec::new();
+ let mut recommended_set = HashSet::default();
+ for provider in LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
- .flat_map(|provider| {
- let icon = provider.icon();
-
- provider.provided_models(cx).into_iter().map(move |model| {
- let model = model.clone();
- let icon = model.icon().unwrap_or(icon);
-
- ModelInfo {
+ {
+ let models = provider.recommended_models(cx);
+ recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
+ recommended.extend(
+ provider
+ .recommended_models(cx)
+ .into_iter()
+ .map(move |model| ModelInfo {
model: model.clone(),
- icon,
- availability: model.availability(),
- }
- })
+ icon: provider.icon(),
+ }),
+ );
+ }
+
+ let other_models = LanguageModelRegistry::global(cx)
+ .read(cx)
+ .providers()
+ .iter()
+ .map(|provider| {
+ (
+ provider.id(),
+ provider
+ .provided_models(cx)
+ .into_iter()
+ .filter_map(|model| {
+ let not_included =
+ !recommended_set.contains(&(model.provider_id(), model.id()));
+ not_included.then(|| ModelInfo {
+ model: model.clone(),
+ icon: provider.icon(),
+ })
+ })
+ .collect::>(),
+ )
})
- .collect::>()
+ .collect::>();
+
+ GroupedModels {
+ recommended,
+ other: other_models,
+ }
}
- fn get_active_model_index(cx: &App) -> usize {
+ fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
let active_model = LanguageModelRegistry::read_global(cx).default_model();
- Self::all_models(cx)
+ entries
.iter()
- .position(|model_info| {
- Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
+ .position(|entry| {
+ if let LanguageModelPickerEntry::Model(model) = entry {
+ active_model
+ .as_ref()
+ .map(|active_model| {
+ active_model.model.id() == model.model.id()
+ && active_model.model.provider_id() == model.model.provider_id()
+ })
+ .unwrap_or_default()
+ } else {
+ false
+ }
})
.unwrap_or(0)
}
@@ -254,22 +292,61 @@ where
struct ModelInfo {
model: Arc,
icon: IconName,
- availability: LanguageModelAvailability,
}
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity,
on_model_changed: OnModelChanged,
- all_models: Vec,
- filtered_models: Vec,
+ all_models: Arc,
+ filtered_entries: Vec,
selected_index: usize,
}
+struct GroupedModels {
+ recommended: Vec,
+ other: IndexMap>,
+}
+
+impl GroupedModels {
+ fn entries(&self) -> Vec {
+ let mut entries = Vec::new();
+
+ if !self.recommended.is_empty() {
+ entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
+ entries.extend(
+ self.recommended
+ .iter()
+ .map(|info| LanguageModelPickerEntry::Model(info.clone())),
+ );
+ }
+
+ for models in self.other.values() {
+ if models.is_empty() {
+ continue;
+ }
+ entries.push(LanguageModelPickerEntry::Separator(
+ models[0].model.provider_name().0,
+ ));
+ entries.extend(
+ models
+ .iter()
+ .map(|info| LanguageModelPickerEntry::Model(info.clone())),
+ );
+ }
+ entries
+ }
+}
+
+enum LanguageModelPickerEntry {
+ Model(ModelInfo),
+ Separator(SharedString),
+}
+
impl PickerDelegate for LanguageModelPickerDelegate {
- type ListItem = ListItem;
+ type ListItem = AnyElement;
fn match_count(&self) -> usize {
- self.filtered_models.len()
+ self.filtered_entries.len()
}
fn selected_index(&self) -> usize {
@@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate {
}
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) {
- self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
+ self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
cx.notify();
}
+ fn can_select(
+ &mut self,
+ ix: usize,
+ _window: &mut Window,
+ _cx: &mut Context>,
+ ) -> bool {
+ match self.filtered_entries.get(ix) {
+ Some(LanguageModelPickerEntry::Model(_)) => true,
+ Some(LanguageModelPickerEntry::Separator(_)) | None => false,
+ }
+ }
+
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc {
- "Select a model...".into()
+ "Select a model…".into()
}
fn update_matches(
@@ -307,22 +396,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx.spawn_in(window, async move |this, cx| {
let filtered_models = cx
.background_spawn(async move {
- let displayed_models = if configured_providers.is_empty() {
- all_models
- } else {
- all_models
- .into_iter()
- .filter(|model_info| {
- configured_providers.contains(&model_info.model.provider_id())
- })
- .collect::>()
- };
-
- if query.is_empty() {
- displayed_models
- } else {
- displayed_models
- .into_iter()
+ let filter_models = |model_infos: &[ModelInfo]| {
+ model_infos
+ .iter()
.filter(|model_info| {
model_info
.model
@@ -331,20 +407,33 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.to_lowercase()
.contains(&query.to_lowercase())
})
- .collect()
+ .cloned()
+ .collect::>()
+ };
+
+ let recommended_models = filter_models(&all_models.recommended);
+ let mut other_models = IndexMap::default();
+ for (provider_id, models) in &all_models.other {
+ if configured_providers.contains(&provider_id) {
+ other_models.insert(provider_id.clone(), filter_models(models));
+ }
+ }
+ GroupedModels {
+ recommended: recommended_models,
+ other: other_models,
}
})
.await;
this.update_in(cx, |this, window, cx| {
- this.delegate.filtered_models = filtered_models;
+ this.delegate.filtered_entries = filtered_models.entries();
// Preserve selection focus
- let new_index = if current_index >= this.delegate.filtered_models.len() {
+ let new_index = if current_index >= this.delegate.filtered_entries.len() {
0
} else {
current_index
};
- this.delegate.set_selected_index(new_index, window, cx);
+ this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify();
})
.ok();
@@ -352,7 +441,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) {
- if let Some(model_info) = self.filtered_models.get(self.selected_index) {
+ if let Some(LanguageModelPickerEntry::Model(model_info)) =
+ self.filtered_entries.get(self.selected_index)
+ {
let model = model_info.model.clone();
(self.on_model_changed)(model.clone(), cx);
@@ -369,29 +460,6 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.ok();
}
- fn render_header(&self, _: &mut Window, cx: &mut Context>) -> Option {
- let configured_models_count = LanguageModelRegistry::global(cx)
- .read(cx)
- .providers()
- .iter()
- .filter(|provider| provider.is_authenticated(cx))
- .count();
-
- if configured_models_count > 0 {
- Some(
- Label::new("Configured Models")
- .size(LabelSize::Small)
- .color(Color::Muted)
- .mt_1()
- .mb_0p5()
- .ml_2()
- .into_any_element(),
- )
- } else {
- None
- }
- }
-
fn render_match(
&self,
ix: usize,
@@ -399,77 +467,68 @@ impl PickerDelegate for LanguageModelPickerDelegate {
_: &mut Window,
cx: &mut Context>,
) -> Option {
- use feature_flags::FeatureFlagAppExt;
- let show_badges = cx.has_flag::();
+ match self.filtered_entries.get(ix)? {
+ LanguageModelPickerEntry::Separator(title) => Some(
+ div()
+ .px_2()
+ .pb_1()
+ .when(ix > 1, |this| {
+ this.mt_1()
+ .pt_2()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ })
+ .child(
+ Label::new(title)
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ )
+ .into_any_element(),
+ ),
+ LanguageModelPickerEntry::Model(model_info) => {
+ let active_model = LanguageModelRegistry::read_global(cx).default_model();
- let model_info = self.filtered_models.get(ix)?;
- let provider_name: String = model_info.model.provider_name().0.clone().into();
+ let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
+ let active_model_id = active_model.map(|m| m.model.id());
- let active_model = LanguageModelRegistry::read_global(cx).default_model();
+ let is_selected = Some(model_info.model.provider_id()) == active_provider_id
+ && Some(model_info.model.id()) == active_model_id;
- let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
- let active_model_id = active_model.map(|m| m.model.id());
+ let model_icon_color = if is_selected {
+ Color::Accent
+ } else {
+ Color::Muted
+ };
- let is_selected = Some(model_info.model.provider_id()) == active_provider_id
- && Some(model_info.model.id()) == active_model_id;
-
- let model_icon_color = if is_selected {
- Color::Accent
- } else {
- Color::Muted
- };
-
- Some(
- ListItem::new(ix)
- .inset(true)
- .spacing(ListItemSpacing::Sparse)
- .toggle_state(selected)
- .start_slot(
- Icon::new(model_info.icon)
- .color(model_icon_color)
- .size(IconSize::Small),
- )
- .child(
- h_flex()
- .w_full()
- .items_center()
- .gap_1p5()
- .pl_0p5()
- .w(px(240.))
- .child(
- div()
- .max_w_40()
- .child(Label::new(model_info.model.name().0.clone()).truncate()),
+ Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .toggle_state(selected)
+ .start_slot(
+ Icon::new(model_info.icon)
+ .color(model_icon_color)
+ .size(IconSize::Small),
)
.child(
h_flex()
- .gap_0p5()
- .child(
- Label::new(provider_name)
- .size(LabelSize::XSmall)
- .color(Color::Muted),
- )
- .children(match model_info.availability {
- LanguageModelAvailability::Public => None,
- LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
- LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
- show_badges.then(|| {
- Label::new("Pro")
- .size(LabelSize::XSmall)
- .color(Color::Muted)
- })
- }
- }),
- ),
+ .w_full()
+ .pl_0p5()
+ .gap_1p5()
+ .w(px(240.))
+ .child(Label::new(model_info.model.name().0.clone()).truncate()),
+ )
+ .end_slot(div().pr_3().when(is_selected, |this| {
+ this.child(
+ Icon::new(IconName::Check)
+ .color(Color::Accent)
+ .size(IconSize::Small),
+ )
+ }))
+ .into_any_element(),
)
- .end_slot(div().pr_3().when(is_selected, |this| {
- this.child(
- Icon::new(IconName::Check)
- .color(Color::Accent)
- .size(IconSize::Small),
- )
- })),
- )
+ }
+ }
}
fn render_footer(
diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs
index bce985a872..4540a08268 100644
--- a/crates/language_models/src/provider/anthropic.rs
+++ b/crates/language_models/src/provider/anthropic.rs
@@ -192,6 +192,16 @@ impl AnthropicLanguageModelProvider {
Self { http_client, state }
}
+
+ fn create_language_model(&self, model: anthropic::Model) -> Arc {
+ Arc::new(AnthropicModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ }) as Arc
+ }
}
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
@@ -226,6 +236,16 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}))
}
+ fn recommended_models(&self, _cx: &App) -> Vec> {
+ [
+ anthropic::Model::Claude3_7Sonnet,
+ anthropic::Model::Claude3_7SonnetThinking,
+ ]
+ .into_iter()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
fn provided_models(&self, cx: &App) -> Vec> {
let mut models = BTreeMap::default();
@@ -266,15 +286,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
models
.into_values()
- .map(|model| {
- Arc::new(AnthropicModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc
- })
+ .map(|model| self.create_language_model(model))
.collect()
}
diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs
index 9377bf315f..6a08f48522 100644
--- a/crates/language_models/src/provider/cloud.rs
+++ b/crates/language_models/src/provider/cloud.rs
@@ -225,6 +225,20 @@ impl CloudLanguageModelProvider {
_maintain_client_status: maintain_client_status,
}
}
+
+ fn create_language_model(
+ &self,
+ model: CloudModel,
+ llm_api_token: LlmApiToken,
+ ) -> Arc {
+ Arc::new(CloudLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ llm_api_token: llm_api_token.clone(),
+ client: self.client.clone(),
+ request_limiter: RateLimiter::new(4),
+ }) as Arc
+ }
}
impl LanguageModelProviderState for CloudLanguageModelProvider {
@@ -260,6 +274,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}))
}
+ fn recommended_models(&self, cx: &App) -> Vec> {
+ let llm_api_token = self.state.read(cx).llm_api_token.clone();
+ [
+ CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
+ CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
+ ]
+ .into_iter()
+ .map(|model| self.create_language_model(model, llm_api_token.clone()))
+ .collect()
+ }
+
fn provided_models(&self, cx: &App) -> Vec> {
let mut models = BTreeMap::default();
@@ -345,15 +370,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
models
.into_values()
- .map(|model| {
- Arc::new(CloudLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- llm_api_token: llm_api_token.clone(),
- client: self.client.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc
- })
+ .map(|model| self.create_language_model(model, llm_api_token.clone()))
.collect()
}
@@ -575,10 +592,6 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
- fn icon(&self) -> Option {
- self.model.icon()
- }
-
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
}
diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs
index 2caa9ff756..54b50453ce 100644
--- a/crates/picker/src/picker.rs
+++ b/crates/picker/src/picker.rs
@@ -3,8 +3,8 @@ use editor::{Editor, scroll::Autoscroll};
use gpui::{
AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle,
Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render,
- ScrollHandle, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div,
- impl_actions, list, prelude::*, uniform_list,
+ ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, impl_actions,
+ list, prelude::*, uniform_list,
};
use head::Head;
use schemars::JsonSchema;
@@ -24,6 +24,11 @@ enum ElementContainer {
UniformList(UniformListScrollHandle),
}
+pub enum Direction {
+ Up,
+ Down,
+}
+
actions!(picker, [ConfirmCompletion]);
/// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally,
@@ -86,6 +91,15 @@ pub trait PickerDelegate: Sized + 'static {
window: &mut Window,
cx: &mut Context>,
);
+ fn can_select(
+ &mut self,
+ _ix: usize,
+ _window: &mut Window,
+ _cx: &mut Context>,
+ ) -> bool {
+ true
+ }
+
// Allows binding some optional effect to when the selection changes.
fn selected_index_changed(
&self,
@@ -271,10 +285,7 @@ impl Picker {
ElementContainer::UniformList(scroll_handle) => {
ScrollbarState::new(scroll_handle.clone())
}
- ElementContainer::List(_) => {
- // todo smit: implement for list
- ScrollbarState::new(ScrollHandle::new())
- }
+ ElementContainer::List(state) => ScrollbarState::new(state.clone()),
};
let focus_handle = cx.focus_handle();
let mut this = Self {
@@ -359,16 +370,58 @@ impl Picker {
}
/// Handles the selecting an index, and passing the change to the delegate.
- /// If `scroll_to_index` is true, the new selected index will be scrolled into view.
+ /// If `fallback_direction` is set to `None`, the index will not be selected
+ /// if the element at that index cannot be selected.
+ /// If `fallback_direction` is set to
+ /// `Some(..)`, the next selectable element will be selected in the
+ /// specified direction (Down or Up), cycling through all elements until
+ /// finding one that can be selected or returning if there are no selectable elements.
+ /// If `scroll_to_index` is true, the new selected index will be scrolled into
+ /// view.
///
/// If some effect is bound to `selected_index_changed`, it will be executed.
pub fn set_selected_index(
&mut self,
- ix: usize,
+ mut ix: usize,
+ fallback_direction: Option,
scroll_to_index: bool,
window: &mut Window,
cx: &mut Context,
) {
+ let match_count = self.delegate.match_count();
+ if match_count == 0 {
+ return;
+ }
+
+ if let Some(bias) = fallback_direction {
+ let mut curr_ix = ix;
+ while !self.delegate.can_select(curr_ix, window, cx) {
+ curr_ix = match bias {
+ Direction::Down => {
+ if curr_ix == match_count - 1 {
+ 0
+ } else {
+ curr_ix + 1
+ }
+ }
+ Direction::Up => {
+ if curr_ix == 0 {
+ match_count - 1
+ } else {
+ curr_ix - 1
+ }
+ }
+ };
+ // There is no item that can be selected
+ if ix == curr_ix {
+ return;
+ }
+ }
+ ix = curr_ix;
+ } else if !self.delegate.can_select(ix, window, cx) {
+ return;
+ }
+
let previous_index = self.delegate.selected_index();
self.delegate.set_selected_index(ix, window, cx);
let current_index = self.delegate.selected_index();
@@ -393,7 +446,7 @@ impl Picker {
if count > 0 {
let index = self.delegate.selected_index();
let ix = if index == count - 1 { 0 } else { index + 1 };
- self.set_selected_index(ix, true, window, cx);
+ self.set_selected_index(ix, Some(Direction::Down), true, window, cx);
cx.notify();
}
}
@@ -408,7 +461,7 @@ impl Picker {
if count > 0 {
let index = self.delegate.selected_index();
let ix = if index == 0 { count - 1 } else { index - 1 };
- self.set_selected_index(ix, true, window, cx);
+ self.set_selected_index(ix, Some(Direction::Up), true, window, cx);
cx.notify();
}
}
@@ -416,7 +469,7 @@ impl Picker {
fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context) {
let count = self.delegate.match_count();
if count > 0 {
- self.set_selected_index(0, true, window, cx);
+ self.set_selected_index(0, Some(Direction::Down), true, window, cx);
cx.notify();
}
}
@@ -424,7 +477,7 @@ impl Picker {
fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context) {
let count = self.delegate.match_count();
if count > 0 {
- self.set_selected_index(count - 1, true, window, cx);
+ self.set_selected_index(count - 1, Some(Direction::Up), true, window, cx);
cx.notify();
}
}
@@ -433,7 +486,7 @@ impl Picker {
let count = self.delegate.match_count();
let index = self.delegate.selected_index();
let new_index = if index + 1 == count { 0 } else { index + 1 };
- self.set_selected_index(new_index, true, window, cx);
+ self.set_selected_index(new_index, Some(Direction::Down), true, window, cx);
cx.notify();
}
@@ -506,14 +559,14 @@ impl Picker {
) {
cx.stop_propagation();
window.prevent_default();
- self.set_selected_index(ix, false, window, cx);
+ self.set_selected_index(ix, None, false, window, cx);
self.do_confirm(secondary, window, cx)
}
fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context) {
if let Some(update_query) = self.delegate.confirm_update_query(window, cx) {
self.set_query(update_query, window, cx);
- self.delegate.set_selected_index(0, window, cx);
+ self.set_selected_index(0, Some(Direction::Down), false, window, cx);
} else {
self.delegate.confirm(secondary, window, cx)
}
diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs
index c2c1f3da60..7fff6d1258 100644
--- a/crates/prompt_library/src/prompt_library.rs
+++ b/crates/prompt_library/src/prompt_library.rs
@@ -657,7 +657,7 @@ impl PromptLibrary {
.iter()
.position(|mat| mat.id == prompt_id)
{
- picker.set_selected_index(ix, true, window, cx);
+ picker.set_selected_index(ix, None, true, window, cx);
}
}
} else {