evals: Configurable judge model (#31282)

This is needed for apples-to-apples comparison of different agent
models.

Another change is that now `cargo -p eval` accepts model names as
`provider_id/model_id` instead of separate `--provider` and `--model`
params.


Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2025-05-23 18:03:09 +03:00 committed by GitHub
parent 3a1053bf0c
commit 68a46c3627
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env;
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio; use gpui_tokio::Tokio;
use language::LanguageRegistry; use language::LanguageRegistry;
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry}; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
use node_runtime::{NodeBinaryOptions, NodeRuntime}; use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project; use project::Project;
use project::project_settings::ProjectSettings; use project::project_settings::ProjectSettings;
@ -33,6 +33,7 @@ use std::collections::VecDeque;
use std::env; use std::env;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::rc::Rc; use std::rc::Rc;
use std::str::FromStr;
use std::sync::{Arc, LazyLock}; use std::sync::{Arc, LazyLock};
use util::ResultExt as _; use util::ResultExt as _;
@ -45,12 +46,12 @@ struct Args {
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run. /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")] #[arg(value_name = "EXAMPLE_SUBSTRING")]
filter: Vec<String>, filter: Vec<String>,
/// ID of model to use. /// provider/model to use for agent
#[arg(long, default_value = "claude-3-7-sonnet-latest")] #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
model: String, model: String,
/// Model provider to use. /// provider/model to use for judges
#[arg(long, default_value = "anthropic")] #[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
provider: String, judge_model: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts,py")] #[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
languages: Vec<String>, languages: Vec<String>,
/// How many times to run each example. /// How many times to run each example.
@ -124,25 +125,19 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default(); let mut cumulative_tool_metrics = ToolMetrics::default();
let model_registry = LanguageModelRegistry::read_global(cx); let agent_model = load_model(&args.model, cx).unwrap();
let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap(); let judge_model = load_model(&args.judge_model, cx).unwrap();
let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| { LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model( registry.set_default_model(Some(agent_model.clone()), cx);
Some(ConfiguredModel {
provider: model_provider.clone(),
model: model.clone(),
}),
cx,
);
}); });
let authenticate_task = model_provider.authenticate(cx); let auth1 = agent_model.provider.authenticate(cx);
let auth2 = judge_model.provider.authenticate(cx);
cx.spawn(async move |cx| { cx.spawn(async move |cx| {
authenticate_task.await.unwrap(); auth1.await?;
auth2.await?;
let mut examples = Vec::new(); let mut examples = Vec::new();
@ -273,7 +268,8 @@ fn main() {
future::join_all((0..args.concurrency).map(|_| { future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone(); let app_state = app_state.clone();
let model = model.clone(); let model = agent_model.model.clone();
let judge_model = judge_model.model.clone();
let zed_commit_sha = zed_commit_sha.clone(); let zed_commit_sha = zed_commit_sha.clone();
let zed_branch_name = zed_branch_name.clone(); let zed_branch_name = zed_branch_name.clone();
let run_id = run_id.clone(); let run_id = run_id.clone();
@ -291,7 +287,7 @@ fn main() {
.await?; .await?;
let judge_output = judge_example( let judge_output = judge_example(
example.clone(), example.clone(),
model.clone(), judge_model.clone(),
&zed_commit_sha, &zed_commit_sha,
&zed_branch_name, &zed_branch_name,
&run_id, &run_id,
@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
} }
pub fn find_model( pub fn find_model(
provider_id: &str, model_name: &str,
model_id: &str,
model_registry: &LanguageModelRegistry, model_registry: &LanguageModelRegistry,
cx: &App, cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> { ) -> anyhow::Result<Arc<dyn LanguageModel>> {
let matching_models = model_registry let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?;
model_registry
.available_models(cx) .available_models(cx)
.filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id) .find(|model| model.id() == selected.model && model.provider_id() == selected.provider)
.collect::<Vec<_>>(); .ok_or_else(|| {
anyhow::anyhow!(
"No language model with ID {}/{} was available. Available models: {}",
selected.model.0,
selected.provider.0,
model_registry
.available_models(cx)
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
.collect::<Vec<_>>()
.join(", ")
)
})
}
match matching_models.as_slice() { pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result<ConfiguredModel> {
[model] => Ok(model.clone()), let model = {
[] => anyhow::bail!( let model_registry = LanguageModelRegistry::read_global(cx);
"No language model with ID {}/{} was available. Available models: {}", find_model(model_name, model_registry, cx)?
provider_id, };
model_id,
model_registry let provider = {
.available_models(cx) let model_registry = LanguageModelRegistry::read_global(cx);
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0)) model_registry
.collect::<Vec<_>>() .provider(&model.provider_id())
.join(", ") .ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))?
), };
_ => anyhow::bail!(
"Multiple language models with ID {} available - use `--provider` to choose one of: {:?}", Ok(ConfiguredModel {
model_id, provider: provider.clone(),
matching_models model: model.clone(),
.iter() })
.map(|model| model.provider_id().0)
.collect::<Vec<_>>()
),
}
} }
pub fn commit_sha_for_path(repo_path: &Path) -> String { pub fn commit_sha_for_path(repo_path: &Path) -> String {