ZIm/crates/assistant_eval/src/main.rs
Mikayla Maki 1aefa5178b
Move "async move" a few characters to the left in cx.spawn() (#26758)
This is the core change:
https://github.com/zed-industries/zed/pull/26758/files#diff-044302c0d57147af17e68a0009fee3e8dcdfb4f32c27a915e70cfa80e987f765R1052

TODO:
- [x] Use AsyncFn instead of Fn() -> Future in GPUI spawn methods
- [x] Implement it in the whole app
- [x] Implement it in the debugger 
- [x] Glance at the RPC crate, and see if those box future methods can
be switched over. Answer: It can't directly, as you can't make an
AsyncFn* into a trait object. There's ways around that, but they're all
more complex than just keeping the code as is.
- [ ] Fix platform specific code

Release Notes:

- N/A
2025-03-19 02:09:02 +00:00

243 lines
8.9 KiB
Rust

mod eval;
mod headless_assistant;
mod judge;
use clap::Parser;
use eval::{Eval, EvalOutput};
use futures::future;
use gpui::{Application, AsyncApp};
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
use itertools::Itertools;
use judge::Judge;
use language_model::{LanguageModel, LanguageModelRegistry};
use regex::Regex;
use reqwest_client::ReqwestClient;
use std::{cmp, path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "assistant_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Regexes to match the names of evals to run.
eval_name_regexes: Vec<String>,
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
#[arg(long)]
all: bool,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Name of the judge model (default: value of `--model_name`).
#[arg(long)]
judge_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 10)
#[arg(short, long, default_value = "10")]
concurrency: usize,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
let crate_dir = PathBuf::from("../zed-agent-bench");
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
let repos_dir = crate_dir.join("repos");
if !repos_dir.exists() {
std::fs::create_dir_all(&repos_dir).unwrap();
}
let repos_dir = repos_dir.canonicalize().unwrap();
let all_evals = std::fs::read_dir(&evaluation_data_dir)
.unwrap()
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
.collect::<Vec<_>>();
let evals_to_run = if args.all {
all_evals
} else {
args.eval_name_regexes
.into_iter()
.map(|regex_string| Regex::new(&regex_string).unwrap())
.flat_map(|regex| {
all_evals
.iter()
.filter(|eval_name| regex.is_match(eval_name))
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
};
if evals_to_run.is_empty() {
panic!("Names of evals to run must be provided or `--all` specified");
}
println!("Will run the following evals: {evals_to_run:?}");
println!("Running up to {} evals concurrently", args.concurrency);
let editor_model_name = if let Some(model_name) = args.editor_model_name {
model_name
} else {
args.model_name.clone()
};
let judge_model_name = if let Some(model_name) = args.judge_model_name {
model_name
} else {
args.model_name.clone()
};
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = find_model(&editor_model_name, cx).unwrap();
let judge_model = find_model(&judge_model_name, cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_active_model(Some(model.clone()), cx);
registry.set_editor_model(Some(editor_model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let judge_model_provider_id = judge_model.provider_id();
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
let eval_load_futures = evals_to_run
.into_iter()
.map(|eval_name| {
let eval_path = evaluation_data_dir.join(&eval_name);
let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
async move {
match load_future.await {
Ok(eval) => Some(eval),
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error loading {eval_name}: {err}");
None
}
}
}
})
.collect::<Vec<_>>();
let loaded_evals = future::join_all(eval_load_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// The evals need to be loaded and grouped by URL before concurrently running, since
// evals that use the same remote URL will use the same working directory.
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
.into_iter()
.map(|eval| (eval.eval_setup.url.clone(), eval))
.into_group_map()
.into_values()
.collect::<Vec<_>>();
// Sort groups in descending order, so that bigger groups start first.
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
let result_futures = evals_grouped_by_url
.into_iter()
.map(|evals| {
let model = model.clone();
let judge_model = judge_model.clone();
let app_state = app_state.clone();
let cx = cx.clone();
async move {
let mut results = Vec::new();
for eval in evals {
let name = eval.name.clone();
println!("Starting eval named {}", name);
let result = run_eval(
eval,
model.clone(),
judge_model.clone(),
app_state.clone(),
cx.clone(),
)
.await;
results.push((name, result));
}
results
}
})
.collect::<Vec<_>>();
let results = future::join_all(result_futures)
.await
.into_iter()
.flatten()
.collect::<Vec<_>>();
// Process results in order of completion
for (eval_name, result) in results {
match result {
Ok((eval_output, judge_output)) => {
println!("Generated diff for {eval_name}:\n");
println!("{}\n", eval_output.diff);
println!("Last message for {eval_name}:\n");
println!("{}\n", eval_output.last_message);
println!("Elapsed time: {:?}", eval_output.elapsed_time);
println!(
"Assistant response count: {}",
eval_output.assistant_response_count
);
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
println!("Judge output for {eval_name}: {judge_output}");
}
Err(err) => {
// TODO: Persist errors / surface errors at the end.
println!("Error running {eval_name}: {err}");
}
}
}
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}
async fn run_eval(
eval: Eval,
model: Arc<dyn LanguageModel>,
judge_model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
cx: AsyncApp,
) -> anyhow::Result<(EvalOutput, String)> {
let path = eval.path.clone();
let judge = Judge::load(&path, judge_model).await?;
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
eval_output.save_to_directory(&path, judge_output.to_string())?;
Ok((eval_output, judge_output))
}