diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index f499c4c8f5..727f51ea55 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -60,6 +60,7 @@ impl RenderOnce for ModelSelector { for (index, provider) in LanguageModelRegistry::global(cx) .read(cx) .providers() + .into_iter() .enumerate() { if index > 0 { diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 8bda65b07a..a3af7e6b18 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -132,8 +132,20 @@ impl LanguageModelRegistry { } } - pub fn providers(&self) -> impl Iterator> { - self.providers.values() + pub fn providers(&self) -> Vec> { + let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into()); + let mut providers = Vec::with_capacity(self.providers.len()); + if let Some(provider) = self.providers.get(&zed_provider_id) { + providers.push(provider.clone()); + } + providers.extend(self.providers.values().filter_map(|p| { + if p.id() != zed_provider_id { + Some(p.clone()) + } else { + None + } + })); + providers } pub fn available_models(&self, cx: &AppContext) -> Vec> { @@ -222,7 +234,7 @@ mod tests { registry.register_provider(FakeLanguageModelProvider::default(), cx); }); - let providers = registry.read(cx).providers().collect::>(); + let providers = registry.read(cx).providers(); assert_eq!(providers.len(), 1); assert_eq!(providers[0].id(), crate::provider::fake::provider_id()); @@ -230,7 +242,7 @@ mod tests { registry.unregister_provider(&crate::provider::fake::provider_id(), cx); }); - let providers = registry.read(cx).providers().collect::>(); + let providers = registry.read(cx).providers(); assert!(providers.is_empty()); } }