diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 064a0c688e..41dbe25d96 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; -use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry}; +use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel}; use node_runtime::{NodeBinaryOptions, NodeRuntime}; use project::Project; use project::project_settings::ProjectSettings; @@ -33,6 +33,7 @@ use std::collections::VecDeque; use std::env; use std::path::{Path, PathBuf}; use std::rc::Rc; +use std::str::FromStr; use std::sync::{Arc, LazyLock}; use util::ResultExt as _; @@ -45,12 +46,12 @@ struct Args { /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run. #[arg(value_name = "EXAMPLE_SUBSTRING")] filter: Vec, - /// ID of model to use. - #[arg(long, default_value = "claude-3-7-sonnet-latest")] + /// provider/model to use for agent + #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")] model: String, - /// Model provider to use. - #[arg(long, default_value = "anthropic")] - provider: String, + /// provider/model to use for judges + #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")] + judge_model: String, #[arg(long, value_delimiter = ',', default_value = "rs,ts,py")] languages: Vec, /// How many times to run each example. @@ -124,25 +125,19 @@ fn main() { let mut cumulative_tool_metrics = ToolMetrics::default(); - let model_registry = LanguageModelRegistry::read_global(cx); - let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap(); - let model_provider_id = model.provider_id(); - let model_provider = model_registry.provider(&model_provider_id).unwrap(); + let agent_model = load_model(&args.model, cx).unwrap(); + let judge_model = load_model(&args.judge_model, cx).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model( - Some(ConfiguredModel { - provider: model_provider.clone(), - model: model.clone(), - }), - cx, - ); + registry.set_default_model(Some(agent_model.clone()), cx); }); - let authenticate_task = model_provider.authenticate(cx); + let auth1 = agent_model.provider.authenticate(cx); + let auth2 = judge_model.provider.authenticate(cx); cx.spawn(async move |cx| { - authenticate_task.await.unwrap(); + auth1.await?; + auth2.await?; let mut examples = Vec::new(); @@ -273,7 +268,8 @@ fn main() { future::join_all((0..args.concurrency).map(|_| { let app_state = app_state.clone(); - let model = model.clone(); + let model = agent_model.model.clone(); + let judge_model = judge_model.model.clone(); let zed_commit_sha = zed_commit_sha.clone(); let zed_branch_name = zed_branch_name.clone(); let run_id = run_id.clone(); @@ -291,7 +287,7 @@ fn main() { .await?; let judge_output = judge_example( example.clone(), - model.clone(), + judge_model.clone(), &zed_commit_sha, &zed_branch_name, &run_id, @@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc { } pub fn find_model( - provider_id: &str, - model_id: &str, + model_name: &str, model_registry: &LanguageModelRegistry, cx: &App, ) -> anyhow::Result> { - let matching_models = model_registry + let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?; + model_registry .available_models(cx) - .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id) - .collect::>(); + .find(|model| model.id() == selected.model && model.provider_id() == selected.provider) + .ok_or_else(|| { + anyhow::anyhow!( + "No language model with ID {}/{} was available. Available models: {}", + selected.model.0, + selected.provider.0, + model_registry + .available_models(cx) + .map(|model| format!("{}/{}", model.provider_id().0, model.id().0)) + .collect::>() + .join(", ") + ) + }) +} - match matching_models.as_slice() { - [model] => Ok(model.clone()), - [] => anyhow::bail!( - "No language model with ID {}/{} was available. Available models: {}", - provider_id, - model_id, - model_registry - .available_models(cx) - .map(|model| format!("{}/{}", model.provider_id().0, model.id().0)) - .collect::>() - .join(", ") - ), - _ => anyhow::bail!( - "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}", - model_id, - matching_models - .iter() - .map(|model| model.provider_id().0) - .collect::>() - ), - } +pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result { + let model = { + let model_registry = LanguageModelRegistry::read_global(cx); + find_model(model_name, model_registry, cx)? + }; + + let provider = { + let model_registry = LanguageModelRegistry::read_global(cx); + model_registry + .provider(&model.provider_id()) + .ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))? + }; + + Ok(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }) } pub fn commit_sha_for_path(repo_path: &Path) -> String {