Simplify language model registry + only emit change events on change (#29086)

* Now only does default fallback logic in the registry

* Only emits change events when there is actually a change

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-04-19 02:26:42 -06:00 committed by GitHub
parent 98ceffe026
commit d88b06a5dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 119 additions and 194 deletions

1
Cargo.lock generated
View file

@ -7651,7 +7651,6 @@ dependencies = [
"http_client", "http_client",
"icons", "icons",
"image", "image",
"log",
"open_ai", "open_ai",
"parking_lot", "parking_lot",
"proto", "proto",

View file

@ -8,7 +8,7 @@ mod terminal_inline_assistant;
use std::sync::Arc; use std::sync::Arc;
use assistant_settings::AssistantSettings; use assistant_settings::{AssistantSettings, LanguageModelSelection};
use assistant_slash_command::SlashCommandRegistry; use assistant_slash_command::SlashCommandRegistry;
use client::Client; use client::Client;
use command_palette_hooks::CommandPaletteFilter; use command_palette_hooks::CommandPaletteFilter;
@ -161,71 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
fn update_active_language_model_from_settings(cx: &mut App) { fn update_active_language_model_from_settings(cx: &mut App) {
let settings = AssistantSettings::get_global(cx); let settings = AssistantSettings::get_global(cx);
// Default model - used as fallback
let active_model_provider_name =
LanguageModelProviderId::from(settings.default_model.provider.clone());
let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
// Inline assistant model fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel {
let inline_assistant_model = settings language_model::SelectedModel {
provider: LanguageModelProviderId::from(selection.provider.clone()),
model: LanguageModelId::from(selection.model.clone()),
}
}
let default = to_selected_model(&settings.default_model);
let inline_assistant = settings
.inline_assistant_model .inline_assistant_model
.as_ref() .as_ref()
.unwrap_or(&settings.default_model); .map(to_selected_model);
let inline_assistant_provider_name = let commit_message = settings
LanguageModelProviderId::from(inline_assistant_model.provider.clone());
let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
// Commit message model
let commit_message_model = settings
.commit_message_model .commit_message_model
.as_ref() .as_ref()
.unwrap_or(&settings.default_model); .map(to_selected_model);
let commit_message_provider_name = let thread_summary = settings
LanguageModelProviderId::from(commit_message_model.provider.clone());
let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
// Thread summary model
let thread_summary_model = settings
.thread_summary_model .thread_summary_model
.as_ref() .as_ref()
.unwrap_or(&settings.default_model); .map(to_selected_model);
let thread_summary_provider_name =
LanguageModelProviderId::from(thread_summary_model.provider.clone());
let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
let inline_alternatives = settings let inline_alternatives = settings
.inline_alternatives .inline_alternatives
.iter() .iter()
.map(|alternative| { .map(to_selected_model)
(
LanguageModelProviderId::from(alternative.provider.clone()),
LanguageModelId::from(alternative.model.clone()),
)
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
// Set the default model registry.select_default_model(Some(&default), cx);
registry.select_default_model(&active_model_provider_name, &active_model_id, cx); registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
registry.select_commit_message_model(commit_message.as_ref(), cx);
// Set the specific models registry.select_thread_summary_model(thread_summary.as_ref(), cx);
registry.select_inline_assistant_model(
&inline_assistant_provider_name,
&inline_assistant_model_id,
cx,
);
registry.select_commit_message_model(
&commit_message_provider_name,
&commit_message_model_id,
cx,
);
registry.select_thread_summary_model(
&thread_summary_provider_name,
&thread_summary_model_id,
cx,
);
// Set the alternatives
registry.select_inline_alternative_models(inline_alternatives, cx); registry.select_inline_alternative_models(inline_alternatives, cx);
}); });
} }

View file

@ -11,12 +11,10 @@ use clap::Parser;
use extension::ExtensionHostProxy; use extension::ExtensionHostProxy;
use futures::{StreamExt, future}; use futures::{StreamExt, future};
use gpui::http_client::{Uri, read_proxy_from_env}; use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal}; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio; use gpui_tokio::Tokio;
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{ use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::{NodeBinaryOptions, NodeRuntime}; use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project; use project::Project;
use project::project_settings::ProjectSettings; use project::project_settings::ProjectSettings;
@ -94,18 +92,25 @@ fn main() {
.telemetry() .telemetry()
.start(system_id, installation_id, session_id, cx); .start(system_id, installation_id, session_id, cx);
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap(); let model_registry = LanguageModelRegistry::read_global(cx);
let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx); registry.set_default_model(
Some(ConfiguredModel {
provider: model_provider.clone(),
model: model.clone(),
}),
cx,
);
}); });
let model_provider_id = model.provider_id(); let authenticate_task = model_provider.authenticate(cx);
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
authenticate.await.unwrap(); authenticate_task.await.unwrap();
std::fs::create_dir_all(REPOS_DIR)?; std::fs::create_dir_all(REPOS_DIR)?;
std::fs::create_dir_all(WORKTREES_DIR)?; std::fs::create_dir_all(WORKTREES_DIR)?;
@ -498,8 +503,11 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
}) })
} }
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> { pub fn find_model(
let model_registry = LanguageModelRegistry::read_global(cx); model_name: &str,
model_registry: &LanguageModelRegistry,
cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model = model_registry let model = model_registry
.available_models(cx) .available_models(cx)
.find(|model| model.id().0 == model_name); .find(|model| model.id().0 == model_name);
@ -519,15 +527,6 @@ pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn Language
Ok(model) Ok(model)
} }
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}
pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> { pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
(run_git(repo_path, &["rev-parse", "HEAD"]).await).ok() (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
} }

View file

@ -27,7 +27,6 @@ gpui.workspace = true
http_client.workspace = true http_client.workspace = true
icons.workspace = true icons.workspace = true
image.workspace = true image.workspace = true
log.workspace = true
open_ai = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true parking_lot.workspace = true
proto.workspace = true proto.workspace = true

View file

@ -25,12 +25,23 @@ pub struct LanguageModelRegistry {
inline_alternatives: Vec<Arc<dyn LanguageModel>>, inline_alternatives: Vec<Arc<dyn LanguageModel>>,
} }
pub struct SelectedModel {
pub provider: LanguageModelProviderId,
pub model: LanguageModelId,
}
#[derive(Clone)] #[derive(Clone)]
pub struct ConfiguredModel { pub struct ConfiguredModel {
pub provider: Arc<dyn LanguageModelProvider>, pub provider: Arc<dyn LanguageModelProvider>,
pub model: Arc<dyn LanguageModel>, pub model: Arc<dyn LanguageModel>,
} }
impl ConfiguredModel {
pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
}
}
pub enum Event { pub enum Event {
DefaultModelChanged, DefaultModelChanged,
InlineAssistantModelChanged, InlineAssistantModelChanged,
@ -59,7 +70,11 @@ impl LanguageModelRegistry {
let mut registry = Self::default(); let mut registry = Self::default();
registry.register_provider(fake_provider.clone(), cx); registry.register_provider(fake_provider.clone(), cx);
let model = fake_provider.provided_models(cx)[0].clone(); let model = fake_provider.provided_models(cx)[0].clone();
registry.set_default_model(Some(model), cx); let configured_model = ConfiguredModel {
provider: Arc::new(fake_provider.clone()),
model,
};
registry.set_default_model(Some(configured_model), cx);
registry registry
}); });
cx.set_global(GlobalLanguageModelRegistry(registry)); cx.set_global(GlobalLanguageModelRegistry(registry));
@ -119,144 +134,114 @@ impl LanguageModelRegistry {
self.providers.get(id).cloned() self.providers.get(id).cloned()
} }
pub fn select_default_model( pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
&mut self, let configured_model = model.and_then(|model| self.select_model(model, cx));
provider: &LanguageModelProviderId, self.set_default_model(configured_model, cx);
model_id: &LanguageModelId,
cx: &mut Context<Self>,
) {
let Some(provider) = self.provider(provider) else {
return;
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_default_model(Some(model), cx);
}
} }
pub fn select_inline_assistant_model( pub fn select_inline_assistant_model(
&mut self, &mut self,
provider: &LanguageModelProviderId, model: Option<&SelectedModel>,
model_id: &LanguageModelId,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(provider) = self.provider(provider) else { let configured_model = model.and_then(|model| self.select_model(model, cx));
return; self.set_inline_assistant_model(configured_model, cx);
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_inline_assistant_model(Some(model), cx);
}
} }
pub fn select_commit_message_model( pub fn select_commit_message_model(
&mut self, &mut self,
provider: &LanguageModelProviderId, model: Option<&SelectedModel>,
model_id: &LanguageModelId,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(provider) = self.provider(provider) else { let configured_model = model.and_then(|model| self.select_model(model, cx));
return; self.set_commit_message_model(configured_model, cx);
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_commit_message_model(Some(model), cx);
}
} }
pub fn select_thread_summary_model( pub fn select_thread_summary_model(
&mut self, &mut self,
provider: &LanguageModelProviderId, model: Option<&SelectedModel>,
model_id: &LanguageModelId,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
let Some(provider) = self.provider(provider) else { let configured_model = model.and_then(|model| self.select_model(model, cx));
return; self.set_thread_summary_model(configured_model, cx);
};
let models = provider.provided_models(cx);
if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
self.set_thread_summary_model(Some(model), cx);
}
} }
pub fn set_default_model( /// Selects and sets the inline alternatives for language models based on
/// provider name and id.
pub fn select_inline_alternative_models(
&mut self, &mut self,
model: Option<Arc<dyn LanguageModel>>, alternatives: impl IntoIterator<Item = SelectedModel>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(model) = model { self.inline_alternatives = alternatives
let provider_id = model.provider_id(); .into_iter()
if let Some(provider) = self.providers.get(&provider_id).cloned() { .flat_map(|alternative| {
self.default_model = Some(ConfiguredModel { provider, model }); self.select_model(&alternative, cx)
cx.emit(Event::DefaultModelChanged); .map(|configured_model| configured_model.model)
} else { })
log::warn!("Active model's provider not found in registry"); .collect::<Vec<_>>();
} }
} else {
self.default_model = None; fn select_model(
cx.emit(Event::DefaultModelChanged); &mut self,
selected_model: &SelectedModel,
cx: &mut Context<Self>,
) -> Option<ConfiguredModel> {
let provider = self.provider(&selected_model.provider)?;
let model = provider
.provided_models(cx)
.iter()
.find(|model| model.id() == selected_model.model)?
.clone();
Some(ConfiguredModel { provider, model })
}
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
match (self.default_model.as_ref(), model.as_ref()) {
(Some(old), Some(new)) if old.is_same_as(new) => {}
(None, None) => {}
_ => cx.emit(Event::DefaultModelChanged),
} }
self.default_model = model;
} }
pub fn set_inline_assistant_model( pub fn set_inline_assistant_model(
&mut self, &mut self,
model: Option<Arc<dyn LanguageModel>>, model: Option<ConfiguredModel>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(model) = model { match (self.inline_assistant_model.as_ref(), model.as_ref()) {
let provider_id = model.provider_id(); (Some(old), Some(new)) if old.is_same_as(new) => {}
if let Some(provider) = self.providers.get(&provider_id).cloned() { (None, None) => {}
self.inline_assistant_model = Some(ConfiguredModel { provider, model }); _ => cx.emit(Event::InlineAssistantModelChanged),
cx.emit(Event::InlineAssistantModelChanged);
} else {
log::warn!("Inline assistant model's provider not found in registry");
}
} else {
self.inline_assistant_model = None;
cx.emit(Event::InlineAssistantModelChanged);
} }
self.inline_assistant_model = model;
} }
pub fn set_commit_message_model( pub fn set_commit_message_model(
&mut self, &mut self,
model: Option<Arc<dyn LanguageModel>>, model: Option<ConfiguredModel>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(model) = model { match (self.commit_message_model.as_ref(), model.as_ref()) {
let provider_id = model.provider_id(); (Some(old), Some(new)) if old.is_same_as(new) => {}
if let Some(provider) = self.providers.get(&provider_id).cloned() { (None, None) => {}
self.commit_message_model = Some(ConfiguredModel { provider, model }); _ => cx.emit(Event::CommitMessageModelChanged),
cx.emit(Event::CommitMessageModelChanged);
} else {
log::warn!("Commit message model's provider not found in registry");
}
} else {
self.commit_message_model = None;
cx.emit(Event::CommitMessageModelChanged);
} }
self.commit_message_model = model;
} }
pub fn set_thread_summary_model( pub fn set_thread_summary_model(
&mut self, &mut self,
model: Option<Arc<dyn LanguageModel>>, model: Option<ConfiguredModel>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
if let Some(model) = model { match (self.thread_summary_model.as_ref(), model.as_ref()) {
let provider_id = model.provider_id(); (Some(old), Some(new)) if old.is_same_as(new) => {}
if let Some(provider) = self.providers.get(&provider_id).cloned() { (None, None) => {}
self.thread_summary_model = Some(ConfiguredModel { provider, model }); _ => cx.emit(Event::ThreadSummaryModelChanged),
cx.emit(Event::ThreadSummaryModelChanged);
} else {
log::warn!("Thread summary model's provider not found in registry");
}
} else {
self.thread_summary_model = None;
cx.emit(Event::ThreadSummaryModelChanged);
} }
self.thread_summary_model = model;
} }
pub fn default_model(&self) -> Option<ConfiguredModel> { pub fn default_model(&self) -> Option<ConfiguredModel> {
@ -286,30 +271,6 @@ impl LanguageModelRegistry {
.or_else(|| self.default_model()) .or_else(|| self.default_model())
} }
/// Selects and sets the inline alternatives for language models based on
/// provider name and id.
pub fn select_inline_alternative_models(
&mut self,
alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
cx: &mut Context<Self>,
) {
let mut selected_alternatives = Vec::new();
for (provider_id, model_id) in alternatives {
if let Some(provider) = self.providers.get(&provider_id) {
if let Some(model) = provider
.provided_models(cx)
.iter()
.find(|m| m.id() == model_id)
{
selected_alternatives.push(model.clone());
}
}
}
self.inline_alternatives = selected_alternatives;
}
/// The models to use for inline assists. Returns the union of the active /// The models to use for inline assists. Returns the union of the active
/// model and all inline alternatives. When there are multiple models, the /// model and all inline alternatives. When there are multiple models, the
/// user will be able to cycle through results. /// user will be able to cycle through results.