Simplify language model registry + only emit change events on change (#29086)

* Now only does default fallback logic in the registry

* Only emits change events when there is actually a change

Release Notes:

- N/A
This commit is contained in:
Michael Sloan 2025-04-19 02:26:42 -06:00 committed by GitHub
parent 98ceffe026
commit d88b06a5dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 119 additions and 194 deletions

View file

@ -11,12 +11,10 @@ use clap::Parser;
use extension::ExtensionHostProxy;
use futures::{StreamExt, future};
use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
@ -94,18 +92,25 @@ fn main() {
.telemetry()
.start(system_id, installation_id, session_id, cx);
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
registry.set_default_model(
Some(ConfiguredModel {
provider: model_provider.clone(),
model: model.clone(),
}),
cx,
);
});
let model_provider_id = model.provider_id();
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
let authenticate_task = model_provider.authenticate(cx);
cx.spawn(async move |cx| {
authenticate.await.unwrap();
authenticate_task.await.unwrap();
std::fs::create_dir_all(REPOS_DIR)?;
std::fs::create_dir_all(WORKTREES_DIR)?;
@ -498,8 +503,11 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
pub fn find_model(
model_name: &str,
model_registry: &LanguageModelRegistry,
cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
@ -519,15 +527,6 @@ pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn Language
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}
pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
(run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
}