ZIm/crates/language_model/src/registry.rs
Ben Brandt 9427833fdf
Distinguish between missing models and registries in error messages (#32678)
Consolidates configuration error handling by moving the error type and
logic from assistant_context_editor to language_model::registry.

The registry now provides a single method to check for configuration
errors, making the error handling more consistent across the agent panel
and context editor.

This also now checks if the issue is that we don't have any providers,
or if we just can't find the model.

Previously, an incorrect model name showed up as having no providers,
which is very confusing.

Release Notes:

- N/A
2025-06-13 10:31:52 +00:00

421 lines
14 KiB
Rust

use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use collections::BTreeMap;
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
}
struct GlobalLanguageModelRegistry(Entity<LanguageModelRegistry>);
impl Global for GlobalLanguageModelRegistry {}
#[derive(Error)]
pub enum ConfigurationError {
#[error("Configure at least one LLM provider to start using the panel.")]
NoProvider,
#[error("LLM Provider is not configured or does not support the configured model.")]
ModelNotFound,
#[error("{} LLM provider is not configured.", .0.name().0)]
ProviderNotAuthenticated(Arc<dyn LanguageModelProvider>),
#[error("Using the {} LLM provider requires accepting the Terms of Service.",
.0.name().0)]
ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
}
impl std::fmt::Debug for ConfigurationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoProvider => write!(f, "NoProvider"),
Self::ModelNotFound => write!(f, "ModelNotFound"),
Self::ProviderNotAuthenticated(provider) => {
write!(f, "ProviderNotAuthenticated({})", provider.id())
}
Self::ProviderPendingTermsAcceptance(provider) => {
write!(f, "ProviderPendingTermsAcceptance({})", provider.id())
}
}
}
}
#[derive(Default)]
pub struct LanguageModelRegistry {
default_model: Option<ConfiguredModel>,
default_fast_model: Option<ConfiguredModel>,
inline_assistant_model: Option<ConfiguredModel>,
commit_message_model: Option<ConfiguredModel>,
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
#[derive(Debug)]
pub struct SelectedModel {
pub provider: LanguageModelProviderId,
pub model: LanguageModelId,
}
impl FromStr for SelectedModel {
type Err = String;
/// Parse string identifiers like `provider_id/model_id` into a `SelectedModel`
fn from_str(id: &str) -> Result<SelectedModel, Self::Err> {
let parts: Vec<&str> = id.split('/').collect();
let [provider_id, model_id] = parts.as_slice() else {
return Err(format!(
"Invalid model identifier format: `{}`. Expected `provider_id/model_id`",
id
));
};
if provider_id.is_empty() || model_id.is_empty() {
return Err(format!("Provider and model ids can't be empty: `{}`", id));
}
Ok(SelectedModel {
provider: LanguageModelProviderId(provider_id.to_string().into()),
model: LanguageModelId(model_id.to_string().into()),
})
}
}
#[derive(Clone)]
pub struct ConfiguredModel {
pub provider: Arc<dyn LanguageModelProvider>,
pub model: Arc<dyn LanguageModel>,
}
impl ConfiguredModel {
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
}
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
}
}
pub enum Event {
DefaultModelChanged,
InlineAssistantModelChanged,
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
}
impl EventEmitter<Event> for LanguageModelRegistry {}
impl LanguageModelRegistry {
pub fn global(cx: &App) -> Entity<Self> {
cx.global::<GlobalLanguageModelRegistry>().0.clone()
}
pub fn read_global(cx: &App) -> &Self {
cx.global::<GlobalLanguageModelRegistry>().0.read(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn test(cx: &mut App) -> crate::fake_provider::FakeLanguageModelProvider {
let fake_provider = crate::fake_provider::FakeLanguageModelProvider;
let registry = cx.new(|cx| {
let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone();
let configured_model = ConfiguredModel {
provider: Arc::new(fake_provider.clone()),
model,
};
registry.set_default_model(Some(configured_model), cx);
registry
});
cx.set_global(GlobalLanguageModelRegistry(registry));
fake_provider
}
pub fn register_provider<T: LanguageModelProvider + LanguageModelProviderState>(
&mut self,
provider: T,
cx: &mut Context<Self>,
) {
let id = provider.id();
let subscription = provider.subscribe(cx, |_, cx| {
cx.emit(Event::ProviderStateChanged);
});
if let Some(subscription) = subscription {
subscription.detach();
}
self.providers.insert(id.clone(), Arc::new(provider));
cx.emit(Event::AddedProvider(id));
}
pub fn unregister_provider(&mut self, id: LanguageModelProviderId, cx: &mut Context<Self>) {
if self.providers.remove(&id).is_some() {
cx.emit(Event::RemovedProvider(id));
}
}
pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
let zed_provider_id = LanguageModelProviderId("zed.dev".into());
let mut providers = Vec::with_capacity(self.providers.len());
if let Some(provider) = self.providers.get(&zed_provider_id) {
providers.push(provider.clone());
}
providers.extend(self.providers.values().filter_map(|p| {
if p.id() != zed_provider_id {
Some(p.clone())
} else {
None
}
}));
providers
}
pub fn configuration_error(
&self,
model: Option<ConfiguredModel>,
cx: &App,
) -> Option<ConfigurationError> {
let Some(model) = model else {
if !self.has_authenticated_provider(cx) {
return Some(ConfigurationError::NoProvider);
}
return Some(ConfigurationError::ModelNotFound);
};
if !model.provider.is_authenticated(cx) {
return Some(ConfigurationError::ProviderNotAuthenticated(model.provider));
}
if model.provider.must_accept_terms(cx) {
return Some(ConfigurationError::ProviderPendingTermsAcceptance(
model.provider,
));
}
None
}
/// Check that we have at least one provider that is authenticated.
fn has_authenticated_provider(&self, cx: &App) -> bool {
self.providers.values().any(|p| p.is_authenticated(cx))
}
pub fn available_models<'a>(
&'a self,
cx: &'a App,
) -> impl Iterator<Item = Arc<dyn LanguageModel>> + 'a {
self.providers
.values()
.flat_map(|provider| provider.provided_models(cx))
}
pub fn provider(&self, id: &LanguageModelProviderId) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(id).cloned()
}
pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_default_model(configured_model, cx);
}
pub fn select_inline_assistant_model(
&mut self,
model: Option<&SelectedModel>,
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_inline_assistant_model(configured_model, cx);
}
pub fn select_commit_message_model(
&mut self,
model: Option<&SelectedModel>,
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_commit_message_model(configured_model, cx);
}
pub fn select_thread_summary_model(
&mut self,
model: Option<&SelectedModel>,
cx: &mut Context<Self>,
) {
let configured_model = model.and_then(|model| self.select_model(model, cx));
self.set_thread_summary_model(configured_model, cx);
}
/// Selects and sets the inline alternatives for language models based on
/// provider name and id.
pub fn select_inline_alternative_models(
&mut self,
alternatives: impl IntoIterator<Item = SelectedModel>,
cx: &mut Context<Self>,
) {
self.inline_alternatives = alternatives
.into_iter()
.flat_map(|alternative| {
self.select_model(&alternative, cx)
.map(|configured_model| configured_model.model)
})
.collect::<Vec<_>>();
}
pub fn select_model(
&mut self,
selected_model: &SelectedModel,
cx: &mut Context<Self>,
) -> Option<ConfiguredModel> {
let provider = self.provider(&selected_model.provider)?;
let model = provider
.provided_models(cx)
.iter()
.find(|model| model.id() == selected_model.model)?
.clone();
Some(ConfiguredModel { provider, model })
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
match (self.default_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
}
self.default_fast_model = maybe!({
let provider = &model.as_ref()?.provider;
let fast_model = provider.default_fast_model(cx)?;
Some(ConfiguredModel {
provider: provider.clone(),
model: fast_model,
})
});
self.default_model = model;
}
pub fn set_inline_assistant_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::InlineAssistantModelChanged),
}
self.inline_assistant_model = model;
}
pub fn set_commit_message_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.commit_message_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::CommitMessageModelChanged),
}
self.commit_message_model = model;
}
pub fn set_thread_summary_model(
&mut self,
model: Option<ConfiguredModel>,
cx: &mut Context<Self>,
) {
match (self.thread_summary_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::ThreadSummaryModelChanged),
}
self.thread_summary_model = model;
}
pub fn default_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.default_model.clone()
}
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.inline_assistant_model
.clone()
.or_else(|| self.default_model.clone())
}
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.commit_message_model
.clone()
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_model.clone())
}
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
return None;
}
self.thread_summary_model
.clone()
.or_else(|| self.default_fast_model.clone())
.or_else(|| self.default_model.clone())
}
/// The models to use for inline assists. Returns the union of the active
/// model and all inline alternatives. When there are multiple models, the
/// user will be able to cycle through results.
pub fn inline_alternative_models(&self) -> &[Arc<dyn LanguageModel>] {
&self.inline_alternatives
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fake_provider::FakeLanguageModelProvider;
#[gpui::test]
fn test_register_providers(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
registry.update(cx, |registry, cx| {
registry.register_provider(FakeLanguageModelProvider, cx);
});
let providers = registry.read(cx).providers();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].id(), crate::fake_provider::provider_id());
registry.update(cx, |registry, cx| {
registry.unregister_provider(crate::fake_provider::provider_id(), cx);
});
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
}