evals: Configurable judge model (#31282)
This is needed for apples-to-apples comparison of different agent models. Another change is that now `cargo -p eval` accepts model names as `provider_id/model_id` instead of separate `--provider` and `--model` params. Release Notes: - N/A
This commit is contained in:
parent
3a1053bf0c
commit
68a46c3627
1 changed files with 51 additions and 47 deletions
|
@ -20,7 +20,7 @@ use gpui::http_client::read_proxy_from_env;
|
|||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
|
||||
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
|
||||
use node_runtime::{NodeBinaryOptions, NodeRuntime};
|
||||
use project::Project;
|
||||
use project::project_settings::ProjectSettings;
|
||||
|
@ -33,6 +33,7 @@ use std::collections::VecDeque;
|
|||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use util::ResultExt as _;
|
||||
|
||||
|
@ -45,12 +46,12 @@ struct Args {
|
|||
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
|
||||
#[arg(value_name = "EXAMPLE_SUBSTRING")]
|
||||
filter: Vec<String>,
|
||||
/// ID of model to use.
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
/// provider/model to use for agent
|
||||
#[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
|
||||
model: String,
|
||||
/// Model provider to use.
|
||||
#[arg(long, default_value = "anthropic")]
|
||||
provider: String,
|
||||
/// provider/model to use for judges
|
||||
#[arg(long, default_value = "anthropic/claude-3-7-sonnet-latest")]
|
||||
judge_model: String,
|
||||
#[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
|
||||
languages: Vec<String>,
|
||||
/// How many times to run each example.
|
||||
|
@ -124,25 +125,19 @@ fn main() {
|
|||
|
||||
let mut cumulative_tool_metrics = ToolMetrics::default();
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
|
||||
let model_provider_id = model.provider_id();
|
||||
let model_provider = model_registry.provider(&model_provider_id).unwrap();
|
||||
let agent_model = load_model(&args.model, cx).unwrap();
|
||||
let judge_model = load_model(&args.judge_model, cx).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: model_provider.clone(),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
registry.set_default_model(Some(agent_model.clone()), cx);
|
||||
});
|
||||
|
||||
let authenticate_task = model_provider.authenticate(cx);
|
||||
let auth1 = agent_model.provider.authenticate(cx);
|
||||
let auth2 = judge_model.provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
authenticate_task.await.unwrap();
|
||||
auth1.await?;
|
||||
auth2.await?;
|
||||
|
||||
let mut examples = Vec::new();
|
||||
|
||||
|
@ -273,7 +268,8 @@ fn main() {
|
|||
|
||||
future::join_all((0..args.concurrency).map(|_| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
let model = agent_model.model.clone();
|
||||
let judge_model = judge_model.model.clone();
|
||||
let zed_commit_sha = zed_commit_sha.clone();
|
||||
let zed_branch_name = zed_branch_name.clone();
|
||||
let run_id = run_id.clone();
|
||||
|
@ -291,7 +287,7 @@ fn main() {
|
|||
.await?;
|
||||
let judge_output = judge_example(
|
||||
example.clone(),
|
||||
model.clone(),
|
||||
judge_model.clone(),
|
||||
&zed_commit_sha,
|
||||
&zed_branch_name,
|
||||
&run_id,
|
||||
|
@ -453,37 +449,45 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
|||
}
|
||||
|
||||
pub fn find_model(
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
model_name: &str,
|
||||
model_registry: &LanguageModelRegistry,
|
||||
cx: &App,
|
||||
) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let matching_models = model_registry
|
||||
let selected = SelectedModel::from_str(model_name).map_err(|e| anyhow::anyhow!(e))?;
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
|
||||
.collect::<Vec<_>>();
|
||||
.find(|model| model.id() == selected.model && model.provider_id() == selected.provider)
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"No language model with ID {}/{} was available. Available models: {}",
|
||||
selected.model.0,
|
||||
selected.provider.0,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
match matching_models.as_slice() {
|
||||
[model] => Ok(model.clone()),
|
||||
[] => anyhow::bail!(
|
||||
"No language model with ID {}/{} was available. Available models: {}",
|
||||
provider_id,
|
||||
model_id,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
),
|
||||
_ => anyhow::bail!(
|
||||
"Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
|
||||
model_id,
|
||||
matching_models
|
||||
.iter()
|
||||
.map(|model| model.provider_id().0)
|
||||
.collect::<Vec<_>>()
|
||||
),
|
||||
}
|
||||
pub fn load_model(model_name: &str, cx: &mut App) -> anyhow::Result<ConfiguredModel> {
|
||||
let model = {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
find_model(model_name, model_registry, cx)?
|
||||
};
|
||||
|
||||
let provider = {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
model_registry
|
||||
.provider(&model.provider_id())
|
||||
.ok_or_else(|| anyhow::anyhow!("Provider not found: {}", model.provider_id()))?
|
||||
};
|
||||
|
||||
Ok(ConfiguredModel {
|
||||
provider: provider.clone(),
|
||||
model: model.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn commit_sha_for_path(repo_path: &Path) -> String {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue