assistant_eval: Add ACE framework (#27181)

Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <michael@zed.dev>
This commit is contained in:
Thomas Mickley-Doyle 2025-04-02 23:02:06 -05:00 committed by GitHub
parent d3e4de7c72
commit cd85b430e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1113 additions and 373 deletions

View file

@ -1,18 +1,21 @@
mod eval;
mod get_exercise;
mod git_commands;
mod headless_assistant;
mod judge;
mod templates_eval;
use clap::Parser;
use eval::{Eval, EvalOutput};
use futures::future;
use gpui::{Application, AsyncApp};
use headless_assistant::{HeadlessAppState, authenticate_model_provider, find_model};
use itertools::Itertools;
use judge::Judge;
use language_model::{LanguageModel, LanguageModelRegistry};
use regex::Regex;
use eval::{run_exercise_eval, save_eval_results};
use futures::stream::{self, StreamExt};
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
use git_commands::read_base_sha;
use gpui::Application;
use headless_assistant::{authenticate_model_provider, find_model};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use std::{cmp, path::PathBuf, sync::Arc};
use std::{path::PathBuf, sync::Arc};
use templates_eval::all_templates;
#[derive(Parser, Debug)]
#[command(
@ -21,11 +24,16 @@ use std::{cmp, path::PathBuf, sync::Arc};
before_help = "Tool eval runner"
)]
struct Args {
/// Regexes to match the names of evals to run.
eval_name_regexes: Vec<String>,
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
/// Match the names of evals to run.
#[arg(long)]
exercise_names: Vec<String>,
/// Runs all exercises, causes the exercise_names to be ignored.
#[arg(long)]
all: bool,
/// Supported language types to evaluate (default: internal).
/// Internal is data generated from the agent panel
#[arg(long, default_value = "internal")]
languages: String,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
@ -35,72 +43,52 @@ struct Args {
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 10)
#[arg(short, long, default_value = "10")]
/// Number of evaluations to run concurrently (default: 3)
#[arg(short, long, default_value = "3")]
concurrency: usize,
/// Maximum number of exercises to evaluate per language
#[arg(long)]
max_exercises_per_language: Option<usize>,
}
// First, let's define the order in which templates should be executed
const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
"ProjectCreation",
"CodeModification",
"ConversationalGuidance",
];
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
// Path to the zed-ace-framework repo
let framework_path = PathBuf::from("../zed-ace-framework")
.canonicalize()
.unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap()
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
.collect::<Vec<_>>();
let evals_to_run = if args.all {
all_evals
} else {
args.eval_name_regexes
.into_iter()
.map(|regex_string| Regex::new(&regex_string).unwrap())
.flat_map(|regex| {
all_evals
.iter()
.filter(|eval_name| regex.is_match(eval_name))
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
};
if evals_to_run.is_empty() {
panic!("Names of evals to run must be provided or `--all` specified");
}
println!("Will run the following evals: {evals_to_run:?}");
println!("Running up to {} evals concurrently", args.concurrency);
let editor_model_name = if let Some(model_name) = args.editor_model_name {
model_name
} else {
args.model_name.clone()
};
let judge_model_name = if let Some(model_name) = args.judge_model_name {
model_name
} else {
args.model_name.clone()
};
println!("Using zed-ace-framework at: {:?}", framework_path);
println!("Evaluating languages: {:?}", languages);
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = find_model(&editor_model_name, cx).unwrap();
let judge_model = find_model(&judge_model_name, cx).unwrap();
let editor_model = if let Some(model_name) = &args.editor_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
let judge_model = if let Some(model_name) = &args.judge_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
@ -111,6 +99,11 @@ fn main() {
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
let framework_path_clone = framework_path.clone();
let languages_clone = languages.clone();
let exercise_names = args.exercise_names.clone();
let all_flag = args.all;
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
@ -126,99 +119,150 @@ fn main() {
.await
.unwrap();
let eval_load_futures = evals_to_run
// Read base SHA from setup.json
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
// Find all exercises for the specified languages
let all_exercises = find_exercises(
&framework_path_clone,
&languages_clone
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
args.max_exercises_per_language,
)
.unwrap();
println!("Found {} exercises total", all_exercises.len());
// Filter exercises if specific ones were requested
let exercises_to_run = if !exercise_names.is_empty() {
// If exercise names are specified, filter by them regardless of --all flag
all_exercises
.into_iter()
.filter(|path| {
let name = get_exercise_name(path);
exercise_names.iter().any(|filter| name.contains(filter))
})
.collect()
} else if all_flag {
// Only use all_flag if no exercise names are specified
all_exercises
} else {
// Default behavior (no filters)
all_exercises
};
println!("Will run {} exercises", exercises_to_run.len());
// Get all templates and sort them according to the execution order
let mut templates = all_templates();
templates.sort_by_key(|template| {
TEMPLATE_EXECUTION_ORDER
.iter()
.position(|&name| name == template.name)
.unwrap_or(usize::MAX)
});
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
let exercise_tasks: Vec<_> = exercises_to_run
.into_iter()
.map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name);
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
.map(|exercise_path| {
let exercise_name = get_exercise_name(&exercise_path);
let templates_clone = templates.clone();
let model_clone = model.clone();
let judge_model_clone = judge_model.clone();
let app_state_clone = app_state.clone();
let base_sha_clone = base_sha.clone();
let framework_path_clone = framework_path_clone.clone();
let cx_clone = cx.clone();
async move {
match load_future.await {
Ok(eval) => Some(eval),
println!("Processing exercise: {}", exercise_name);
let mut exercise_results = Vec::new();
// Determine the language for this exercise
let language = match get_exercise_language(&exercise_path) {
Ok(lang) => lang,
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error loading {eval_name}: {err}");
None
println!(
"Error determining language for {}: {}",
exercise_name, err
);
return exercise_results;
}
};
// Run each template sequentially for this exercise
for template in templates_clone {
// For "multi" or "internal" language, only run the CodeModification template
if (language == "multi" || language == "internal")
&& template.name != "CodeModification"
{
println!(
"Skipping {} template for {} language",
template.name, language
);
continue;
}
match run_exercise_eval(
exercise_path.clone(),
template.clone(),
model_clone.clone(),
judge_model_clone.clone(),
app_state_clone.clone(),
base_sha_clone.clone(),
framework_path_clone.clone(),
cx_clone.clone(),
)
.await
{
Ok(result) => {
println!(
"Completed {} with template {} - score: {}",
exercise_name, template.name, result.score
);
exercise_results.push(result);
}
Err(err) => {
println!(
"Error running {} with template {}: {}",
exercise_name, template.name, err
);
}
}
}
}
})
.collect::<Vec<_>>();
let loaded_evals = future::join_all(eval_load_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// The evals need to be loaded and grouped by URL before concurrently running, since
// evals that use the same remote URL will use the same working directory.
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
.into_iter()
.map(|eval| (eval.eval_setup.url.clone(), eval))
.into_group_map()
.into_values()
.collect::<Vec<_>>();
// Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| {
let model = model.clone();
let judge_model = judge_model.clone();
let app_state = app_state.clone();
let cx = cx.clone();
async move {
let mut results = Vec::new();
for eval in evals {
let name = eval.name.clone();
println!("Starting eval named {}", name);
let result = run_eval(
eval,
model.clone(),
judge_model.clone(),
app_state.clone(),
cx.clone(),
)
.await;
results.push((name, result));
// Save results for this exercise
if !exercise_results.is_empty() {
if let Err(err) =
save_eval_results(&exercise_path, exercise_results.clone()).await
{
println!("Error saving results for {}: {}", exercise_name, err);
} else {
println!("Saved results for {}", exercise_name);
}
}
results
exercise_results
}
})
.collect::<Vec<_>>();
.collect();
let results = future::join_all(result_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
println!(
"Running {} exercises with concurrency: {}",
exercise_tasks.len(),
args.concurrency
);
// Process results in order of completion
for (eval_name, result) in results {
match result {
Ok((eval_output, judge_output)) => {
println!("Generated diff for {eval_name}:\n");
println!("{}\n", eval_output.diff);
println!("Last message for {eval_name}:\n");
println!("{}\n", eval_output.last_message);
println!("Elapsed time: {:?}", eval_output.elapsed_time);
println!(
"Assistant response count: {}",
eval_output.assistant_response_count
);
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
println!("Judge output for {eval_name}: {judge_output}");
}
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error running {eval_name}: {err}");
}
}
}
// Run exercises concurrently, with each exercise running its templates sequentially
let all_results = stream::iter(exercise_tasks)
.buffer_unordered(args.concurrency)
.flat_map(stream::iter)
.collect::<Vec<_>>()
.await;
println!("Completed {} evaluation runs", all_results.len());
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
@ -226,18 +270,3 @@ fn main() {
println!("Done running evals");
}
async fn run_eval(
eval: Eval,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
cx: AsyncApp,
) -> anyhow::Result<(EvalOutput, String)> {
let path = eval.path.clone();
let judge = Judge::load(&path, judge_model).await?;
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
eval_output.save_to_directory(&path, judge_output.to_string())?;
Ok((eval_output, judge_output))
}