mod example; mod ids; mod tool_metrics; pub(crate) use example::*; pub(crate) use tool_metrics::*; use ::fs::RealFs; use anyhow::{Result, anyhow}; use clap::Parser; use client::{Client, ProxySettings, UserStore}; use collections::HashSet; use extension::ExtensionHostProxy; use futures::{StreamExt, future}; use gpui::http_client::{Uri, read_proxy_from_env}; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal}; use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry}; use node_runtime::{NodeBinaryOptions, NodeRuntime}; use project::Project; use project::project_settings::ProjectSettings; use prompt_store::PromptBuilder; use release_channel::AppVersion; use reqwest_client::ReqwestClient; use settings::{Settings, SettingsStore}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::usize; use util::ResultExt as _; pub const RUNS_DIR: &str = "./crates/eval/runs"; #[derive(Parser, Debug)] #[command(name = "eval", disable_version_flag = true)] struct Args { /// Runs all examples that contain these substrings. If unspecified, all examples are run. #[arg(value_name = "EXAMPLE_SUBSTRING")] examples: Vec, /// Model to use (default: "claude-3-7-sonnet-latest") #[arg(long, default_value = "claude-3-7-sonnet-latest")] model: String, #[arg(long, value_delimiter = ',', default_value = "rs,ts")] languages: Vec, /// How many times to run each example. Note that this is currently not very efficient as N /// worktrees will be created for the examples. #[arg(long, default_value = "1")] repetitions: u32, /// How many times to run the judge on each example run. #[arg(long, default_value = "3")] judge_repetitions: u32, /// Maximum number of examples to run concurrently. #[arg(long, default_value = "10")] concurrency: usize, } fn main() { env_logger::init(); let args = Args::parse(); let all_available_examples = list_all_examples().unwrap(); let example_paths = all_available_examples .iter() .filter_map(|example_path| { let name = example_path.file_name()?.to_string_lossy(); if args.examples.is_empty() || args .examples .iter() .any(|name_substring| name.contains(name_substring)) { Some(example_path.clone()) } else { None } }) .collect::>(); let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client.clone()); app.run(move |cx| { let app_state = init(cx); let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok(); let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok(); let session_id = uuid::Uuid::new_v4().to_string(); app_state .client .telemetry() .start(system_id, installation_id, session_id, cx); let mut cumulative_tool_metrics = ToolMetrics::default(); let model_registry = LanguageModelRegistry::read_global(cx); let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap(); let model_provider_id = model.provider_id(); let model_provider = model_registry.provider(&model_provider_id).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model( Some(ConfiguredModel { provider: model_provider.clone(), model: model.clone(), }), cx, ); }); let authenticate_task = model_provider.authenticate(cx); cx.spawn(async move |cx| { authenticate_task.await.unwrap(); std::fs::create_dir_all(REPOS_DIR)?; std::fs::create_dir_all(WORKTREES_DIR)?; let run_dir = Path::new(RUNS_DIR).join(format!( "{}", chrono::Local::now().format("%Y-%m-%d_%H-%M-%S") )); std::fs::create_dir_all(&run_dir)?; let mut examples = Vec::new(); const COLORS: [&str; 12] = [ "\x1b[31m", // Red "\x1b[32m", // Green "\x1b[33m", // Yellow "\x1b[34m", // Blue "\x1b[35m", // Magenta "\x1b[36m", // Cyan "\x1b[91m", // Bright Red "\x1b[92m", // Bright Green "\x1b[93m", // Bright Yellow "\x1b[94m", // Bright Blue "\x1b[95m", // Bright Magenta "\x1b[96m", // Bright Cyan ]; let mut max_name_width = 0; let mut skipped = Vec::new(); for example_path in &example_paths { let example = Example::load_from_directory(example_path, &run_dir)?; if !example .base .language_extension .as_ref() .map_or(false, |lang| args.languages.contains(lang)) { skipped.push(example.name); continue; } // TODO: This creates a worktree per repetition. Ideally these examples should // either be run sequentially on the same worktree, or reuse worktrees when there // are more examples to run than the concurrency limit. for repetition_number in 0..args.repetitions { let mut example = example.clone(); example.set_repetition_number(repetition_number); let name_len = example.name.len(); if name_len > max_name_width { max_name_width = example.name.len(); } examples.push(example); } } println!("Skipped examples: {}\n", skipped.join(", ")); if examples.is_empty() { eprintln!("Filter matched no examples"); return cx.update(|cx| cx.quit()); } let mut repo_urls = HashSet::default(); let mut clone_tasks = Vec::new(); for (i, example) in examples.iter_mut().enumerate() { let color = COLORS[i % COLORS.len()].to_string(); example.set_log_prefix_style(&color, max_name_width); println!( "{}Logging to: {}", example.log_prefix, example.example_output_directory().display() ); let repo_url = example.base.url.clone(); if repo_urls.insert(repo_url.clone()) { let repo_path = repo_path_for_url(&repo_url); if !repo_path.join(".git").is_dir() { println!( "{:>() .await; println!("\n\n"); print_header("EVAL RESULTS"); let mut diff_scores = Vec::new(); let mut thread_scores = Vec::new(); let mut error_count = 0; for (example, result) in results { print_header(&example.name); match result { Err(err) => { println!("💥 {}{:?}", example.log_prefix, err); error_count += 1; } Ok((run_output, judge_results)) => { cumulative_tool_metrics.merge(&run_output.tool_metrics); println!("┌───────┬──────┬────────┐"); println!("│ Judge │ Diff │ Thread │"); println!("├───────┼──────┼────────┤"); for (i, judge_result) in judge_results.iter().enumerate() { match judge_result { Ok(judge_output) => { let diff_score = judge_output.diff.score; diff_scores.push(diff_score); let thread_display = if let Some(thread) = &judge_output.thread { let thread_score = thread.score; thread_scores.push(thread_score); format!("{}", thread_score) } else { "N/A".to_string() }; println!( "|{:^7}│{:^6}│{:^8}│", i + 1, diff_score, thread_display ); } Err(err) => { println!("|{:^7}│{:^6}│{:^8}│{:?}", i + 1, "N/A", "N/A", err); } } } println!("└───────┴──────┴────────┘"); println!("{}", run_output.tool_metrics); } } println!( "{} > {}", " ".repeat(max_name_width), example.example_output_directory().display() ); } let diff_score_count = diff_scores.len(); let average_diff_score = diff_scores .into_iter() .map(|score| score as f32) .sum::() / (diff_score_count as f32); if error_count > 0 { println!("\n{error_count} examples failed to run!"); } if diff_score_count > 0 { println!("\nAverage code diff score: {average_diff_score}"); } let thread_score_count = thread_scores.len(); // We might have gotten no thread scores if we weren't asked to judge the thread. if thread_score_count > 0 { let average_thread_score = thread_scores .into_iter() .map(|score| score as f32) .sum::() / (thread_score_count as f32); if diff_score_count > 0 { println!("\nAverage thread score: {average_thread_score}"); } } print_header("CUMULATIVE TOOL METRICS"); println!("{}", cumulative_tool_metrics); std::thread::sleep(std::time::Duration::from_secs(2)); app_state.client.telemetry().flush_events(); cx.update(|cx| cx.quit()) }) .detach_and_log_err(cx); }); } fn list_all_examples() -> Result> { let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap(); let entries = std::fs::read_dir(path).unwrap(); let mut result_paths = Vec::new(); for entry in entries { let entry = entry?; let path = entry.path(); if path.is_dir() { result_paths.push(path); } } Ok(result_paths) } /// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. pub struct AgentAppState { pub languages: Arc, pub client: Arc, pub user_store: Entity, pub fs: Arc, pub node_runtime: NodeRuntime, // Additional fields not present in `workspace::AppState`. pub prompt_builder: Arc, } pub fn init(cx: &mut App) -> Arc { release_channel::init(SemanticVersion::default(), cx); gpui_tokio::init(cx); let mut settings_store = SettingsStore::new(cx); settings_store .set_default_settings(settings::default_settings().as_ref(), cx) .unwrap(); cx.set_global(settings_store); client::init_settings(cx); // Set User-Agent so we can download language servers from GitHub let user_agent = format!( "Zed/{} ({}; {})", AppVersion::global(cx), std::env::consts::OS, std::env::consts::ARCH ); let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); let proxy_url = proxy_str .as_ref() .and_then(|input| input.parse::().ok()) .or_else(read_proxy_from_env); let http = { let _guard = Tokio::handle(cx).enter(); ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) .expect("could not start HTTP client") }; cx.set_http_client(Arc::new(http)); Project::init_settings(cx); let client = Client::production(cx); cx.set_http_client(client.http_client().clone()); let git_binary_path = None; let fs = Arc::new(RealFs::new( git_binary_path, cx.background_executor().clone(), )); let mut languages = LanguageRegistry::new(cx.background_executor().clone()); languages.set_language_server_download_dir(paths::languages_dir().clone()); let languages = Arc::new(languages); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); extension::init(cx); let (tx, rx) = async_watch::channel(None); cx.observe_global::(move |cx| { let settings = &ProjectSettings::get_global(cx).node; let options = NodeBinaryOptions { allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(), allow_binary_download: true, use_paths: settings.path.as_ref().map(|node_path| { let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); let npm_path = settings .npm_path .as_ref() .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); ( node_path.clone(), npm_path.unwrap_or_else(|| { let base_path = PathBuf::new(); node_path.parent().unwrap_or(&base_path).join("npm") }), ) }), }; tx.send(Some(options)).log_err(); }) .detach(); let node_runtime = NodeRuntime::new(client.http_client().clone(), rx); let extension_host_proxy = ExtensionHostProxy::global(cx); language::init(cx); language_extension::init(extension_host_proxy.clone(), languages.clone()); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); languages::init(languages.clone(), node_runtime.clone(), cx); assistant_tools::init(client.http_client().clone(), cx); context_server::init(cx); prompt_store::init(cx); let stdout_is_a_pty = false; let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); SettingsStore::update_global(cx, |store, cx| { store.set_user_settings(include_str!("../runner_settings.json"), cx) }) .unwrap(); Arc::new(AgentAppState { languages, client, user_store, fs, node_runtime, prompt_builder, }) } pub fn find_model( model_name: &str, model_registry: &LanguageModelRegistry, cx: &App, ) -> anyhow::Result> { let model = model_registry .available_models(cx) .find(|model| model.id().0 == model_name); let Some(model) = model else { return Err(anyhow!( "No language model named {} was available. Available models: {}", model_name, model_registry .available_models(cx) .map(|model| model.id().0.clone()) .collect::>() .join(", ") )); }; Ok(model) } pub async fn get_current_commit_id(repo_path: &Path) -> Option { (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok() } pub fn get_current_commit_id_sync(repo_path: &Path) -> String { futures::executor::block_on(async { get_current_commit_id(repo_path).await.unwrap_or_default() }) } async fn run_judge_repetition( example: Example, model: Arc, run_output: &RunOutput, round: u32, cx: &AsyncApp, ) -> Result { let judge_result = example.judge(model.clone(), &run_output, round, cx).await; if let Ok(judge_output) = &judge_result { let cohort_id = example .run_directory_path .file_name() .map(|name| name.to_string_lossy().to_string()) .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string()); let path = std::path::Path::new("."); let commit_id = get_current_commit_id(path).await.unwrap_or_default(); if let Some(thread) = &judge_output.thread { telemetry::event!( "Agent Eval Completed", cohort_id = cohort_id, example_name = example.name.clone(), round = round, diff_score = judge_output.diff.score, diff_analysis = judge_output.diff.analysis, thread_score = thread.score, thread_analysis = thread.analysis, tool_metrics = run_output.tool_metrics, response_count = run_output.response_count, token_usage = run_output.token_usage, model = model.telemetry_id(), model_provider = model.provider_id().to_string(), repository_url = example.base.url.clone(), repository_revision = example.base.revision.clone(), diagnostics_before = run_output.diagnostics_before, diagnostics_after = run_output.diagnostics_after, commit_id = commit_id ); } else { telemetry::event!( "Agent Eval Completed", cohort_id = cohort_id, example_name = example.name.clone(), round = round, diff_score = judge_output.diff.score, diff_analysis = judge_output.diff.analysis, tool_metrics = run_output.tool_metrics, response_count = run_output.response_count, token_usage = run_output.token_usage, model = model.telemetry_id(), model_provider = model.provider_id().to_string(), repository_url = example.base.url.clone(), repository_revision = example.base.revision.clone(), diagnostics_before = run_output.diagnostics_before, diagnostics_after = run_output.diagnostics_after, commit_id = commit_id ); } } judge_result } fn print_header(header: &str) { println!("\n========================================"); println!("{:^40}", header); println!("========================================\n"); }