use std::{cmp::Reverse, sync::Arc}; use cloud_llm_client::Plan; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task}; use language_model::{ ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; use ui::{ListItem, ListItemSpacing, prelude::*}; const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; type OnModelChanged = Arc, &mut App) + 'static>; type GetActiveModel = Arc Option + 'static>; pub type LanguageModelSelector = Picker; pub fn language_model_selector( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, window: &mut Window, cx: &mut Context, ) -> LanguageModelSelector { let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx); Picker::list(delegate, window, cx) .show_scrollbar(true) .width(rems(20.)) .max_height(Some(rems(20.).into())) } fn all_models(cx: &App) -> GroupedModels { let providers = LanguageModelRegistry::global(cx).read(cx).providers(); let recommended = providers .iter() .flat_map(|provider| { provider .recommended_models(cx) .into_iter() .map(|model| ModelInfo { model, icon: provider.icon(), }) }) .collect(); let other = providers .iter() .flat_map(|provider| { provider .provided_models(cx) .into_iter() .map(|model| ModelInfo { model, icon: provider.icon(), }) }) .collect(); GroupedModels::new(other, recommended) } #[derive(Clone)] struct ModelInfo { model: Arc, icon: IconName, } pub struct LanguageModelPickerDelegate { on_model_changed: OnModelChanged, get_active_model: GetActiveModel, all_models: Arc, filtered_entries: Vec, selected_index: usize, _subscriptions: Vec, } impl LanguageModelPickerDelegate { fn new( get_active_model: impl Fn(&App) -> Option + 'static, on_model_changed: impl Fn(Arc, &mut App) + 'static, window: &mut Window, cx: &mut Context>, ) -> Self { let on_model_changed = Arc::new(on_model_changed); let models = all_models(cx); let entries = models.entries(); Self { on_model_changed, all_models: Arc::new(models), selected_index: Self::get_active_model_index(&entries, get_active_model(cx)), filtered_entries: entries, get_active_model: Arc::new(get_active_model), _subscriptions: vec![cx.subscribe_in( &LanguageModelRegistry::global(cx), window, |picker, _, event, window, cx| { match event { language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { let query = picker.query(cx); picker.delegate.all_models = Arc::new(all_models(cx)); // Update matches will automatically drop the previous task // if we get a provider event again picker.update_matches(query, window, cx) } _ => {} } }, )], } } fn get_active_model_index( entries: &[LanguageModelPickerEntry], active_model: Option, ) -> usize { entries .iter() .position(|entry| { if let LanguageModelPickerEntry::Model(model) = entry { active_model .as_ref() .map(|active_model| { active_model.model.id() == model.model.id() && active_model.provider.id() == model.model.provider_id() }) .unwrap_or_default() } else { false } }) .unwrap_or(0) } pub fn active_model(&self, cx: &App) -> Option { (self.get_active_model)(cx) } } struct GroupedModels { recommended: Vec, other: IndexMap>, } impl GroupedModels { pub fn new(other: Vec, recommended: Vec) -> Self { let recommended_ids = recommended .iter() .map(|info| (info.model.provider_id(), info.model.id())) .collect::>(); let mut other_by_provider: IndexMap<_, Vec> = IndexMap::default(); for model in other { if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) { continue; } let provider = model.model.provider_id(); if let Some(models) = other_by_provider.get_mut(&provider) { models.push(model); } else { other_by_provider.insert(provider, vec![model]); } } Self { recommended, other: other_by_provider, } } fn entries(&self) -> Vec { let mut entries = Vec::new(); if !self.recommended.is_empty() { entries.push(LanguageModelPickerEntry::Separator("Recommended".into())); entries.extend( self.recommended .iter() .map(|info| LanguageModelPickerEntry::Model(info.clone())), ); } for models in self.other.values() { if models.is_empty() { continue; } entries.push(LanguageModelPickerEntry::Separator( models[0].model.provider_name().0, )); entries.extend( models .iter() .map(|info| LanguageModelPickerEntry::Model(info.clone())), ); } entries } fn model_infos(&self) -> Vec { let other = self .other .values() .flat_map(|model| model.iter()) .cloned() .collect::>(); self.recommended .iter() .chain(&other) .cloned() .collect::>() } } enum LanguageModelPickerEntry { Model(ModelInfo), Separator(SharedString), } struct ModelMatcher { models: Vec, bg_executor: BackgroundExecutor, candidates: Vec, } impl ModelMatcher { fn new(models: Vec, bg_executor: BackgroundExecutor) -> ModelMatcher { let candidates = Self::make_match_candidates(&models); Self { models, bg_executor, candidates, } } pub fn fuzzy_search(&self, query: &str) -> Vec { let mut matches = self.bg_executor.block(match_strings( &self.candidates, query, false, true, 100, &Default::default(), self.bg_executor.clone(), )); let sorting_key = |mat: &StringMatch| { let candidate = &self.candidates[mat.candidate_id]; (Reverse(OrderedFloat(mat.score)), candidate.id) }; matches.sort_unstable_by_key(sorting_key); let matched_models: Vec<_> = matches .into_iter() .map(|mat| self.models[mat.candidate_id].clone()) .collect(); matched_models } pub fn exact_search(&self, query: &str) -> Vec { self.models .iter() .filter(|m| { m.model .name() .0 .to_lowercase() .contains(&query.to_lowercase()) }) .cloned() .collect::>() } fn make_match_candidates(model_infos: &Vec) -> Vec { model_infos .iter() .enumerate() .map(|(index, model)| { StringMatchCandidate::new( index, &format!( "{}/{}", &model.model.provider_name().0, &model.model.name().0 ), ) }) .collect::>() } } impl PickerDelegate for LanguageModelPickerDelegate { type ListItem = AnyElement; fn match_count(&self) -> usize { self.filtered_entries.len() } fn selected_index(&self) -> usize { self.selected_index } fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context>) { self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1)); cx.notify(); } fn can_select( &mut self, ix: usize, _window: &mut Window, _cx: &mut Context>, ) -> bool { match self.filtered_entries.get(ix) { Some(LanguageModelPickerEntry::Model(_)) => true, Some(LanguageModelPickerEntry::Separator(_)) | None => false, } } fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { "Select a model…".into() } fn update_matches( &mut self, query: String, window: &mut Window, cx: &mut Context>, ) -> Task<()> { let all_models = self.all_models.clone(); let active_model = (self.get_active_model)(cx); let bg_executor = cx.background_executor(); let language_model_registry = LanguageModelRegistry::global(cx); let configured_providers = language_model_registry .read(cx) .providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); let configured_provider_ids = configured_providers .iter() .map(|provider| provider.id()) .collect::>(); let recommended_models = all_models .recommended .iter() .filter(|m| configured_provider_ids.contains(&m.model.provider_id())) .cloned() .collect::>(); let available_models = all_models .model_infos() .iter() .filter(|m| configured_provider_ids.contains(&m.model.provider_id())) .cloned() .collect::>(); let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone()); let matcher_all = ModelMatcher::new(available_models, bg_executor.clone()); let recommended = matcher_rec.exact_search(&query); let all = matcher_all.fuzzy_search(&query); let filtered_models = GroupedModels::new(all, recommended); cx.spawn_in(window, async move |this, cx| { this.update_in(cx, |this, window, cx| { this.delegate.filtered_entries = filtered_models.entries(); // Finds the currently selected model in the list let new_index = Self::get_active_model_index(&this.delegate.filtered_entries, active_model); this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx); cx.notify(); }) .ok(); }) } fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { if let Some(LanguageModelPickerEntry::Model(model_info)) = self.filtered_entries.get(self.selected_index) { let model = model_info.model.clone(); (self.on_model_changed)(model.clone(), cx); let current_index = self.selected_index; self.set_selected_index(current_index, window, cx); cx.emit(DismissEvent); } } fn dismissed(&mut self, _: &mut Window, cx: &mut Context>) { cx.emit(DismissEvent); } fn render_match( &self, ix: usize, selected: bool, _: &mut Window, cx: &mut Context>, ) -> Option { match self.filtered_entries.get(ix)? { LanguageModelPickerEntry::Separator(title) => Some( div() .px_2() .pb_1() .when(ix > 1, |this| { this.mt_1() .pt_2() .border_t_1() .border_color(cx.theme().colors().border_variant) }) .child( Label::new(title) .size(LabelSize::XSmall) .color(Color::Muted), ) .into_any_element(), ), LanguageModelPickerEntry::Model(model_info) => { let active_model = (self.get_active_model)(cx); let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let active_model_id = active_model.map(|m| m.model.id()); let is_selected = Some(model_info.model.provider_id()) == active_provider_id && Some(model_info.model.id()) == active_model_id; let model_icon_color = if is_selected { Color::Accent } else { Color::Muted }; Some( ListItem::new(ix) .inset(true) .spacing(ListItemSpacing::Sparse) .toggle_state(selected) .start_slot( Icon::new(model_info.icon) .color(model_icon_color) .size(IconSize::Small), ) .child( h_flex() .w_full() .pl_0p5() .gap_1p5() .w(px(240.)) .child(Label::new(model_info.model.name().0).truncate()), ) .end_slot(div().pr_3().when(is_selected, |this| { this.child( Icon::new(IconName::Check) .color(Color::Accent) .size(IconSize::Small), ) })) .into_any_element(), ) } } } fn render_footer( &self, _: &mut Window, cx: &mut Context>, ) -> Option { use feature_flags::FeatureFlagAppExt; let plan = Plan::ZedPro; Some( h_flex() .w_full() .border_t_1() .border_color(cx.theme().colors().border_variant) .p_1() .gap_4() .justify_between() .when(cx.has_flag::(), |this| { this.child(match plan { Plan::ZedPro => Button::new("zed-pro", "Zed Pro") .icon(IconName::ZedAssistant) .icon_size(IconSize::Small) .icon_color(Color::Muted) .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx) }), Plan::ZedFree | Plan::ZedProTrial => Button::new( "try-pro", if plan == Plan::ZedProTrial { "Upgrade to Pro" } else { "Try Pro" }, ) .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)), }) }) .child( Button::new("configure", "Configure") .icon(IconName::Settings) .icon_size(IconSize::Small) .icon_color(Color::Muted) .icon_position(IconPosition::Start) .on_click(|_, window, cx| { window.dispatch_action( zed_actions::agent::OpenSettings.boxed_clone(), cx, ); }), ) .into_any(), ) } } #[cfg(test)] mod tests { use super::*; use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{AsyncApp, TestAppContext, http_client}; use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, LanguageModelToolChoice, }; use ui::IconName; #[derive(Clone)] struct TestLanguageModel { name: LanguageModelName, id: LanguageModelId, provider_id: LanguageModelProviderId, provider_name: LanguageModelProviderName, } impl TestLanguageModel { fn new(name: &str, provider: &str) -> Self { Self { name: LanguageModelName::from(name.to_string()), id: LanguageModelId::from(name.to_string()), provider_id: LanguageModelProviderId::from(provider.to_string()), provider_name: LanguageModelProviderName::from(provider.to_string()), } } } impl LanguageModel for TestLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() } fn name(&self) -> LanguageModelName { self.name.clone() } fn provider_id(&self) -> LanguageModelProviderId { self.provider_id.clone() } fn provider_name(&self) -> LanguageModelProviderName { self.provider_name.clone() } fn supports_tools(&self) -> bool { false } fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { false } fn supports_images(&self) -> bool { false } fn telemetry_id(&self) -> String { format!("{}/{}", self.provider_id.0, self.name.0) } fn max_token_count(&self) -> u64 { 1000 } fn count_tokens( &self, _: LanguageModelRequest, _: &App, ) -> BoxFuture<'static, http_client::Result> { unimplemented!() } fn stream_completion( &self, _: LanguageModelRequest, _: &AsyncApp, ) -> BoxFuture< 'static, Result< BoxStream< 'static, Result, >, LanguageModelCompletionError, >, > { unimplemented!() } } fn create_models(model_specs: Vec<(&str, &str)>) -> Vec { model_specs .into_iter() .map(|(provider, name)| ModelInfo { model: Arc::new(TestLanguageModel::new(name, provider)), icon: IconName::Ai, }) .collect() } fn assert_models_eq(result: Vec, expected: Vec<&str>) { assert_eq!( result.len(), expected.len(), "Number of models doesn't match" ); for (i, expected_name) in expected.iter().enumerate() { assert_eq!( result[i].model.telemetry_id(), *expected_name, "Model at position {} doesn't match expected model", i ); } } #[gpui::test] fn test_exact_match(cx: &mut TestAppContext) { let models = create_models(vec![ ("zed", "Claude 3.7 Sonnet"), ("zed", "Claude 3.7 Sonnet Thinking"), ("zed", "gpt-4.1"), ("zed", "gpt-4.1-nano"), ("openai", "gpt-3.5-turbo"), ("openai", "gpt-4.1"), ("openai", "gpt-4.1-nano"), ("ollama", "mistral"), ("ollama", "deepseek"), ]); let matcher = ModelMatcher::new(models, cx.background_executor.clone()); // The order of models should be maintained, case doesn't matter let results = matcher.exact_search("GPT-4.1"); assert_models_eq( results, vec![ "zed/gpt-4.1", "zed/gpt-4.1-nano", "openai/gpt-4.1", "openai/gpt-4.1-nano", ], ); } #[gpui::test] fn test_fuzzy_match(cx: &mut TestAppContext) { let models = create_models(vec![ ("zed", "Claude 3.7 Sonnet"), ("zed", "Claude 3.7 Sonnet Thinking"), ("zed", "gpt-4.1"), ("zed", "gpt-4.1-nano"), ("openai", "gpt-3.5-turbo"), ("openai", "gpt-4.1"), ("openai", "gpt-4.1-nano"), ("ollama", "mistral"), ("ollama", "deepseek"), ]); let matcher = ModelMatcher::new(models, cx.background_executor.clone()); // Results should preserve models order whenever possible. // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical // similarity scores, but `zed/gpt-4.1` was higher in the models list, // so it should appear first in the results. let results = matcher.fuzzy_search("41"); assert_models_eq( results, vec![ "zed/gpt-4.1", "openai/gpt-4.1", "zed/gpt-4.1-nano", "openai/gpt-4.1-nano", ], ); // Model provider should be searchable as well let results = matcher.fuzzy_search("ol"); // meaning "ollama" assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]); // Fuzzy search let results = matcher.fuzzy_search("z4n"); assert_models_eq(results, vec!["zed/gpt-4.1-nano"]); } #[gpui::test] fn test_exclude_recommended_models(_cx: &mut TestAppContext) { let recommended_models = create_models(vec![("zed", "claude")]); let all_models = create_models(vec![ ("zed", "claude"), // Should be filtered out from "other" ("zed", "gemini"), ("copilot", "o3"), ]); let grouped_models = GroupedModels::new(all_models, recommended_models); let actual_other_models = grouped_models .other .values() .flatten() .cloned() .collect::>(); // Recommended models should not appear in "other" assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]); } #[gpui::test] fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) { let recommended_models = create_models(vec![("zed", "claude")]); let all_models = create_models(vec![ ("zed", "claude"), // Should be filtered out from "other" ("zed", "gemini"), ("copilot", "claude"), // Should not be filtered out from "other" ]); let grouped_models = GroupedModels::new(all_models, recommended_models); let actual_other_models = grouped_models .other .values() .flatten() .cloned() .collect::>(); // Recommended models should not appear in "other" assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]); } }