Fix model deduplication to use provider ID and model ID (#31750)

Previously only used model ID for deduplication, which incorrectly
filtered models with the same name from different providers.

Release Notes:

- Fix to make sure all provider models are shown in the model picker
This commit is contained in:
Ben Brandt 2025-05-30 15:49:09 +02:00 committed by GitHub
parent 310ea43048
commit 97c01c6720
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -46,53 +46,35 @@ pub fn language_model_selector(
} }
fn all_models(cx: &App) -> GroupedModels { fn all_models(cx: &App) -> GroupedModels {
let mut recommended = Vec::new(); let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let mut recommended_set = HashSet::default();
for provider in LanguageModelRegistry::global(cx) let recommended = providers
.read(cx)
.providers()
.iter() .iter()
{ .flat_map(|provider| {
let models = provider.recommended_models(cx);
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
recommended.extend(
provider provider
.recommended_models(cx) .recommended_models(cx)
.into_iter() .into_iter()
.map(move |model| ModelInfo { .map(|model| ModelInfo {
model: model.clone(), model,
icon: provider.icon(), icon: provider.icon(),
}), })
); })
} .collect();
let other_models = LanguageModelRegistry::global(cx) let other = providers
.read(cx)
.providers()
.iter() .iter()
.map(|provider| { .flat_map(|provider| {
(
provider.id(),
provider provider
.provided_models(cx) .provided_models(cx)
.into_iter() .into_iter()
.filter_map(|model| { .map(|model| ModelInfo {
let not_included = model,
!recommended_set.contains(&(model.provider_id(), model.id()));
not_included.then(|| ModelInfo {
model: model.clone(),
icon: provider.icon(), icon: provider.icon(),
}) })
}) })
.collect::<Vec<_>>(), .collect();
)
})
.collect::<IndexMap<_, _>>();
GroupedModels { GroupedModels::new(other, recommended)
recommended,
other: other_models,
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -234,11 +216,14 @@ struct GroupedModels {
impl GroupedModels { impl GroupedModels {
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self { pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect(); let recommended_ids = recommended
.iter()
.map(|info| (info.model.provider_id(), info.model.id()))
.collect::<HashSet<_>>();
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default(); let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in other { for model in other {
if recommended_ids.contains(&model.model.id()) { if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
continue; continue;
} }
@ -823,4 +808,26 @@ mod tests {
// Recommended models should not appear in "other" // Recommended models should not appear in "other"
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]); assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
} }
#[gpui::test]
fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models(vec![
("zed", "claude"), // Should be filtered out from "other"
("zed", "gemini"),
("copilot", "claude"), // Should not be filtered out from "other"
]);
let grouped_models = GroupedModels::new(all_models, recommended_models);
let actual_other_models = grouped_models
.other
.values()
.flatten()
.cloned()
.collect::<Vec<_>>();
// Recommended models should not appear in "other"
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
}
} }