use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; use collections::BTreeMap; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; use thiserror::Error; use util::maybe; pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); cx.set_global(GlobalLanguageModelRegistry(registry)); } struct GlobalLanguageModelRegistry(Entity); impl Global for GlobalLanguageModelRegistry {} #[derive(Error)] pub enum ConfigurationError { #[error("Configure at least one LLM provider to start using the panel.")] NoProvider, #[error("LLM provider is not configured or does not support the configured model.")] ModelNotFound, #[error("{} LLM provider is not configured.", .0.name().0)] ProviderNotAuthenticated(Arc), } impl std::fmt::Debug for ConfigurationError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::NoProvider => write!(f, "NoProvider"), Self::ModelNotFound => write!(f, "ModelNotFound"), Self::ProviderNotAuthenticated(provider) => { write!(f, "ProviderNotAuthenticated({})", provider.id()) } } } } #[derive(Default)] pub struct LanguageModelRegistry { default_model: Option, default_fast_model: Option, inline_assistant_model: Option, commit_message_model: Option, thread_summary_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, } #[derive(Debug)] pub struct SelectedModel { pub provider: LanguageModelProviderId, pub model: LanguageModelId, } impl FromStr for SelectedModel { type Err = String; /// Parse string identifiers like `provider_id/model_id` into a `SelectedModel` fn from_str(id: &str) -> Result { let parts: Vec<&str> = id.split('/').collect(); let [provider_id, model_id] = parts.as_slice() else { return Err(format!( "Invalid model identifier format: `{}`. Expected `provider_id/model_id`", id )); }; if provider_id.is_empty() || model_id.is_empty() { return Err(format!("Provider and model ids can't be empty: `{}`", id)); } Ok(SelectedModel { provider: LanguageModelProviderId(provider_id.to_string().into()), model: LanguageModelId(model_id.to_string().into()), }) } } #[derive(Clone)] pub struct ConfiguredModel { pub provider: Arc, pub model: Arc, } impl ConfiguredModel { pub fn is_same_as(&self, other: &ConfiguredModel) -> bool { self.model.id() == other.model.id() && self.provider.id() == other.provider.id() } pub fn is_provided_by_zed(&self) -> bool { self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID } } pub enum Event { DefaultModelChanged, InlineAssistantModelChanged, CommitMessageModelChanged, ThreadSummaryModelChanged, ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } impl EventEmitter for LanguageModelRegistry {} impl LanguageModelRegistry { pub fn global(cx: &App) -> Entity { cx.global::().0.clone() } pub fn read_global(cx: &App) -> &Self { cx.global::().0.read(cx) } #[cfg(any(test, feature = "test-support"))] pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider { let fake_provider = crate::fake_provider::FakeLanguageModelProvider::default(); let registry = cx.new(|cx| { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); let model = fake_provider.provided_models(cx)[0].clone(); let configured_model = ConfiguredModel { provider: Arc::new(fake_provider.clone()), model, }; registry.set_default_model(Some(configured_model), cx); registry }); cx.set_global(GlobalLanguageModelRegistry(registry)); fake_provider } pub fn register_provider( &mut self, provider: T, cx: &mut Context, ) { let id = provider.id(); let subscription = provider.subscribe(cx, { let id = id.clone(); move |_, cx| { cx.emit(Event::ProviderStateChanged(id.clone())); } }); if let Some(subscription) = subscription { subscription.detach(); } self.providers.insert(id.clone(), Arc::new(provider)); cx.emit(Event::AddedProvider(id)); } pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context) { if self.providers.remove(&id).is_some() { cx.emit(Event::RemovedProvider(id)); } } pub fn providers(&self) -> Vec> { let zed_provider_id = LanguageModelProviderId("zed.dev".into()); let mut providers = Vec::with_capacity(self.providers.len()); if let Some(provider) = self.providers.get(&zed_provider_id) { providers.push(provider.clone()); } providers.extend(self.providers.values().filter_map(|p| { if p.id() != zed_provider_id { Some(p.clone()) } else { None } })); providers } pub fn configuration_error( &self, model: Option, cx: &App, ) -> Option { let Some(model) = model else { if !self.has_authenticated_provider(cx) { return Some(ConfigurationError::NoProvider); } return Some(ConfigurationError::ModelNotFound); }; if !model.provider.is_authenticated(cx) { return Some(ConfigurationError::ProviderNotAuthenticated(model.provider)); } None } /// Returns `true` if at least one provider that is authenticated. pub fn has_authenticated_provider(&self, cx: &App) -> bool { self.providers.values().any(|p| p.is_authenticated(cx)) } pub fn available_models<'a>( &'a self, cx: &'a App, ) -> impl Iterator> + 'a { self.providers .values() .flat_map(|provider| provider.provided_models(cx)) } pub fn provider(&self, id: &LanguageModelProviderId) -> Option> { self.providers.get(id).cloned() } pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context) { let configured_model = model.and_then(|model| self.select_model(model, cx)); self.set_default_model(configured_model, cx); } pub fn select_inline_assistant_model( &mut self, model: Option<&SelectedModel>, cx: &mut Context, ) { let configured_model = model.and_then(|model| self.select_model(model, cx)); self.set_inline_assistant_model(configured_model, cx); } pub fn select_commit_message_model( &mut self, model: Option<&SelectedModel>, cx: &mut Context, ) { let configured_model = model.and_then(|model| self.select_model(model, cx)); self.set_commit_message_model(configured_model, cx); } pub fn select_thread_summary_model( &mut self, model: Option<&SelectedModel>, cx: &mut Context, ) { let configured_model = model.and_then(|model| self.select_model(model, cx)); self.set_thread_summary_model(configured_model, cx); } /// Selects and sets the inline alternatives for language models based on /// provider name and id. pub fn select_inline_alternative_models( &mut self, alternatives: impl IntoIterator, cx: &mut Context, ) { self.inline_alternatives = alternatives .into_iter() .flat_map(|alternative| { self.select_model(&alternative, cx) .map(|configured_model| configured_model.model) }) .collect::>(); } pub fn select_model( &mut self, selected_model: &SelectedModel, cx: &mut Context, ) -> Option { let provider = self.provider(&selected_model.provider)?; let model = provider .provided_models(cx) .iter() .find(|model| model.id() == selected_model.model)? .clone(); Some(ConfiguredModel { provider, model }) } pub fn set_default_model(&mut self, model: Option, cx: &mut Context) { match (self.default_model.as_ref(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::DefaultModelChanged), } self.default_fast_model = maybe!({ let provider = &model.as_ref()?.provider; let fast_model = provider.default_fast_model(cx)?; Some(ConfiguredModel { provider: provider.clone(), model: fast_model, }) }); self.default_model = model; } pub fn set_inline_assistant_model( &mut self, model: Option, cx: &mut Context, ) { match (self.inline_assistant_model.as_ref(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::InlineAssistantModelChanged), } self.inline_assistant_model = model; } pub fn set_commit_message_model( &mut self, model: Option, cx: &mut Context, ) { match (self.commit_message_model.as_ref(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::CommitMessageModelChanged), } self.commit_message_model = model; } pub fn set_thread_summary_model( &mut self, model: Option, cx: &mut Context, ) { match (self.thread_summary_model.as_ref(), model.as_ref()) { (Some(old), Some(new)) if old.is_same_as(new) => {} (None, None) => {} _ => cx.emit(Event::ThreadSummaryModelChanged), } self.thread_summary_model = model; } pub fn default_model(&self) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; } self.default_model.clone() } pub fn inline_assistant_model(&self) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; } self.inline_assistant_model .clone() .or_else(|| self.default_model.clone()) } pub fn commit_message_model(&self) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; } self.commit_message_model .clone() .or_else(|| self.default_fast_model.clone()) .or_else(|| self.default_model.clone()) } pub fn thread_summary_model(&self) -> Option { #[cfg(debug_assertions)] if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() { return None; } self.thread_summary_model .clone() .or_else(|| self.default_fast_model.clone()) .or_else(|| self.default_model.clone()) } /// The models to use for inline assists. Returns the union of the active /// model and all inline alternatives. When there are multiple models, the /// user will be able to cycle through results. pub fn inline_alternative_models(&self) -> &[Arc] { &self.inline_alternatives } } #[cfg(test)] mod tests { use super::*; use crate::fake_provider::FakeLanguageModelProvider; #[gpui::test] fn test_register_providers(cx: &mut App) { let registry = cx.new(|_| LanguageModelRegistry::default()); let provider = FakeLanguageModelProvider::default(); registry.update(cx, |registry, cx| { registry.register_provider(provider.clone(), cx); }); let providers = registry.read(cx).providers(); assert_eq!(providers.len(), 1); assert_eq!(providers[0].id(), provider.id()); registry.update(cx, |registry, cx| { registry.unregister_provider(provider.id(), cx); }); let providers = registry.read(cx).providers(); assert!(providers.is_empty()); } }