evals: Configurable judge model (#31282)

This is needed for apples-to-apples comparison of different agent
models.

Another change is that now `cargo -p eval` accepts model names as
`provider_id/model_id` instead of separate `--provider` and `--model`
params.


Release Notes:

- N/A
This commit is contained in:
Oleksiy Syvokon 2025-05-23 18:03:09 +03:00 committed by GitHub
parent 3a1053bf0c
commit 68a46c3627
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env;
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
@ -33,6 +33,7 @@ use std::collections::VecDeque;
use std::env;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use util::ResultExt as _;
@ -45,12 +46,12 @@ struct Args {
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")]
filter: Vec<String>,
/// ID of model to use.
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
/// provider/model to use for agent
#[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
model: String,
/// Model provider to use.
#[arg(long, default_value = "anthropic")]
provider: String,
/// provider/model to use for judges
#[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
judge_model: String,
#[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
languages: Vec<String>,
/// How many times to run each example.
@ -124,25 +125,19 @@ fn main() {
let mut cumulative_tool_metrics = ToolMetrics::default();
let model_registry = LanguageModelRegistry::read_global(cx);
let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
let model_provider_id = model.provider_id();
let model_provider = model_registry.provider(&model_provider_id).unwrap();
let agent_model = load_model(&args.model, cx).unwrap();
let judge_model = load_model(&args.judge_model, cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
provider: model_provider.clone(),
model: model.clone(),
}),
cx,
);
registry.set_default_model(Some(agent_model.clone()), cx);
});
let authenticate_task = model_provider.authenticate(cx);
let auth1 = agent_model.provider.authenticate(cx);
let auth2 = judge_model.provider.authenticate(cx);
cx.spawn(async move |cx| {
authenticate_task.await.unwrap();
auth1.await?;
auth2.await?;
let mut examples = Vec::new();
@ -273,7 +268,8 @@ fn main() {
future::join_all((0..args.concurrency).map(|_| {
let app_state = app_state.clone();
let model = model.clone();
let model = agent_model.model.clone();
let judge_model = judge_model.model.clone();
let zed_commit_sha = zed_commit_sha.clone();
let zed_branch_name = zed_branch_name.clone();
let run_id = run_id.clone();
@ -291,7 +287,7 @@ fn main() {
.await?;
let judge_output = judge_example(
example.clone(),
model.clone(),
judge_model.clone(),
&zed_commit_sha,
&zed_branch_name,
&run_id,
@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
}
pub fn find_model(
provider_id: &str,
model_id: &str,
model_name: &str,
model_registry: &LanguageModelRegistry,
cx: &App,
) -> anyhow::Result<Arc<dyn LanguageModel>> {
let matching_models = model_registry
let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?;
model_registry
.available_models(cx)
.filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
.collect::<Vec<_>>();
match matching_models.as_slice() {
[model] => Ok(model.clone()),
[] => anyhow::bail!(
.find(|model| model.id() == selected.model && model.provider_id() == selected.provider)
.ok_or_else(|| {
anyhow::anyhow!(
"No language model with ID {}/{} was available. Available models: {}",
provider_id,
model_id,
selected.model.0,
selected.provider.0,
model_registry
.available_models(cx)
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
.collect::<Vec<_>>()
.join(", ")
),
_ => anyhow::bail!(
"Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
model_id,
matching_models
.iter()
.map(|model| model.provider_id().0)
.collect::<Vec<_>>()
),
}
)
})
}
pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result<ConfiguredModel> {
let model = {
let model_registry = LanguageModelRegistry::read_global(cx);
find_model(model_name, model_registry, cx)?
};
let provider = {
let model_registry = LanguageModelRegistry::read_global(cx);
model_registry
.provider(&model.provider_id())
.ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))?
};
Ok(ConfiguredModel {
provider: provider.clone(),
model: model.clone(),
})
}
pub fn commit_sha_for_path(repo_path: &Path) -> String {