eval: Fine-grained assertions (#29246)
- Support programmatic examples ([example](17feb260a0/crates/eval/src/examples/file_search.rs
)) - Combine data-driven example declarations into a single `.toml` file ([example](17feb260a0/crates/eval/src/examples/find_and_replace_diff_card.toml
)) - Run judge on individual assertions (previously called "criteria") - Report judge and programmatic assertions in one combined table Note: We still need to work on concept naming <img width=400 src="https://github.com/user-attachments/assets/fc719c93-467f-412b-8d47-68821bd8a5f5"> Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com>
This commit is contained in:
parent
0d3fe474db
commit
ce1a674eba
18 changed files with 1969 additions and 1229 deletions
|
@ -1,13 +1,16 @@
|
|||
mod assertions;
|
||||
mod example;
|
||||
mod examples;
|
||||
mod ids;
|
||||
mod instance;
|
||||
mod tool_metrics;
|
||||
|
||||
pub(crate) use example::*;
|
||||
use parking_lot::Mutex;
|
||||
use assertions::display_error_row;
|
||||
use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
|
||||
pub(crate) use tool_metrics::*;
|
||||
|
||||
use ::fs::RealFs;
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::anyhow;
|
||||
use clap::Parser;
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
use collections::{HashMap, HashSet};
|
||||
|
@ -25,18 +28,20 @@ use prompt_store::PromptBuilder;
|
|||
use release_channel::AppVersion;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::VecDeque;
|
||||
use std::env;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "eval", disable_version_flag = true)]
|
||||
struct Args {
|
||||
/// Runs all examples that contain these substrings. If unspecified, all examples are run.
|
||||
/// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
|
||||
#[arg(value_name = "EXAMPLE_SUBSTRING")]
|
||||
examples: Vec<String>,
|
||||
filter: Vec<String>,
|
||||
/// Model to use (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model: String,
|
||||
|
@ -66,43 +71,30 @@ fn main() {
|
|||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.canonicalize()
|
||||
.unwrap();
|
||||
let eval_crate_dir = root_dir.join("crates/eval");
|
||||
let eval_crate_dir = root_dir.join("crates").join("eval");
|
||||
let repos_dir = eval_crate_dir.join("repos");
|
||||
let worktrees_dir = eval_crate_dir.join("worktrees");
|
||||
let examples_dir = eval_crate_dir.join("examples");
|
||||
let runs_dir = eval_crate_dir.join("runs");
|
||||
let run_dir = runs_dir.join(format!("{}", run_timestamp));
|
||||
let examples_dir = eval_crate_dir.join("src").join("examples");
|
||||
let run_dir = eval_crate_dir
|
||||
.join("runs")
|
||||
.join(format!("{}", run_timestamp));
|
||||
std::fs::create_dir_all(&run_dir).unwrap();
|
||||
std::fs::create_dir_all(&repos_dir).unwrap();
|
||||
std::fs::create_dir_all(&worktrees_dir).unwrap();
|
||||
std::fs::create_dir_all(&examples_dir).unwrap();
|
||||
std::fs::create_dir_all(&paths::config_dir()).unwrap();
|
||||
|
||||
let zed_commit_sha = commit_sha_for_path(root_dir);
|
||||
let zed_branch_name = git_branch_for_path(root_dir);
|
||||
let zed_commit_sha = commit_sha_for_path(&root_dir);
|
||||
let zed_branch_name = git_branch_for_path(&root_dir);
|
||||
let args = Args::parse();
|
||||
let all_available_examples = list_all_examples(&examples_dir).unwrap();
|
||||
|
||||
let example_paths = all_available_examples
|
||||
.iter()
|
||||
.filter_map(|example_path| {
|
||||
let name = example_path.file_name()?.to_string_lossy();
|
||||
if args.examples.is_empty()
|
||||
|| args
|
||||
.examples
|
||||
.iter()
|
||||
.any(|name_substring| name.contains(name_substring))
|
||||
{
|
||||
Some(example_path.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let languages: HashSet<String> = args.languages.into_iter().collect();
|
||||
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client.clone());
|
||||
let all_threads = examples::all(&examples_dir);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = init(cx);
|
||||
|
@ -163,28 +155,40 @@ fn main() {
|
|||
|
||||
let mut skipped = Vec::new();
|
||||
|
||||
for example_path in &example_paths {
|
||||
let example = Example::load_from_directory(
|
||||
example_path,
|
||||
&run_dir,
|
||||
&worktrees_dir,
|
||||
&repos_dir,
|
||||
)?;
|
||||
|
||||
if !example
|
||||
.base
|
||||
.language_extension
|
||||
.as_ref()
|
||||
.map_or(false, |lang| args.languages.contains(lang))
|
||||
for thread in all_threads {
|
||||
let meta = thread.meta();
|
||||
if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
|
||||
{
|
||||
skipped.push(example.name);
|
||||
skipped.push(meta.name);
|
||||
continue;
|
||||
}
|
||||
|
||||
examples.extend(example.repeat(args.repetitions));
|
||||
if meta.language_server.map_or(false, |language| {
|
||||
!languages.contains(&language.file_extension)
|
||||
}) {
|
||||
skipped.push(meta.name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO: This creates a worktree per repetition. Ideally these examples should
|
||||
// either be run sequentially on the same worktree, or reuse worktrees when there
|
||||
// are more examples to run than the concurrency limit.
|
||||
for repetition_number in 0..args.repetitions {
|
||||
let example_instance = ExampleInstance::new(
|
||||
thread.clone(),
|
||||
&repos_dir,
|
||||
&run_dir,
|
||||
&worktrees_dir,
|
||||
repetition_number,
|
||||
);
|
||||
|
||||
examples.push(example_instance);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Skipped examples: {}\n", skipped.join(", "));
|
||||
if !skipped.is_empty() {
|
||||
println!("Skipped threads: {}", skipped.join(", "));
|
||||
}
|
||||
|
||||
if examples.is_empty() {
|
||||
eprintln!("Filter matched no examples");
|
||||
|
@ -196,22 +200,23 @@ fn main() {
|
|||
|
||||
let max_name_width = examples
|
||||
.iter()
|
||||
.map(|e| e.repetition_name().len())
|
||||
.map(|e| e.worktree_name().len())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
for (i, example) in examples.iter_mut().enumerate() {
|
||||
|
||||
for (i, example_instance) in examples.iter_mut().enumerate() {
|
||||
let color = COLORS[i % COLORS.len()].to_string();
|
||||
example.set_log_prefix_style(&color, max_name_width);
|
||||
example_instance.set_log_prefix_style(&color, max_name_width);
|
||||
|
||||
println!(
|
||||
"{}Logging to: {}",
|
||||
example.log_prefix,
|
||||
example.run_directory_path().display()
|
||||
example_instance.log_prefix,
|
||||
example_instance.run_directory.display()
|
||||
);
|
||||
|
||||
let repo_url = example.base.url.clone();
|
||||
let repo_url = example_instance.repo_url();
|
||||
if repo_urls.insert(repo_url.clone()) {
|
||||
let repo_path = example.repo_path.clone();
|
||||
let repo_path = example_instance.repo_path.clone();
|
||||
|
||||
if !repo_path.join(".git").is_dir() {
|
||||
println!(
|
||||
|
@ -251,12 +256,12 @@ fn main() {
|
|||
|
||||
future::join_all(clone_tasks).await;
|
||||
|
||||
for example in examples.iter_mut() {
|
||||
example.fetch().await?;
|
||||
for example_instance in examples.iter_mut() {
|
||||
example_instance.fetch().await?;
|
||||
}
|
||||
|
||||
let examples = Arc::new(Mutex::new(VecDeque::from(examples)));
|
||||
let results_by_example_name = Arc::new(Mutex::new(HashMap::default()));
|
||||
let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
|
||||
let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
future::join_all((0..args.concurrency).map(|_| {
|
||||
let app_state = app_state.clone();
|
||||
|
@ -268,7 +273,7 @@ fn main() {
|
|||
let results = results_by_example_name.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
loop {
|
||||
let Some(mut example) = examples.lock().pop_front() else {
|
||||
let Some(mut example) = examples.borrow_mut().pop_front() else {
|
||||
break;
|
||||
};
|
||||
let result = async {
|
||||
|
@ -291,7 +296,7 @@ fn main() {
|
|||
}
|
||||
.await;
|
||||
results
|
||||
.lock()
|
||||
.borrow_mut()
|
||||
.entry(example.name.clone())
|
||||
.or_insert(Vec::new())
|
||||
.push((example.clone(), result));
|
||||
|
@ -300,98 +305,156 @@ fn main() {
|
|||
}))
|
||||
.await;
|
||||
|
||||
println!("\n\n");
|
||||
print_header("EVAL RESULTS");
|
||||
print_h1("EVAL RESULTS");
|
||||
|
||||
let mut diff_scores = Vec::new();
|
||||
let mut thread_scores = Vec::new();
|
||||
let mut programmatic_scores = Vec::new();
|
||||
let mut error_count = 0;
|
||||
|
||||
for (example_name, results) in results_by_example_name.lock().iter_mut() {
|
||||
print_header(&example_name);
|
||||
for (example_name, results) in results_by_example_name.borrow_mut().iter_mut() {
|
||||
print_h2(&example_name);
|
||||
|
||||
results.sort_unstable_by_key(|(example, _)| example.repetition);
|
||||
let mut example_cumulative_tool_metrics = ToolMetrics::default();
|
||||
|
||||
println!("┌───────┬──────┬────────┐");
|
||||
println!("│ Round │ Diff │ Thread │");
|
||||
println!("├───────┼──────┼────────┤");
|
||||
for (example, result) in results {
|
||||
let run_dir_path = example.run_directory_path();
|
||||
let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap();
|
||||
let mut table_rows = String::new();
|
||||
|
||||
for (example, result) in results.iter() {
|
||||
match result {
|
||||
Err(err) => {
|
||||
println!(
|
||||
"|{:^7}│{:^6}│{:^8}│ {:?}{}",
|
||||
display_error_row(
|
||||
&mut table_rows,
|
||||
example.repetition,
|
||||
"N/A",
|
||||
"N/A",
|
||||
err,
|
||||
relative_run_dir_path.display()
|
||||
);
|
||||
err.to_string(),
|
||||
)?;
|
||||
error_count += 1;
|
||||
}
|
||||
Ok((run_output, judge_result)) => {
|
||||
Ok((run_output, judge_output)) => {
|
||||
cumulative_tool_metrics.merge(&run_output.tool_metrics);
|
||||
example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
|
||||
|
||||
match judge_result {
|
||||
Ok(judge_output) => {
|
||||
diff_scores.push(judge_output.diff.score());
|
||||
thread_scores.push(judge_output.thread.score());
|
||||
println!(
|
||||
"|{:^7}│{:^6}│{:^8}│ {}",
|
||||
if !run_output.programmatic_assertions.total_count() > 0 {
|
||||
for assertion in &run_output.programmatic_assertions.ran {
|
||||
assertions::display_table_row(
|
||||
&mut table_rows,
|
||||
example.repetition,
|
||||
format!("{}%", judge_output.diff.score()),
|
||||
format!("{}%", judge_output.thread.score()),
|
||||
relative_run_dir_path.display()
|
||||
);
|
||||
assertion,
|
||||
)?;
|
||||
}
|
||||
Err(err) => {
|
||||
println!(
|
||||
"|{:^7}│{:^6}│{:^8}│{:?}│ {}",
|
||||
|
||||
programmatic_scores
|
||||
.push(run_output.programmatic_assertions.passed_percentage())
|
||||
}
|
||||
|
||||
if !judge_output.diff.is_empty() {
|
||||
diff_scores.push(judge_output.diff.passed_percentage());
|
||||
|
||||
for assertion in &judge_output.diff.ran {
|
||||
assertions::display_table_row(
|
||||
&mut table_rows,
|
||||
example.repetition,
|
||||
"N/A",
|
||||
"N/A",
|
||||
err,
|
||||
relative_run_dir_path.display()
|
||||
);
|
||||
assertion,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
if !judge_output.thread.is_empty() {
|
||||
thread_scores.push(judge_output.thread.passed_percentage());
|
||||
|
||||
for assertion in &judge_output.thread.ran {
|
||||
assertions::display_table_row(
|
||||
&mut table_rows,
|
||||
example.repetition,
|
||||
assertion,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("└───────┴──────┴────────┘");
|
||||
println!("{}", example_cumulative_tool_metrics);
|
||||
if !table_rows.is_empty() {
|
||||
assertions::print_table_header();
|
||||
print!("{}", table_rows);
|
||||
|
||||
assertions::print_table_divider();
|
||||
|
||||
for (example, result) in results.iter() {
|
||||
if let Ok((run_output, judge_output)) = result {
|
||||
assertions::print_table_round_summary(
|
||||
&example.repetition.to_string(),
|
||||
[
|
||||
&run_output.programmatic_assertions,
|
||||
&judge_output.diff,
|
||||
&judge_output.thread,
|
||||
]
|
||||
.into_iter(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
assertions::print_table_divider();
|
||||
|
||||
assertions::print_table_round_summary(
|
||||
"avg",
|
||||
results.iter().flat_map(|(_, result)| {
|
||||
result.iter().flat_map(|(run_output, judge_output)| {
|
||||
[
|
||||
&run_output.programmatic_assertions,
|
||||
&judge_output.diff,
|
||||
&judge_output.thread,
|
||||
]
|
||||
.into_iter()
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
assertions::print_table_footer();
|
||||
}
|
||||
|
||||
if !example_cumulative_tool_metrics.is_empty() {
|
||||
println!("{}", &example_cumulative_tool_metrics);
|
||||
}
|
||||
}
|
||||
|
||||
let diff_score_count = diff_scores.len();
|
||||
let average_diff_score = diff_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (diff_score_count as f32);
|
||||
if results_by_example_name.borrow().len() > 1 {
|
||||
print_h1("AGGREGATE");
|
||||
|
||||
if error_count > 0 {
|
||||
println!("\n{error_count} examples failed to run!");
|
||||
if error_count > 0 {
|
||||
println!("\n{error_count} examples failed to run!");
|
||||
}
|
||||
|
||||
let programmatic_score_count = programmatic_scores.len();
|
||||
if programmatic_score_count > 0 {
|
||||
let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
|
||||
/ (programmatic_score_count as f32))
|
||||
.floor();
|
||||
println!("Average programmatic score: {average_programmatic_score}%");
|
||||
}
|
||||
|
||||
let diff_score_count = diff_scores.len();
|
||||
if diff_score_count > 0 {
|
||||
let average_diff_score =
|
||||
(diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
|
||||
println!("Average diff score: {average_diff_score}%");
|
||||
}
|
||||
|
||||
let thread_score_count = thread_scores.len();
|
||||
|
||||
if thread_score_count > 0 {
|
||||
let average_thread_score = (thread_scores.into_iter().sum::<f32>()
|
||||
/ (thread_score_count as f32))
|
||||
.floor();
|
||||
println!("Average thread score: {average_thread_score}%");
|
||||
}
|
||||
|
||||
println!("");
|
||||
|
||||
print_h2("CUMULATIVE TOOL METRICS");
|
||||
println!("{}", cumulative_tool_metrics);
|
||||
}
|
||||
|
||||
println!("\nAverage code diff score: {average_diff_score}");
|
||||
|
||||
let thread_score_count = thread_scores.len();
|
||||
let average_thread_score = thread_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (thread_score_count as f32);
|
||||
|
||||
println!("\nAverage thread score: {average_thread_score}");
|
||||
|
||||
print_header("CUMULATIVE TOOL METRICS");
|
||||
println!("{}", cumulative_tool_metrics);
|
||||
|
||||
app_state.client.telemetry().flush_events().await;
|
||||
|
||||
cx.update(|cx| cx.quit())
|
||||
|
@ -400,20 +463,6 @@ fn main() {
|
|||
});
|
||||
}
|
||||
|
||||
fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
|
||||
let path = std::fs::canonicalize(examples_dir).unwrap();
|
||||
let entries = std::fs::read_dir(path).unwrap();
|
||||
let mut result_paths = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
result_paths.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result_paths)
|
||||
}
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct AgentAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
|
@ -570,7 +619,7 @@ pub fn git_branch_for_path(repo_path: &Path) -> String {
|
|||
}
|
||||
|
||||
async fn judge_example(
|
||||
example: Example,
|
||||
example: ExampleInstance,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
zed_commit_sha: &str,
|
||||
zed_branch_name: &str,
|
||||
|
@ -578,19 +627,9 @@ async fn judge_example(
|
|||
run_output: &RunOutput,
|
||||
enable_telemetry: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
) -> JudgeOutput {
|
||||
let judge_output = example.judge(model.clone(), &run_output, cx).await;
|
||||
|
||||
let diff_evaluation;
|
||||
let thread_evaluation;
|
||||
if let Ok(output) = judge_output.as_ref() {
|
||||
diff_evaluation = Some(output.diff.clone());
|
||||
thread_evaluation = Some(output.thread.clone());
|
||||
} else {
|
||||
diff_evaluation = None;
|
||||
thread_evaluation = None;
|
||||
}
|
||||
|
||||
if enable_telemetry {
|
||||
telemetry::event!(
|
||||
"Agent Example Evaluated",
|
||||
|
@ -599,15 +638,15 @@ async fn judge_example(
|
|||
run_id = run_id,
|
||||
example_name = example.name.clone(),
|
||||
example_repetition = example.repetition,
|
||||
diff_evaluation = diff_evaluation,
|
||||
thread_evaluation = thread_evaluation,
|
||||
diff_evaluation = judge_output.diff.clone(),
|
||||
thread_evaluation = judge_output.thread.clone(),
|
||||
tool_metrics = run_output.tool_metrics,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
repository_url = example.repo_url(),
|
||||
repository_revision = example.revision(),
|
||||
diagnostic_summary_before = run_output.diagnostic_summary_before,
|
||||
diagnostic_summary_after = run_output.diagnostic_summary_after,
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
|
@ -618,8 +657,16 @@ async fn judge_example(
|
|||
judge_output
|
||||
}
|
||||
|
||||
fn print_header(header: &str) {
|
||||
println!("\n========================================");
|
||||
println!("{:^40}", header);
|
||||
println!("========================================\n");
|
||||
const HEADER_WIDTH: usize = 65;
|
||||
|
||||
fn print_h1(header: &str) {
|
||||
println!("\n\n{:=^HEADER_WIDTH$}", "");
|
||||
println!("{:^HEADER_WIDTH$}", header);
|
||||
println!("{:=^HEADER_WIDTH$}\n", "");
|
||||
}
|
||||
|
||||
fn print_h2(header: &str) {
|
||||
println!("\n{:-^HEADER_WIDTH$}", "");
|
||||
println!("{:^HEADER_WIDTH$}", header);
|
||||
println!("{:-^HEADER_WIDTH$}\n", "");
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue