From b22faf96e00297f43e362e5e274b4d2e817a9682 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 11 Apr 2025 17:02:50 -0600 Subject: [PATCH] agent: Refine language model selector (#28597) Release Notes: - agent: Show recommended models in the agent model selector and display the provider in the model selector's trigger. --------- Co-authored-by: Danilo Leal Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com> --- Cargo.lock | 1 + assets/icons/ai_anthropic_hosted.svg | 12 - crates/agent/src/assistant_model_selector.rs | 19 +- crates/file_finder/src/file_finder_tests.rs | 14 +- crates/icons/src/icons.rs | 1 - crates/language_model/src/language_model.rs | 7 +- .../language_model/src/model/cloud_model.rs | 8 - crates/language_model_selector/Cargo.toml | 1 + .../src/language_model_selector.rs | 351 ++++++++++-------- .../language_models/src/provider/anthropic.rs | 30 +- crates/language_models/src/provider/cloud.rs | 39 +- crates/picker/src/picker.rs | 83 ++++- crates/prompt_library/src/prompt_library.rs | 2 +- 13 files changed, 350 insertions(+), 218 deletions(-) delete mode 100644 assets/icons/ai_anthropic_hosted.svg diff --git a/Cargo.lock b/Cargo.lock index fc906cf259..9af249804c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7657,6 +7657,7 @@ dependencies = [ name = "language_model_selector" version = "0.1.0" dependencies = [ + "collections", "feature_flags", "gpui", "language_model", diff --git a/assets/icons/ai_anthropic_hosted.svg b/assets/icons/ai_anthropic_hosted.svg deleted file mode 100644 index b088520490..0000000000 --- a/assets/icons/ai_anthropic_hosted.svg +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - - - diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs index 11726b2574..091071af29 100644 --- a/crates/agent/src/assistant_model_selector.rs +++ b/crates/agent/src/assistant_model_selector.rs @@ -80,17 +80,16 @@ impl AssistantModelSelector { impl Render for AssistantModelSelector { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let model_registry = LanguageModelRegistry::read_global(cx); + let focus_handle = self.focus_handle.clone(); + let model_registry = LanguageModelRegistry::read_global(cx); let model = match self.model_type { ModelType::Default => model_registry.default_model(), ModelType::InlineAssistant => model_registry.inline_assistant_model(), }; - - let focus_handle = self.focus_handle.clone(); - let model_name = match model { - Some(model) => model.model.name().0, - _ => SharedString::from("No model selected"), + let (model_name, model_icon) = match model { + Some(model) => (model.model.name().0, Some(model.provider.icon())), + _ => (SharedString::from("No model selected"), None), }; LanguageModelSelectorPopoverMenu::new( @@ -100,10 +99,16 @@ impl Render for AssistantModelSelector { .child( h_flex() .gap_0p5() + .children( + model_icon.map(|icon| { + Icon::new(icon).color(Color::Muted).size(IconSize::Small) + }), + ) .child( Label::new(model_name) .size(LabelSize::Small) - .color(Color::Muted), + .color(Color::Muted) + .ml_1(), ) .child( Icon::new(IconName::ChevronDown) diff --git a/crates/file_finder/src/file_finder_tests.rs b/crates/file_finder/src/file_finder_tests.rs index d5d3582858..d2a5f1402d 100644 --- a/crates/file_finder/src/file_finder_tests.rs +++ b/crates/file_finder/src/file_finder_tests.rs @@ -2133,18 +2133,28 @@ async fn test_repeat_toggle_action(cx: &mut gpui::TestAppContext) { cx.dispatch_action(ToggleFileFinder::default()); let picker = active_file_picker(&workspace, cx); + + picker.update_in(cx, |picker, window, cx| { + picker.update_matches(".txt".to_string(), window, cx) + }); + + cx.run_until_parked(); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 6); assert_eq!(picker.delegate.selected_index, 0); - assert_eq!(picker.logical_scroll_top_index(), 0); }); // When toggling repeatedly, the picker scrolls to reveal the selected item. cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default()); cx.dispatch_action(ToggleFileFinder::default()); + + cx.run_until_parked(); + picker.update(cx, |picker, _| { + assert_eq!(picker.delegate.matches.len(), 6); assert_eq!(picker.delegate.selected_index, 3); - assert_eq!(picker.logical_scroll_top_index(), 3); }); } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 6c448c03ed..d7f4a820da 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -10,7 +10,6 @@ use strum::{EnumIter, EnumString, IntoStaticStr}; pub enum IconName { Ai, AiAnthropic, - AiAnthropicHosted, AiBedrock, AiDeepSeek, AiEdit, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index aa060f7b30..98456e7db4 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -174,10 +174,6 @@ impl Default for LanguageModelTextStream { pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; - /// If None, falls back to [LanguageModelProvider::icon] - fn icon(&self) -> Option { - None - } fn provider_id(&self) -> LanguageModelProviderId; fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; @@ -304,6 +300,9 @@ pub trait LanguageModelProvider: 'static { } fn default_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; + fn recommended_models(&self, _cx: &App) -> Vec> { + Vec::new() + } fn load_model(&self, _model: Arc, _cx: &App) {} fn is_authenticated(&self, cx: &App) -> bool; fn authenticate(&self, cx: &mut App) -> Task>; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index e5c66670d8..cc15ce3364 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -6,7 +6,6 @@ use client::Client; use gpui::{ App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, }; -use icons::IconName; use proto::{Plan, TypedEnvelope}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -53,13 +52,6 @@ impl CloudModel { } } - pub fn icon(&self) -> Option { - match self { - Self::Anthropic(_) => Some(IconName::AiAnthropicHosted), - _ => None, - } - } - pub fn max_token_count(&self) -> usize { match self { Self::Anthropic(model) => model.max_token_count(), diff --git a/crates/language_model_selector/Cargo.toml b/crates/language_model_selector/Cargo.toml index 1257ae564c..39bc8a59f9 100644 --- a/crates/language_model_selector/Cargo.toml +++ b/crates/language_model_selector/Cargo.toml @@ -12,6 +12,7 @@ workspace = true path = "src/language_model_selector.rs" [dependencies] +collections.workspace = true feature_flags.workspace = true gpui.workspace = true language_model.workspace = true diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index 90747a01f3..7f18b4d9fd 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -1,12 +1,13 @@ use std::sync::Arc; +use collections::{HashSet, IndexMap}; use feature_flags::{Assistant2FeatureFlag, ZedPro}; use gpui::{ Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, }; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry, + AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; use picker::{Picker, PickerDelegate}; use proto::Plan; @@ -24,9 +25,6 @@ type OnModelChanged = Arc, &App) + 'static>; pub struct LanguageModelSelector { picker: Entity>, - /// The task used to update the picker's matches when there is a change to - /// the language model registry. - update_matches_task: Option>, _authenticate_all_providers_task: Task<()>, _subscriptions: Vec, } @@ -40,16 +38,18 @@ impl LanguageModelSelector { let on_model_changed = Arc::new(on_model_changed); let all_models = Self::all_models(cx); + let entries = all_models.entries(); + let delegate = LanguageModelPickerDelegate { language_model_selector: cx.entity().downgrade(), on_model_changed: on_model_changed.clone(), - all_models: all_models.clone(), - filtered_models: all_models, - selected_index: Self::get_active_model_index(cx), + all_models: Arc::new(all_models), + selected_index: Self::get_active_model_index(&entries, cx), + filtered_entries: entries, }; let picker = cx.new(|cx| { - Picker::uniform_list(delegate, window, cx) + Picker::list(delegate, window, cx) .show_scrollbar(true) .width(rems(20.)) .max_height(Some(rems(20.).into())) @@ -59,7 +59,6 @@ impl LanguageModelSelector { LanguageModelSelector { picker, - update_matches_task: None, _authenticate_all_providers_task: Self::authenticate_all_providers(cx), _subscriptions: vec![ cx.subscribe_in( @@ -83,12 +82,13 @@ impl LanguageModelSelector { language_model::Event::ProviderStateChanged | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { - let task = self.picker.update(cx, |this, cx| { + self.picker.update(cx, |this, cx| { let query = this.query(cx); - this.delegate.all_models = Self::all_models(cx); - this.delegate.update_matches(query, window, cx) + this.delegate.all_models = Arc::new(Self::all_models(cx)); + // Update matches will automatically drop the previous task + // if we get a provider event again + this.update_matches(query, window, cx) }); - self.update_matches_task = Some(task); } _ => {} } @@ -144,34 +144,72 @@ impl LanguageModelSelector { }) } - fn all_models(cx: &App) -> Vec { - LanguageModelRegistry::global(cx) + fn all_models(cx: &App) -> GroupedModels { + let mut recommended = Vec::new(); + let mut recommended_set = HashSet::default(); + for provider in LanguageModelRegistry::global(cx) .read(cx) .providers() .iter() - .flat_map(|provider| { - let icon = provider.icon(); - - provider.provided_models(cx).into_iter().map(move |model| { - let model = model.clone(); - let icon = model.icon().unwrap_or(icon); - - ModelInfo { + { + let models = provider.recommended_models(cx); + recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id()))); + recommended.extend( + provider + .recommended_models(cx) + .into_iter() + .map(move |model| ModelInfo { model: model.clone(), - icon, - availability: model.availability(), - } - }) + icon: provider.icon(), + }), + ); + } + + let other_models = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .iter() + .map(|provider| { + ( + provider.id(), + provider + .provided_models(cx) + .into_iter() + .filter_map(|model| { + let not_included = + !recommended_set.contains(&(model.provider_id(), model.id())); + not_included.then(|| ModelInfo { + model: model.clone(), + icon: provider.icon(), + }) + }) + .collect::>(), + ) }) - .collect::>() + .collect::>(); + + GroupedModels { + recommended, + other: other_models, + } } - fn get_active_model_index(cx: &App) -> usize { + fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize { let active_model = LanguageModelRegistry::read_global(cx).default_model(); - Self::all_models(cx) + entries .iter() - .position(|model_info| { - Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id()) + .position(|entry| { + if let LanguageModelPickerEntry::Model(model) = entry { + active_model + .as_ref() + .map(|active_model| { + active_model.model.id() == model.model.id() + && active_model.model.provider_id() == model.model.provider_id() + }) + .unwrap_or_default() + } else { + false + } }) .unwrap_or(0) } @@ -254,22 +292,61 @@ where struct ModelInfo { model: Arc, icon: IconName, - availability: LanguageModelAvailability, } pub struct LanguageModelPickerDelegate { language_model_selector: WeakEntity, on_model_changed: OnModelChanged, - all_models: Vec, - filtered_models: Vec, + all_models: Arc, + filtered_entries: Vec, selected_index: usize, } +struct GroupedModels { + recommended: Vec, + other: IndexMap>, +} + +impl GroupedModels { + fn entries(&self) -> Vec { + let mut entries = Vec::new(); + + if !self.recommended.is_empty() { + entries.push(LanguageModelPickerEntry::Separator("Recommended".into())); + entries.extend( + self.recommended + .iter() + .map(|info| LanguageModelPickerEntry::Model(info.clone())), + ); + } + + for models in self.other.values() { + if models.is_empty() { + continue; + } + entries.push(LanguageModelPickerEntry::Separator( + models[0].model.provider_name().0, + )); + entries.extend( + models + .iter() + .map(|info| LanguageModelPickerEntry::Model(info.clone())), + ); + } + entries + } +} + +enum LanguageModelPickerEntry { + Model(ModelInfo), + Separator(SharedString), +} + impl PickerDelegate for LanguageModelPickerDelegate { - type ListItem = ListItem; + type ListItem = AnyElement; fn match_count(&self) -> usize { - self.filtered_models.len() + self.filtered_entries.len() } fn selected_index(&self) -> usize { @@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate { } fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { - self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1)); + self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); cx.notify(); } + fn can_select( + &mut self, + ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + match self.filtered_entries.get(ix) { + Some(LanguageModelPickerEntry::Model(_)) => true, + Some(LanguageModelPickerEntry::Separator(_)) | None => false, + } + } + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { - "Select a model...".into() + "Select a model…".into() } fn update_matches( @@ -307,22 +396,9 @@ impl PickerDelegate for LanguageModelPickerDelegate { cx.spawn_in(window, async move |this, cx| { let filtered_models = cx .background_spawn(async move { - let displayed_models = if configured_providers.is_empty() { - all_models - } else { - all_models - .into_iter() - .filter(|model_info| { - configured_providers.contains(&model_info.model.provider_id()) - }) - .collect::>() - }; - - if query.is_empty() { - displayed_models - } else { - displayed_models - .into_iter() + let filter_models = |model_infos: &[ModelInfo]| { + model_infos + .iter() .filter(|model_info| { model_info .model @@ -331,20 +407,33 @@ impl PickerDelegate for LanguageModelPickerDelegate { .to_lowercase() .contains(&query.to_lowercase()) }) - .collect() + .cloned() + .collect::>() + }; + + let recommended_models = filter_models(&all_models.recommended); + let mut other_models = IndexMap::default(); + for (provider_id, models) in &all_models.other { + if configured_providers.contains(&provider_id) { + other_models.insert(provider_id.clone(), filter_models(models)); + } + } + GroupedModels { + recommended: recommended_models, + other: other_models, } }) .await; this.update_in(cx, |this, window, cx| { - this.delegate.filtered_models = filtered_models; + this.delegate.filtered_entries = filtered_models.entries(); // Preserve selection focus - let new_index = if current_index >= this.delegate.filtered_models.len() { + let new_index = if current_index >= this.delegate.filtered_entries.len() { 0 } else { current_index }; - this.delegate.set_selected_index(new_index, window, cx); + this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); cx.notify(); }) .ok(); @@ -352,7 +441,9 @@ impl PickerDelegate for LanguageModelPickerDelegate { } fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { - if let Some(model_info) = self.filtered_models.get(self.selected_index) { + if let Some(LanguageModelPickerEntry::Model(model_info)) = + self.filtered_entries.get(self.selected_index) + { let model = model_info.model.clone(); (self.on_model_changed)(model.clone(), cx); @@ -369,29 +460,6 @@ impl PickerDelegate for LanguageModelPickerDelegate { .ok(); } - fn render_header(&self, _: &mut Window, cx: &mut Context>) -> Option { - let configured_models_count = LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .iter() - .filter(|provider| provider.is_authenticated(cx)) - .count(); - - if configured_models_count > 0 { - Some( - Label::new("Configured Models") - .size(LabelSize::Small) - .color(Color::Muted) - .mt_1() - .mb_0p5() - .ml_2() - .into_any_element(), - ) - } else { - None - } - } - fn render_match( &self, ix: usize, @@ -399,77 +467,68 @@ impl PickerDelegate for LanguageModelPickerDelegate { _: &mut Window, cx: &mut Context>, ) -> Option { - use feature_flags::FeatureFlagAppExt; - let show_badges = cx.has_flag::(); + match self.filtered_entries.get(ix)? { + LanguageModelPickerEntry::Separator(title) => Some( + div() + .px_2() + .pb_1() + .when(ix > 1, |this| { + this.mt_1() + .pt_2() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + Label::new(title) + .size(LabelSize::XSmall) + .color(Color::Muted), + ) + .into_any_element(), + ), + LanguageModelPickerEntry::Model(model_info) => { + let active_model = LanguageModelRegistry::read_global(cx).default_model(); - let model_info = self.filtered_models.get(ix)?; - let provider_name: String = model_info.model.provider_name().0.clone().into(); + let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); + let active_model_id = active_model.map(|m| m.model.id()); - let active_model = LanguageModelRegistry::read_global(cx).default_model(); + let is_selected = Some(model_info.model.provider_id()) == active_provider_id + && Some(model_info.model.id()) == active_model_id; - let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); - let active_model_id = active_model.map(|m| m.model.id()); + let model_icon_color = if is_selected { + Color::Accent + } else { + Color::Muted + }; - let is_selected = Some(model_info.model.provider_id()) == active_provider_id - && Some(model_info.model.id()) == active_model_id; - - let model_icon_color = if is_selected { - Color::Accent - } else { - Color::Muted - }; - - Some( - ListItem::new(ix) - .inset(true) - .spacing(ListItemSpacing::Sparse) - .toggle_state(selected) - .start_slot( - Icon::new(model_info.icon) - .color(model_icon_color) - .size(IconSize::Small), - ) - .child( - h_flex() - .w_full() - .items_center() - .gap_1p5() - .pl_0p5() - .w(px(240.)) - .child( - div() - .max_w_40() - .child(Label::new(model_info.model.name().0.clone()).truncate()), + Some( + ListItem::new(ix) + .inset(true) + .spacing(ListItemSpacing::Sparse) + .toggle_state(selected) + .start_slot( + Icon::new(model_info.icon) + .color(model_icon_color) + .size(IconSize::Small), ) .child( h_flex() - .gap_0p5() - .child( - Label::new(provider_name) - .size(LabelSize::XSmall) - .color(Color::Muted), - ) - .children(match model_info.availability { - LanguageModelAvailability::Public => None, - LanguageModelAvailability::RequiresPlan(Plan::Free) => None, - LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => { - show_badges.then(|| { - Label::new("Pro") - .size(LabelSize::XSmall) - .color(Color::Muted) - }) - } - }), - ), + .w_full() + .pl_0p5() + .gap_1p5() + .w(px(240.)) + .child(Label::new(model_info.model.name().0.clone()).truncate()), + ) + .end_slot(div().pr_3().when(is_selected, |this| { + this.child( + Icon::new(IconName::Check) + .color(Color::Accent) + .size(IconSize::Small), + ) + })) + .into_any_element(), ) - .end_slot(div().pr_3().when(is_selected, |this| { - this.child( - Icon::new(IconName::Check) - .color(Color::Accent) - .size(IconSize::Small), - ) - })), - ) + } + } } fn render_footer( diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index bce985a872..4540a08268 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -192,6 +192,16 @@ impl AnthropicLanguageModelProvider { Self { http_client, state } } + + fn create_language_model(&self, model: anthropic::Model) -> Arc { + Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for AnthropicLanguageModelProvider { @@ -226,6 +236,16 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { })) } + fn recommended_models(&self, _cx: &App) -> Vec> { + [ + anthropic::Model::Claude3_7Sonnet, + anthropic::Model::Claude3_7SonnetThinking, + ] + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); @@ -266,15 +286,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { models .into_values() - .map(|model| { - Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model)) .collect() } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 9377bf315f..6a08f48522 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -225,6 +225,20 @@ impl CloudLanguageModelProvider { _maintain_client_status: maintain_client_status, } } + + fn create_language_model( + &self, + model: CloudModel, + llm_api_token: LlmApiToken, + ) -> Arc { + Arc::new(CloudLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + llm_api_token: llm_api_token.clone(), + client: self.client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + } } impl LanguageModelProviderState for CloudLanguageModelProvider { @@ -260,6 +274,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider { })) } + fn recommended_models(&self, cx: &App) -> Vec> { + let llm_api_token = self.state.read(cx).llm_api_token.clone(); + [ + CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet), + CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking), + ] + .into_iter() + .map(|model| self.create_language_model(model, llm_api_token.clone())) + .collect() + } + fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); @@ -345,15 +370,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider { let llm_api_token = self.state.read(cx).llm_api_token.clone(); models .into_values() - .map(|model| { - Arc::new(CloudLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - llm_api_token: llm_api_token.clone(), - client: self.client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) + .map(|model| self.create_language_model(model, llm_api_token.clone())) .collect() } @@ -575,10 +592,6 @@ impl LanguageModel for CloudLanguageModel { LanguageModelName::from(self.model.display_name().to_string()) } - fn icon(&self) -> Option { - self.model.icon() - } - fn provider_id(&self) -> LanguageModelProviderId { LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into()) } diff --git a/crates/picker/src/picker.rs b/crates/picker/src/picker.rs index 2caa9ff756..54b50453ce 100644 --- a/crates/picker/src/picker.rs +++ b/crates/picker/src/picker.rs @@ -3,8 +3,8 @@ use editor::{Editor, scroll::Autoscroll}; use gpui::{ AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render, - ScrollHandle, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, - impl_actions, list, prelude::*, uniform_list, + ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, impl_actions, + list, prelude::*, uniform_list, }; use head::Head; use schemars::JsonSchema; @@ -24,6 +24,11 @@ enum ElementContainer { UniformList(UniformListScrollHandle), } +pub enum Direction { + Up, + Down, +} + actions!(picker, [ConfirmCompletion]); /// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally, @@ -86,6 +91,15 @@ pub trait PickerDelegate: Sized + 'static { window: &mut Window, cx: &mut Context>, ); + fn can_select( + &mut self, + _ix: usize, + _window: &mut Window, + _cx: &mut Context>, + ) -> bool { + true + } + // Allows binding some optional effect to when the selection changes. fn selected_index_changed( &self, @@ -271,10 +285,7 @@ impl Picker { ElementContainer::UniformList(scroll_handle) => { ScrollbarState::new(scroll_handle.clone()) } - ElementContainer::List(_) => { - // todo smit: implement for list - ScrollbarState::new(ScrollHandle::new()) - } + ElementContainer::List(state) => ScrollbarState::new(state.clone()), }; let focus_handle = cx.focus_handle(); let mut this = Self { @@ -359,16 +370,58 @@ impl Picker { } /// Handles the selecting an index, and passing the change to the delegate. - /// If `scroll_to_index` is true, the new selected index will be scrolled into view. + /// If `fallback_direction` is set to `None`, the index will not be selected + /// if the element at that index cannot be selected. + /// If `fallback_direction` is set to + /// `Some(..)`, the next selectable element will be selected in the + /// specified direction (Down or Up), cycling through all elements until + /// finding one that can be selected or returning if there are no selectable elements. + /// If `scroll_to_index` is true, the new selected index will be scrolled into + /// view. /// /// If some effect is bound to `selected_index_changed`, it will be executed. pub fn set_selected_index( &mut self, - ix: usize, + mut ix: usize, + fallback_direction: Option, scroll_to_index: bool, window: &mut Window, cx: &mut Context, ) { + let match_count = self.delegate.match_count(); + if match_count == 0 { + return; + } + + if let Some(bias) = fallback_direction { + let mut curr_ix = ix; + while !self.delegate.can_select(curr_ix, window, cx) { + curr_ix = match bias { + Direction::Down => { + if curr_ix == match_count - 1 { + 0 + } else { + curr_ix + 1 + } + } + Direction::Up => { + if curr_ix == 0 { + match_count - 1 + } else { + curr_ix - 1 + } + } + }; + // There is no item that can be selected + if ix == curr_ix { + return; + } + } + ix = curr_ix; + } else if !self.delegate.can_select(ix, window, cx) { + return; + } + let previous_index = self.delegate.selected_index(); self.delegate.set_selected_index(ix, window, cx); let current_index = self.delegate.selected_index(); @@ -393,7 +446,7 @@ impl Picker { if count > 0 { let index = self.delegate.selected_index(); let ix = if index == count - 1 { 0 } else { index + 1 }; - self.set_selected_index(ix, true, window, cx); + self.set_selected_index(ix, Some(Direction::Down), true, window, cx); cx.notify(); } } @@ -408,7 +461,7 @@ impl Picker { if count > 0 { let index = self.delegate.selected_index(); let ix = if index == 0 { count - 1 } else { index - 1 }; - self.set_selected_index(ix, true, window, cx); + self.set_selected_index(ix, Some(Direction::Up), true, window, cx); cx.notify(); } } @@ -416,7 +469,7 @@ impl Picker { fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context) { let count = self.delegate.match_count(); if count > 0 { - self.set_selected_index(0, true, window, cx); + self.set_selected_index(0, Some(Direction::Down), true, window, cx); cx.notify(); } } @@ -424,7 +477,7 @@ impl Picker { fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context) { let count = self.delegate.match_count(); if count > 0 { - self.set_selected_index(count - 1, true, window, cx); + self.set_selected_index(count - 1, Some(Direction::Up), true, window, cx); cx.notify(); } } @@ -433,7 +486,7 @@ impl Picker { let count = self.delegate.match_count(); let index = self.delegate.selected_index(); let new_index = if index + 1 == count { 0 } else { index + 1 }; - self.set_selected_index(new_index, true, window, cx); + self.set_selected_index(new_index, Some(Direction::Down), true, window, cx); cx.notify(); } @@ -506,14 +559,14 @@ impl Picker { ) { cx.stop_propagation(); window.prevent_default(); - self.set_selected_index(ix, false, window, cx); + self.set_selected_index(ix, None, false, window, cx); self.do_confirm(secondary, window, cx) } fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context) { if let Some(update_query) = self.delegate.confirm_update_query(window, cx) { self.set_query(update_query, window, cx); - self.delegate.set_selected_index(0, window, cx); + self.set_selected_index(0, Some(Direction::Down), false, window, cx); } else { self.delegate.confirm(secondary, window, cx) } diff --git a/crates/prompt_library/src/prompt_library.rs b/crates/prompt_library/src/prompt_library.rs index c2c1f3da60..7fff6d1258 100644 --- a/crates/prompt_library/src/prompt_library.rs +++ b/crates/prompt_library/src/prompt_library.rs @@ -657,7 +657,7 @@ impl PromptLibrary { .iter() .position(|mat| mat.id == prompt_id) { - picker.set_selected_index(ix, true, window, cx); + picker.set_selected_index(ix, None, true, window, cx); } } } else {