Fix model deduplication to use provider ID and model ID (#31750)
Previously only used model ID for deduplication, which incorrectly filtered models with the same name from different providers. Release Notes: - Fix to make sure all provider models are shown in the model picker
This commit is contained in:
parent
310ea43048
commit
97c01c6720
1 changed files with 49 additions and 42 deletions
|
@ -46,53 +46,35 @@ pub fn language_model_selector(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_models(cx: &App) -> GroupedModels {
|
fn all_models(cx: &App) -> GroupedModels {
|
||||||
let mut recommended = Vec::new();
|
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
|
||||||
let mut recommended_set = HashSet::default();
|
|
||||||
for provider in LanguageModelRegistry::global(cx)
|
let recommended = providers
|
||||||
.read(cx)
|
|
||||||
.providers()
|
|
||||||
.iter()
|
.iter()
|
||||||
{
|
.flat_map(|provider| {
|
||||||
let models = provider.recommended_models(cx);
|
|
||||||
recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
|
|
||||||
recommended.extend(
|
|
||||||
provider
|
provider
|
||||||
.recommended_models(cx)
|
.recommended_models(cx)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(move |model| ModelInfo {
|
.map(|model| ModelInfo {
|
||||||
model: model.clone(),
|
model,
|
||||||
icon: provider.icon(),
|
icon: provider.icon(),
|
||||||
}),
|
})
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let other_models = LanguageModelRegistry::global(cx)
|
|
||||||
.read(cx)
|
|
||||||
.providers()
|
|
||||||
.iter()
|
|
||||||
.map(|provider| {
|
|
||||||
(
|
|
||||||
provider.id(),
|
|
||||||
provider
|
|
||||||
.provided_models(cx)
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|model| {
|
|
||||||
let not_included =
|
|
||||||
!recommended_set.contains(&(model.provider_id(), model.id()));
|
|
||||||
not_included.then(|| ModelInfo {
|
|
||||||
model: model.clone(),
|
|
||||||
icon: provider.icon(),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.collect::<IndexMap<_, _>>();
|
.collect();
|
||||||
|
|
||||||
GroupedModels {
|
let other = providers
|
||||||
recommended,
|
.iter()
|
||||||
other: other_models,
|
.flat_map(|provider| {
|
||||||
}
|
provider
|
||||||
|
.provided_models(cx)
|
||||||
|
.into_iter()
|
||||||
|
.map(|model| ModelInfo {
|
||||||
|
model,
|
||||||
|
icon: provider.icon(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
GroupedModels::new(other, recommended)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
@ -234,11 +216,14 @@ struct GroupedModels {
|
||||||
|
|
||||||
impl GroupedModels {
|
impl GroupedModels {
|
||||||
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
|
||||||
let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
|
let recommended_ids = recommended
|
||||||
|
.iter()
|
||||||
|
.map(|info| (info.model.provider_id(), info.model.id()))
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
|
||||||
for model in other {
|
for model in other {
|
||||||
if recommended_ids.contains(&model.model.id()) {
|
if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -823,4 +808,26 @@ mod tests {
|
||||||
// Recommended models should not appear in "other"
|
// Recommended models should not appear in "other"
|
||||||
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
|
||||||
|
let recommended_models = create_models(vec![("zed", "claude")]);
|
||||||
|
let all_models = create_models(vec![
|
||||||
|
("zed", "claude"), // Should be filtered out from "other"
|
||||||
|
("zed", "gemini"),
|
||||||
|
("copilot", "claude"), // Should not be filtered out from "other"
|
||||||
|
]);
|
||||||
|
|
||||||
|
let grouped_models = GroupedModels::new(all_models, recommended_models);
|
||||||
|
|
||||||
|
let actual_other_models = grouped_models
|
||||||
|
.other
|
||||||
|
.values()
|
||||||
|
.flatten()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Recommended models should not appear in "other"
|
||||||
|
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue