diff --git a/crates/assistant_context_editor/src/language_model_selector.rs b/crates/assistant_context_editor/src/language_model_selector.rs index ee29c2eb5e..049c4d24bd 100644 --- a/crates/assistant_context_editor/src/language_model_selector.rs +++ b/crates/assistant_context_editor/src/language_model_selector.rs @@ -46,53 +46,35 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { - let mut recommended = Vec::new(); - let mut recommended_set = HashSet::default(); - for provider in LanguageModelRegistry::global(cx) - .read(cx) - .providers() + let providers = LanguageModelRegistry::global(cx).read(cx).providers(); + + let recommended = providers .iter() - { - let models = provider.recommended_models(cx); - recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id()))); - recommended.extend( + .flat_map(|provider| { provider .recommended_models(cx) .into_iter() - .map(move |model| ModelInfo { - model: model.clone(), + .map(|model| ModelInfo { + model, icon: provider.icon(), - }), - ); - } - - let other_models = LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .iter() - .map(|provider| { - ( - provider.id(), - provider - .provided_models(cx) - .into_iter() - .filter_map(|model| { - let not_included = - !recommended_set.contains(&(model.provider_id(), model.id())); - not_included.then(|| ModelInfo { - model: model.clone(), - icon: provider.icon(), - }) - }) - .collect::>(), - ) + }) }) - .collect::>(); + .collect(); - GroupedModels { - recommended, - other: other_models, - } + let other = providers + .iter() + .flat_map(|provider| { + provider + .provided_models(cx) + .into_iter() + .map(|model| ModelInfo { + model, + icon: provider.icon(), + }) + }) + .collect(); + + GroupedModels::new(other, recommended) } #[derive(Clone)] @@ -234,11 +216,14 @@ struct GroupedModels { impl GroupedModels { pub fn new(other: Vec, recommended: Vec) -> Self { - let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect(); + let recommended_ids = recommended + .iter() + .map(|info| (info.model.provider_id(), info.model.id())) + .collect::>(); let mut other_by_provider: IndexMap<_, Vec> = IndexMap::default(); for model in other { - if recommended_ids.contains(&model.model.id()) { + if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) { continue; } @@ -823,4 +808,26 @@ mod tests { // Recommended models should not appear in "other" assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]); } + + #[gpui::test] + fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) { + let recommended_models = create_models(vec![("zed", "claude")]); + let all_models = create_models(vec![ + ("zed", "claude"), // Should be filtered out from "other" + ("zed", "gemini"), + ("copilot", "claude"), // Should not be filtered out from "other" + ]); + + let grouped_models = GroupedModels::new(all_models, recommended_models); + + let actual_other_models = grouped_models + .other + .values() + .flatten() + .cloned() + .collect::>(); + + // Recommended models should not appear in "other" + assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]); + } }