assistant: Fix issues when configuring different providers (#15072)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
ba6c36f370
commit
af4b9805c9
16 changed files with 225 additions and 148 deletions
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
|
||||
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
|
||||
},
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
|
||||
};
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||
|
|
@ -48,7 +48,7 @@ fn register_language_model_providers(
|
|||
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
|
||||
} else {
|
||||
registry.unregister_provider(
|
||||
&LanguageModelProviderName::from(
|
||||
&LanguageModelProviderId::from(
|
||||
crate::provider::cloud::PROVIDER_NAME.to_string(),
|
||||
),
|
||||
cx,
|
||||
|
|
@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct LanguageModelRegistry {
|
||||
providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
|
||||
providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
|
||||
}
|
||||
|
||||
impl LanguageModelRegistry {
|
||||
|
|
@ -94,7 +94,7 @@ impl LanguageModelRegistry {
|
|||
provider: T,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
let name = provider.name();
|
||||
let name = provider.id();
|
||||
|
||||
if let Some(subscription) = provider.subscribe(cx) {
|
||||
subscription.detach();
|
||||
|
|
@ -106,7 +106,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn unregister_provider(
|
||||
&mut self,
|
||||
name: &LanguageModelProviderName,
|
||||
name: &LanguageModelProviderId,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if self.providers.remove(name).is_some() {
|
||||
|
|
@ -116,7 +116,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn providers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
|
||||
) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
|
|
@ -130,7 +130,7 @@ impl LanguageModelRegistry {
|
|||
pub fn available_models_grouped_by_provider(
|
||||
&self,
|
||||
cx: &AppContext,
|
||||
) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
|
||||
) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
|
||||
self.providers
|
||||
.iter()
|
||||
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
|
||||
|
|
@ -139,7 +139,7 @@ impl LanguageModelRegistry {
|
|||
|
||||
pub fn provider(
|
||||
&self,
|
||||
name: &LanguageModelProviderName,
|
||||
name: &LanguageModelProviderId,
|
||||
) -> Option<Arc<dyn LanguageModelProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
|
@ -160,10 +160,10 @@ mod tests {
|
|||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
assert_eq!(providers.len(), 1);
|
||||
assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
|
||||
assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
|
||||
registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
|
||||
});
|
||||
|
||||
let providers = registry.read(cx).providers().collect::<Vec<_>>();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue