From d88b06a5dc7413b065275695006d3499f9606e3d Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sat, 19 Apr 2025 02:26:42 -0600 Subject: [PATCH] Simplify language model registry + only emit change events on change (#29086) * Now only does default fallback logic in the registry * Only emits change events when there is actually a change Release Notes: - N/A --- Cargo.lock | 1 - crates/assistant/src/assistant.rs | 73 +++------- crates/eval/src/eval.rs | 41 +++--- crates/language_model/Cargo.toml | 1 - crates/language_model/src/registry.rs | 197 +++++++++++--------------- 5 files changed, 119 insertions(+), 194 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e372c6a562..d6ff13e748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7651,7 +7651,6 @@ dependencies = [ "http_client", "icons", "image", - "log", "open_ai", "parking_lot", "proto", diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 16b775865d..33add7b4b4 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -8,7 +8,7 @@ mod terminal_inline_assistant; use std::sync::Arc; -use assistant_settings::AssistantSettings; +use assistant_settings::{AssistantSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; use client::Client; use command_palette_hooks::CommandPaletteFilter; @@ -161,71 +161,38 @@ fn init_language_model_settings(cx: &mut App) { fn update_active_language_model_from_settings(cx: &mut App) { let settings = AssistantSettings::get_global(cx); - // Default model - used as fallback - let active_model_provider_name = - LanguageModelProviderId::from(settings.default_model.provider.clone()); - let active_model_id = LanguageModelId::from(settings.default_model.model.clone()); - // Inline assistant model - let inline_assistant_model = settings + fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel { + language_model::SelectedModel { + provider: LanguageModelProviderId::from(selection.provider.clone()), + model: LanguageModelId::from(selection.model.clone()), + } + } + + let default = to_selected_model(&settings.default_model); + let inline_assistant = settings .inline_assistant_model .as_ref() - .unwrap_or(&settings.default_model); - let inline_assistant_provider_name = - LanguageModelProviderId::from(inline_assistant_model.provider.clone()); - let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone()); - - // Commit message model - let commit_message_model = settings + .map(to_selected_model); + let commit_message = settings .commit_message_model .as_ref() - .unwrap_or(&settings.default_model); - let commit_message_provider_name = - LanguageModelProviderId::from(commit_message_model.provider.clone()); - let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone()); - - // Thread summary model - let thread_summary_model = settings + .map(to_selected_model); + let thread_summary = settings .thread_summary_model .as_ref() - .unwrap_or(&settings.default_model); - let thread_summary_provider_name = - LanguageModelProviderId::from(thread_summary_model.provider.clone()); - let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone()); - + .map(to_selected_model); let inline_alternatives = settings .inline_alternatives .iter() - .map(|alternative| { - ( - LanguageModelProviderId::from(alternative.provider.clone()), - LanguageModelId::from(alternative.model.clone()), - ) - }) + .map(to_selected_model) .collect::>(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - // Set the default model - registry.select_default_model(&active_model_provider_name, &active_model_id, cx); - - // Set the specific models - registry.select_inline_assistant_model( - &inline_assistant_provider_name, - &inline_assistant_model_id, - cx, - ); - registry.select_commit_message_model( - &commit_message_provider_name, - &commit_message_model_id, - cx, - ); - registry.select_thread_summary_model( - &thread_summary_provider_name, - &thread_summary_model_id, - cx, - ); - - // Set the alternatives + registry.select_default_model(Some(&default), cx); + registry.select_inline_assistant_model(inline_assistant.as_ref(), cx); + registry.select_commit_message_model(commit_message.as_ref(), cx); + registry.select_thread_summary_model(thread_summary.as_ref(), cx); registry.select_inline_alternative_models(inline_alternatives, cx); }); } diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 5b465a5e74..90d11f616b 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -11,12 +11,10 @@ use clap::Parser; use extension::ExtensionHostProxy; use futures::{StreamExt, future}; use gpui::http_client::{Uri, read_proxy_from_env}; -use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; -use language_model::{ - AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, -}; +use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry}; use node_runtime::{NodeBinaryOptions, NodeRuntime}; use project::Project; use project::project_settings::ProjectSettings; @@ -94,18 +92,25 @@ fn main() { .telemetry() .start(system_id, installation_id, session_id, cx); - let model = find_model("claude-3-7-sonnet-latest", cx).unwrap(); + let model_registry = LanguageModelRegistry::read_global(cx); + let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap(); + let model_provider_id = model.provider_id(); + let model_provider = model_registry.provider(&model_provider_id).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model(Some(model.clone()), cx); + registry.set_default_model( + Some(ConfiguredModel { + provider: model_provider.clone(), + model: model.clone(), + }), + cx, + ); }); - let model_provider_id = model.provider_id(); - - let authenticate = authenticate_model_provider(model_provider_id.clone(), cx); + let authenticate_task = model_provider.authenticate(cx); cx.spawn(async move |cx| { - authenticate.await.unwrap(); + authenticate_task.await.unwrap(); std::fs::create_dir_all(REPOS_DIR)?; std::fs::create_dir_all(WORKTREES_DIR)?; @@ -498,8 +503,11 @@ pub fn init(cx: &mut App) -> Arc { }) } -pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { - let model_registry = LanguageModelRegistry::read_global(cx); +pub fn find_model( + model_name: &str, + model_registry: &LanguageModelRegistry, + cx: &App, +) -> anyhow::Result> { let model = model_registry .available_models(cx) .find(|model| model.id().0 == model_name); @@ -519,15 +527,6 @@ pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result Task> { - let model_registry = LanguageModelRegistry::read_global(cx); - let model_provider = model_registry.provider(&provider_id).unwrap(); - model_provider.authenticate(cx) -} - pub async fn get_current_commit_id(repo_path: &Path) -> Option { (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok() } diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index c468ff8297..f9b4afda79 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -27,7 +27,6 @@ gpui.workspace = true http_client.workspace = true icons.workspace = true image.workspace = true -log.workspace = true open_ai = { workspace = true, features = ["schemars"] } parking_lot.workspace = true proto.workspace = true diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 9ef183189e..45be22457f 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -25,12 +25,23 @@ pub struct LanguageModelRegistry { inline_alternatives: Vec>, } +pub struct SelectedModel { + pub provider: LanguageModelProviderId, + pub model: LanguageModelId, +} + #[derive(Clone)] pub struct ConfiguredModel { pub provider: Arc, pub model: Arc, } +impl ConfiguredModel { + pub fn is_same_as(&self, other: &ConfiguredModel) -> bool { + self.model.id() == other.model.id() && self.provider.id() == other.provider.id() + } +} + pub enum Event { DefaultModelChanged, InlineAssistantModelChanged, @@ -59,7 +70,11 @@ impl LanguageModelRegistry { let mut registry = Self::default(); registry.register_provider(fake_provider.clone(), cx); let model = fake_provider.provided_models(cx)[0].clone(); - registry.set_default_model(Some(model), cx); + let configured_model = ConfiguredModel { + provider: Arc::new(fake_provider.clone()), + model, + }; + registry.set_default_model(Some(configured_model), cx); registry }); cx.set_global(GlobalLanguageModelRegistry(registry)); @@ -119,144 +134,114 @@ impl LanguageModelRegistry { self.providers.get(id).cloned() } - pub fn select_default_model( - &mut self, - provider: &LanguageModelProviderId, - model_id: &LanguageModelId, - cx: &mut Context, - ) { - let Some(provider) = self.provider(provider) else { - return; - }; - - let models = provider.provided_models(cx); - if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_default_model(Some(model), cx); - } + pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context) { + let configured_model = model.and_then(|model| self.select_model(model, cx)); + self.set_default_model(configured_model, cx); } pub fn select_inline_assistant_model( &mut self, - provider: &LanguageModelProviderId, - model_id: &LanguageModelId, + model: Option<&SelectedModel>, cx: &mut Context, ) { - let Some(provider) = self.provider(provider) else { - return; - }; - - let models = provider.provided_models(cx); - if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_inline_assistant_model(Some(model), cx); - } + let configured_model = model.and_then(|model| self.select_model(model, cx)); + self.set_inline_assistant_model(configured_model, cx); } pub fn select_commit_message_model( &mut self, - provider: &LanguageModelProviderId, - model_id: &LanguageModelId, + model: Option<&SelectedModel>, cx: &mut Context, ) { - let Some(provider) = self.provider(provider) else { - return; - }; - - let models = provider.provided_models(cx); - if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_commit_message_model(Some(model), cx); - } + let configured_model = model.and_then(|model| self.select_model(model, cx)); + self.set_commit_message_model(configured_model, cx); } pub fn select_thread_summary_model( &mut self, - provider: &LanguageModelProviderId, - model_id: &LanguageModelId, + model: Option<&SelectedModel>, cx: &mut Context, ) { - let Some(provider) = self.provider(provider) else { - return; - }; - - let models = provider.provided_models(cx); - if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() { - self.set_thread_summary_model(Some(model), cx); - } + let configured_model = model.and_then(|model| self.select_model(model, cx)); + self.set_thread_summary_model(configured_model, cx); } - pub fn set_default_model( + /// Selects and sets the inline alternatives for language models based on + /// provider name and id. + pub fn select_inline_alternative_models( &mut self, - model: Option>, + alternatives: impl IntoIterator, cx: &mut Context, ) { - if let Some(model) = model { - let provider_id = model.provider_id(); - if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.default_model = Some(ConfiguredModel { provider, model }); - cx.emit(Event::DefaultModelChanged); - } else { - log::warn!("Active model's provider not found in registry"); - } - } else { - self.default_model = None; - cx.emit(Event::DefaultModelChanged); + self.inline_alternatives = alternatives + .into_iter() + .flat_map(|alternative| { + self.select_model(&alternative, cx) + .map(|configured_model| configured_model.model) + }) + .collect::>(); + } + + fn select_model( + &mut self, + selected_model: &SelectedModel, + cx: &mut Context, + ) -> Option { + let provider = self.provider(&selected_model.provider)?; + let model = provider + .provided_models(cx) + .iter() + .find(|model| model.id() == selected_model.model)? + .clone(); + Some(ConfiguredModel { provider, model }) + } + + pub fn set_default_model(&mut self, model: Option, cx: &mut Context) { + match (self.default_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::DefaultModelChanged), } + self.default_model = model; } pub fn set_inline_assistant_model( &mut self, - model: Option>, + model: Option, cx: &mut Context, ) { - if let Some(model) = model { - let provider_id = model.provider_id(); - if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.inline_assistant_model = Some(ConfiguredModel { provider, model }); - cx.emit(Event::InlineAssistantModelChanged); - } else { - log::warn!("Inline assistant model's provider not found in registry"); - } - } else { - self.inline_assistant_model = None; - cx.emit(Event::InlineAssistantModelChanged); + match (self.inline_assistant_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::InlineAssistantModelChanged), } + self.inline_assistant_model = model; } pub fn set_commit_message_model( &mut self, - model: Option>, + model: Option, cx: &mut Context, ) { - if let Some(model) = model { - let provider_id = model.provider_id(); - if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.commit_message_model = Some(ConfiguredModel { provider, model }); - cx.emit(Event::CommitMessageModelChanged); - } else { - log::warn!("Commit message model's provider not found in registry"); - } - } else { - self.commit_message_model = None; - cx.emit(Event::CommitMessageModelChanged); + match (self.commit_message_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::CommitMessageModelChanged), } + self.commit_message_model = model; } pub fn set_thread_summary_model( &mut self, - model: Option>, + model: Option, cx: &mut Context, ) { - if let Some(model) = model { - let provider_id = model.provider_id(); - if let Some(provider) = self.providers.get(&provider_id).cloned() { - self.thread_summary_model = Some(ConfiguredModel { provider, model }); - cx.emit(Event::ThreadSummaryModelChanged); - } else { - log::warn!("Thread summary model's provider not found in registry"); - } - } else { - self.thread_summary_model = None; - cx.emit(Event::ThreadSummaryModelChanged); + match (self.thread_summary_model.as_ref(), model.as_ref()) { + (Some(old), Some(new)) if old.is_same_as(new) => {} + (None, None) => {} + _ => cx.emit(Event::ThreadSummaryModelChanged), } + self.thread_summary_model = model; } pub fn default_model(&self) -> Option { @@ -286,30 +271,6 @@ impl LanguageModelRegistry { .or_else(|| self.default_model()) } - /// Selects and sets the inline alternatives for language models based on - /// provider name and id. - pub fn select_inline_alternative_models( - &mut self, - alternatives: impl IntoIterator, - cx: &mut Context, - ) { - let mut selected_alternatives = Vec::new(); - - for (provider_id, model_id) in alternatives { - if let Some(provider) = self.providers.get(&provider_id) { - if let Some(model) = provider - .provided_models(cx) - .iter() - .find(|m| m.id() == model_id) - { - selected_alternatives.push(model.clone()); - } - } - } - - self.inline_alternatives = selected_alternatives; - } - /// The models to use for inline assists. Returns the union of the active /// model and all inline alternatives. When there are multiple models, the /// user will be able to cycle through results.