use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; use collections::BTreeMap; use gpui::{prelude::*, App, Context, Entity, EventEmitter, Global}; use std::sync::Arc; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); cx.set_global(GlobalLanguageModelRegistry(registry)); } struct GlobalLanguageModelRegistry(Entity); impl Global for GlobalLanguageModelRegistry {} #[derive(Default)] pub struct LanguageModelRegistry { active_model: Option, editor_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, } pub struct ActiveModel { provider: Arc, model: Option>, } pub enum Event { ActiveModelChanged, EditorModelChanged, ProviderStateChanged, AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } impl EventEmitter for LanguageModelRegistry {} impl LanguageModelRegistry { pub fn global(cx: &App) -> Entity { cx.global::().0.clone() } pub fn read_global(cx: &App) -> &Self { cx.global::().0.read(cx) } #[cfg(any(test, feature = "test-support"))] pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider { let fake_provider = crate::fake_provider::FakeLanguageModelProvider; let registry = cx.new(|cx| { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); let model = fake_provider.provided_models(cx)[0].clone(); registry.set_active_model(Some(model), cx); registry }); cx.set_global(GlobalLanguageModelRegistry(registry)); fake_provider } pub fn register_provider( &mut self, provider: T, cx: &mut Context, ) { let id = provider.id(); let subscription = provider.subscribe(cx, |_, cx| { cx.emit(Event::ProviderStateChanged); }); if let Some(subscription) = subscription { subscription.detach(); } self.providers.insert(id.clone(), Arc::new(provider)); cx.emit(Event::AddedProvider(id)); } pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context) { if self.providers.remove(&id).is_some() { cx.emit(Event::RemovedProvider(id)); } } pub fn providers(&self) -> Vec> { let zed_provider_id = LanguageModelProviderId("zed.dev".into()); let mut providers = Vec::with_capacity(self.providers.len()); if let Some(provider) = self.providers.get(&zed_provider_id) { providers.push(provider.clone()); } providers.extend(self.providers.values().filter_map(|p| { if p.id() != zed_provider_id { Some(p.clone()) } else { None } })); providers } pub fn available_models<'a>( &'a self, cx: &'a App, ) -> impl Iterator> + 'a { self.providers .values() .flat_map(|provider| provider.provided_models(cx)) } pub fn provider(&self, id: &LanguageModelProviderId) -> Option> { self.providers.get(id).cloned() } pub fn select_active_model( &mut self, provider: &LanguageModelProviderId, model_id: &LanguageModelId, cx: &mut Context, ) { let Some(provider) = self.provider(provider) else { return; }; let models = provider.provided_models(cx); if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { self.set_active_model(Some(model), cx); } } pub fn select_editor_model( &mut self, provider: &LanguageModelProviderId, model_id: &LanguageModelId, cx: &mut Context, ) { let Some(provider) = self.provider(provider) else { return; }; let models = provider.provided_models(cx); if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { self.set_editor_model(Some(model), cx); } } pub fn set_active_provider( &mut self, provider: Option>, cx: &mut Context, ) { self.active_model = provider.map(|provider| ActiveModel { provider, model: None, }); cx.emit(Event::ActiveModelChanged); } pub fn set_active_model( &mut self, model: Option>, cx: &mut Context, ) { if let Some(model) = model { let provider_id = model.provider_id(); if let Some(provider) = self.providers.get(&provider_id).cloned() { self.active_model = Some(ActiveModel { provider, model: Some(model), }); cx.emit(Event::ActiveModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { self.active_model = None; cx.emit(Event::ActiveModelChanged); } } pub fn set_editor_model( &mut self, model: Option>, cx: &mut Context, ) { if let Some(model) = model { let provider_id = model.provider_id(); if let Some(provider) = self.providers.get(&provider_id).cloned() { self.editor_model = Some(ActiveModel { provider, model: Some(model), }); cx.emit(Event::EditorModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { self.editor_model = None; cx.emit(Event::EditorModelChanged); } } pub fn active_provider(&self) -> Option> { Some(self.active_model.as_ref()?.provider.clone()) } pub fn active_model(&self) -> Option> { self.active_model.as_ref()?.model.clone() } pub fn editor_model(&self) -> Option> { self.editor_model.as_ref()?.model.clone() } /// Selects and sets the inline alternatives for language models based on /// provider name and id. pub fn select_inline_alternative_models( &mut self, alternatives: impl IntoIterator, cx: &mut Context, ) { let mut selected_alternatives = Vec::new(); for (provider_id, model_id) in alternatives { if let Some(provider) = self.providers.get(&provider_id) { if let Some(model) = provider .provided_models(cx) .iter() .find(|m| m.id() == model_id) { selected_alternatives.push(model.clone()); } } } self.inline_alternatives = selected_alternatives; } /// The models to use for inline assists. Returns the union of the active /// model and all inline alternatives. When there are multiple models, the /// user will be able to cycle through results. pub fn inline_alternative_models(&self) -> &[Arc] { &self.inline_alternatives } } #[cfg(test)] mod tests { use super::*; use crate::fake_provider::FakeLanguageModelProvider; #[gpui::test] fn test_register_providers(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); registry.update(cx, |registry, cx| { registry.register_provider(FakeLanguageModelProvider, cx); }); let providers = registry.read(cx).providers(); assert_eq!(providers.len(), 1); assert_eq!(providers[0].id(), crate::fake_provider::provider_id()); registry.update(cx, |registry, cx| { registry.unregister_provider(crate::fake_provider::provider_id(), cx); }); let providers = registry.read(cx).providers(); assert!(providers.is_empty()); } }