Simplify language model registry + only emit change events on change (#29086)
* Now only does default fallback logic in the registry * Only emits change events when there is actually a change Release Notes: - N/A
This commit is contained in:
parent
98ceffe026
commit
d88b06a5dc
5 changed files with 119 additions and 194 deletions
|
@ -11,12 +11,10 @@ use clap::Parser;
|
|||
use extension::ExtensionHostProxy;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
|
@ -94,18 +92,25 @@ fn main() {
|
|||
.telemetry()
|
||||
.start(system_id, installation_id, session_id, cx);
|
||||
|
||||
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
|
||||
let model_provider_id = model.provider_id();
|
||||
let model_provider = model_registry.provider(&model_provider_id).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(Some(model.clone()), cx);
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: model_provider.clone(),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
let model_provider_id = model.provider_id();
|
||||
|
||||
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
|
||||
let authenticate_task = model_provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
authenticate.await.unwrap();
|
||||
authenticate_task.await.unwrap();
|
||||
|
||||
std::fs::create_dir_all(REPOS_DIR)?;
|
||||
std::fs::create_dir_all(WORKTREES_DIR)?;
|
||||
|
@ -498,8 +503,11 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
pub fn find_model(
|
||||
model_name: &str,
|
||||
model_registry: &LanguageModelRegistry,
|
||||
cx: &App,
|
||||
) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model = model_registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == model_name);
|
||||
|
@ -519,15 +527,6 @@ pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn Language
|
|||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn authenticate_model_provider(
|
||||
provider_id: LanguageModelProviderId,
|
||||
cx: &mut App,
|
||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||
model_provider.authenticate(cx)
|
||||
}
|
||||
|
||||
pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
|
||||
(run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue