assistant: Fix issues when configuring different providers (#15072)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Bennet Bo Fenner 2024-07-24 11:21:31 +02:00 committed by GitHub
parent ba6c36f370
commit af4b9805c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 225 additions and 148 deletions

View file

@ -9,7 +9,7 @@ use crate::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@ -48,7 +48,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
&LanguageModelProviderName::from(
&LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(),
),
cx,
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
impl LanguageModelRegistry {
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
provider: T,
cx: &mut ModelContext<Self>,
) {
let name = provider.name();
let name = provider.id();
if let Some(subscription) = provider.subscribe(cx) {
subscription.detach();
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
pub fn unregister_provider(
&mut self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
if self.providers.remove(name).is_some() {
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
pub fn providers(
&self,
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
self.providers.iter()
}
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
pub fn available_models_grouped_by_provider(
&self,
cx: &AppContext,
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
self.providers
.iter()
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
pub fn provider(
&self,
name: &LanguageModelProviderName,
name: &LanguageModelProviderId,
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
@ -160,10 +160,10 @@ mod tests {
let providers = registry.read(cx).providers().collect::<Vec<_>>();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
});
let providers = registry.read(cx).providers().collect::<Vec<_>>();