diff --git a/Cargo.lock b/Cargo.lock index 53a0da2a81..a29906beb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7813,9 +7813,12 @@ version = "0.1.0" dependencies = [ "collections", "feature_flags", + "futures 0.3.31", + "fuzzy", "gpui", "language_model", "log", + "ordered-float 2.10.1", "picker", "proto", "ui", diff --git a/crates/language_model_selector/Cargo.toml b/crates/language_model_selector/Cargo.toml index 39bc8a59f9..0237fe530b 100644 --- a/crates/language_model_selector/Cargo.toml +++ b/crates/language_model_selector/Cargo.toml @@ -11,14 +11,26 @@ workspace = true [lib] path = "src/language_model_selector.rs" +[features] +test-support = [ + "gpui/test-support", +] + [dependencies] collections.workspace = true feature_flags.workspace = true +futures.workspace = true +fuzzy.workspace = true gpui.workspace = true language_model.workspace = true log.workspace = true +ordered-float.workspace = true picker.workspace = true proto.workspace = true ui.workspace = true workspace-hack.workspace = true zed_actions.workspace = true + +[dev-dependencies] +gpui = { workspace = true, "features" = ["test-support"] } +language_model = { workspace = true, "features" = ["test-support"] } diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index a6a25ae426..e1dbb1cc42 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -1,15 +1,18 @@ -use std::sync::Arc; +use std::{cmp::Reverse, sync::Arc}; use collections::{HashSet, IndexMap}; use feature_flags::ZedProFeatureFlag; +use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; use gpui::{ - Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle, - Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases, + Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity, + EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, + action_with_deprecated_aliases, }; use language_model::{ AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; +use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; use proto::Plan; use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*}; @@ -322,6 +325,23 @@ struct GroupedModels { } impl GroupedModels { + pub fn new(other: Vec, recommended: Vec) -> Self { + let mut other_by_provider: IndexMap<_, Vec> = IndexMap::default(); + for model in other { + let provider = model.model.provider_id(); + if let Some(models) = other_by_provider.get_mut(&provider) { + models.push(model); + } else { + other_by_provider.insert(provider, vec![model]); + } + } + + Self { + recommended, + other: other_by_provider, + } + } + fn entries(&self) -> Vec { let mut entries = Vec::new(); @@ -349,6 +369,20 @@ impl GroupedModels { } entries } + + fn model_infos(&self) -> Vec { + let other = self + .other + .values() + .flat_map(|model| model.iter()) + .cloned() + .collect::>(); + self.recommended + .iter() + .chain(&other) + .cloned() + .collect::>() + } } enum LanguageModelPickerEntry { @@ -356,6 +390,78 @@ enum LanguageModelPickerEntry { Separator(SharedString), } +struct ModelMatcher { + models: Vec, + bg_executor: BackgroundExecutor, + candidates: Vec, +} + +impl ModelMatcher { + fn new(models: Vec, bg_executor: BackgroundExecutor) -> ModelMatcher { + let candidates = Self::make_match_candidates(&models); + Self { + models, + bg_executor, + candidates, + } + } + + pub fn fuzzy_search(&self, query: &str) -> Vec { + let mut matches = self.bg_executor.block(match_strings( + &self.candidates, + &query, + false, + 100, + &Default::default(), + self.bg_executor.clone(), + )); + + let sorting_key = |mat: &StringMatch| { + let candidate = &self.candidates[mat.candidate_id]; + (Reverse(OrderedFloat(mat.score)), candidate.id) + }; + matches.sort_unstable_by_key(sorting_key); + + let matched_models: Vec<_> = matches + .into_iter() + .map(|mat| self.models[mat.candidate_id].clone()) + .collect(); + + matched_models + } + + pub fn exact_search(&self, query: &str) -> Vec { + self.models + .iter() + .filter(|m| { + m.model + .name() + .0 + .to_lowercase() + .contains(&query.to_lowercase()) + }) + .cloned() + .collect::>() + } + + fn make_match_candidates(model_infos: &Vec) -> Vec { + model_infos + .iter() + .enumerate() + .map(|(index, model)| { + StringMatchCandidate::new( + index, + &format!( + "{}/{}", + &model.model.provider_name().0, + &model.model.name().0 + ), + ) + }) + .collect::>() + } +} + impl PickerDelegate for LanguageModelPickerDelegate { type ListItem = AnyElement; @@ -396,56 +502,45 @@ impl PickerDelegate for LanguageModelPickerDelegate { ) -> Task<()> { let all_models = self.all_models.clone(); let current_index = self.selected_index; + let bg_executor = cx.background_executor(); let language_model_registry = LanguageModelRegistry::global(cx); let configured_providers = language_model_registry .read(cx) .providers() - .iter() + .into_iter() .filter(|provider| provider.is_authenticated(cx)) + .collect::>(); + + let configured_provider_ids = configured_providers + .iter() .map(|provider| provider.id()) .collect::>(); + let recommended_models = all_models + .recommended + .iter() + .filter(|m| configured_provider_ids.contains(&m.model.provider_id())) + .cloned() + .collect::>(); + + let available_models = all_models + .model_infos() + .iter() + .filter(|m| configured_provider_ids.contains(&m.model.provider_id())) + .cloned() + .collect::>(); + + let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone()); + let matcher_all = ModelMatcher::new(available_models, bg_executor.clone()); + + let recommended = matcher_rec.exact_search(&query); + let all = matcher_all.fuzzy_search(&query); + + let filtered_models = GroupedModels::new(all, recommended); + cx.spawn_in(window, async move |this, cx| { - let filtered_models = cx - .background_spawn(async move { - let matches = |info: &ModelInfo| { - info.model - .name() - .0 - .to_lowercase() - .contains(&query.to_lowercase()) - }; - - let recommended_models = all_models - .recommended - .iter() - .filter(|r| { - configured_providers.contains(&r.model.provider_id()) && matches(r) - }) - .cloned() - .collect(); - let mut other_models = IndexMap::default(); - for (provider_id, models) in &all_models.other { - if configured_providers.contains(&provider_id) { - other_models.insert( - provider_id.clone(), - models - .iter() - .filter(|m| matches(m)) - .cloned() - .collect::>(), - ); - } - } - GroupedModels { - recommended: recommended_models, - other: other_models, - } - }) - .await; - this.update_in(cx, |this, window, cx| { this.delegate.filtered_entries = filtered_models.entries(); // Preserve selection focus @@ -607,3 +702,187 @@ impl PickerDelegate for LanguageModelPickerDelegate { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::{future::BoxFuture, stream::BoxStream}; + use gpui::{AsyncApp, TestAppContext, http_client}; + use language_model::{ + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, + LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelRequest, LanguageModelToolChoice, + }; + use ui::IconName; + + #[derive(Clone)] + struct TestLanguageModel { + name: LanguageModelName, + id: LanguageModelId, + provider_id: LanguageModelProviderId, + provider_name: LanguageModelProviderName, + } + + impl TestLanguageModel { + fn new(name: &str, provider: &str) -> Self { + Self { + name: LanguageModelName::from(name.to_string()), + id: LanguageModelId::from(name.to_string()), + provider_id: LanguageModelProviderId::from(provider.to_string()), + provider_name: LanguageModelProviderName::from(provider.to_string()), + } + } + } + + impl LanguageModel for TestLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + self.name.clone() + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.provider_id.clone() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.provider_name.clone() + } + + fn supports_tools(&self) -> bool { + false + } + + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { + false + } + + fn telemetry_id(&self) -> String { + format!("{}/{}", self.provider_id.0, self.name.0) + } + + fn max_token_count(&self) -> usize { + 1000 + } + + fn count_tokens( + &self, + _: LanguageModelRequest, + _: &App, + ) -> BoxFuture<'static, http_client::Result> { + unimplemented!() + } + + fn stream_completion( + &self, + _: LanguageModelRequest, + _: &AsyncApp, + ) -> BoxFuture< + 'static, + http_client::Result< + BoxStream< + 'static, + http_client::Result, + >, + >, + > { + unimplemented!() + } + } + + fn create_models(model_specs: Vec<(&str, &str)>) -> Vec { + model_specs + .into_iter() + .map(|(provider, name)| ModelInfo { + model: Arc::new(TestLanguageModel::new(name, provider)), + icon: IconName::Ai, + }) + .collect() + } + + fn assert_models_eq(result: Vec, expected: Vec<&str>) { + assert_eq!( + result.len(), + expected.len(), + "Number of models doesn't match" + ); + + for (i, expected_name) in expected.iter().enumerate() { + assert_eq!( + result[i].model.telemetry_id(), + *expected_name, + "Model at position {} doesn't match expected model", + i + ); + } + } + + #[gpui::test] + fn test_exact_match(cx: &mut TestAppContext) { + let models = create_models(vec![ + ("zed", "Claude 3.7 Sonnet"), + ("zed", "Claude 3.7 Sonnet Thinking"), + ("zed", "gpt-4.1"), + ("zed", "gpt-4.1-nano"), + ("openai", "gpt-3.5-turbo"), + ("openai", "gpt-4.1"), + ("openai", "gpt-4.1-nano"), + ("ollama", "mistral"), + ("ollama", "deepseek"), + ]); + let matcher = ModelMatcher::new(models, cx.background_executor.clone()); + + // The order of models should be maintained, case doesn't matter + let results = matcher.exact_search("GPT-4.1"); + assert_models_eq( + results, + vec![ + "zed/gpt-4.1", + "zed/gpt-4.1-nano", + "openai/gpt-4.1", + "openai/gpt-4.1-nano", + ], + ); + } + + #[gpui::test] + fn test_fuzzy_match(cx: &mut TestAppContext) { + let models = create_models(vec![ + ("zed", "Claude 3.7 Sonnet"), + ("zed", "Claude 3.7 Sonnet Thinking"), + ("zed", "gpt-4.1"), + ("zed", "gpt-4.1-nano"), + ("openai", "gpt-3.5-turbo"), + ("openai", "gpt-4.1"), + ("openai", "gpt-4.1-nano"), + ("ollama", "mistral"), + ("ollama", "deepseek"), + ]); + let matcher = ModelMatcher::new(models, cx.background_executor.clone()); + + // Results should preserve models order whenever possible. + // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical + // similarity scores, but `zed/gpt-4.1` was higher in the models list, + // so it should appear first in the results. + let results = matcher.fuzzy_search("41"); + assert_models_eq( + results, + vec![ + "zed/gpt-4.1", + "openai/gpt-4.1", + "zed/gpt-4.1-nano", + "openai/gpt-4.1-nano", + ], + ); + + // Model provider should be searchable as well + let results = matcher.fuzzy_search("ol"); // meaning "ollama" + assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]); + + // Fuzzy search + let results = matcher.fuzzy_search("z4n"); + assert_models_eq(results, vec!["zed/gpt-4.1-nano"]); + } +}