From ce1a674ebad2a4ecc48dea029f73660483c74a64 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 22 Apr 2025 23:58:58 -0300 Subject: [PATCH] eval: Fine-grained assertions (#29246) - Support programmatic examples ([example](https://github.com/zed-industries/zed/blob/17feb260a0919e102ae4e220669c467885bf4b57/crates/eval/src/examples/file_search.rs)) - Combine data-driven example declarations into a single `.toml` file ([example](https://github.com/zed-industries/zed/blob/17feb260a0919e102ae4e220669c467885bf4b57/crates/eval/src/examples/find_and_replace_diff_card.toml)) - Run judge on individual assertions (previously called "criteria") - Report judge and programmatic assertions in one combined table Note: We still need to work on concept naming Release Notes: - N/A --------- Co-authored-by: Richard Feldman Co-authored-by: Max Brunsfeld Co-authored-by: Thomas Mickley-Doyle --- Cargo.lock | 4 +- crates/agent/src/thread.rs | 19 +- crates/assistant_tools/src/assistant_tools.rs | 2 + crates/eval/Cargo.toml | 5 +- .../find_and_replace_diff_card/base.toml | 3 - .../diff_criteria.md | 2 - .../find_and_replace_diff_card/prompt.md | 3 - .../thread_criteria.md | 3 - crates/eval/src/assertions.rs | 157 ++ crates/eval/src/eval.rs | 351 +++-- crates/eval/src/example.rs | 1362 ++++------------- crates/eval/src/examples/file_search.rs | 53 + .../examples/find_and_replace_diff_card.toml | 43 + crates/eval/src/examples/mod.rs | 128 ++ crates/eval/src/instance.rs | 1023 +++++++++++++ crates/eval/src/judge_diff_prompt.hbs | 18 +- crates/eval/src/judge_thread_prompt.hbs | 16 +- crates/eval/src/tool_metrics.rs | 6 +- 18 files changed, 1969 insertions(+), 1229 deletions(-) delete mode 100644 crates/eval/examples/find_and_replace_diff_card/base.toml delete mode 100644 crates/eval/examples/find_and_replace_diff_card/diff_criteria.md delete mode 100644 crates/eval/examples/find_and_replace_diff_card/prompt.md delete mode 100644 crates/eval/examples/find_and_replace_diff_card/thread_criteria.md create mode 100644 crates/eval/src/assertions.rs create mode 100644 crates/eval/src/examples/file_search.rs create mode 100644 crates/eval/src/examples/find_and_replace_diff_card.toml create mode 100644 crates/eval/src/examples/mod.rs create mode 100644 crates/eval/src/instance.rs diff --git a/Cargo.lock b/Cargo.lock index 75594ed5ac..04dde837b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4895,6 +4895,7 @@ dependencies = [ "anyhow", "assistant_tool", "assistant_tools", + "async-trait", "async-watch", "chrono", "clap", @@ -4915,13 +4916,14 @@ dependencies = [ "language_models", "languages", "node_runtime", - "parking_lot", "paths", "project", "prompt_store", + "regex", "release_channel", "reqwest_client", "serde", + "serde_json", "settings", "shellexpand 2.1.2", "smol", diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 2e5198a462..25170840d4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -315,6 +315,7 @@ pub struct Thread { request_callback: Option< Box])>, >, + remaining_turns: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -368,6 +369,7 @@ impl Thread { message_feedback: HashMap::default(), last_auto_capture_at: None, request_callback: None, + remaining_turns: u32::MAX, } } @@ -442,6 +444,7 @@ impl Thread { message_feedback: HashMap::default(), last_auto_capture_at: None, request_callback: None, + remaining_turns: u32::MAX, } } @@ -522,7 +525,7 @@ impl Thread { self.messages.iter().find(|message| message.id == id) } - pub fn messages(&self) -> impl Iterator { + pub fn messages(&self) -> impl ExactSizeIterator { self.messages.iter() } @@ -958,7 +961,21 @@ impl Thread { }) } + pub fn remaining_turns(&self) -> u32 { + self.remaining_turns + } + + pub fn set_remaining_turns(&mut self, remaining_turns: u32) { + self.remaining_turns = remaining_turns; + } + pub fn send_to_model(&mut self, model: Arc, cx: &mut Context) { + if self.remaining_turns == 0 { + return; + } + + self.remaining_turns -= 1; + let mut request = self.to_completion_request(cx); if model.supports_tools() { request.tools = { diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 408d5655b0..86e000e3b2 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -56,6 +56,8 @@ use crate::symbol_info_tool::SymbolInfoTool; use crate::terminal_tool::TerminalTool; use crate::thinking_tool::ThinkingTool; +pub use path_search_tool::PathSearchToolInput; + pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index 23780222a7..b0f989b6ad 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -9,6 +9,7 @@ agent.workspace = true anyhow.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true +async-trait.workspace = true async-watch.workspace = true chrono.workspace = true clap.workspace = true @@ -29,13 +30,14 @@ language_model.workspace = true language_models.workspace = true languages = { workspace = true, features = ["load-grammars"] } node_runtime.workspace = true -parking_lot.workspace = true paths.workspace = true project.workspace = true prompt_store.workspace = true +regex.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true shellexpand.workspace = true smol.workspace = true @@ -45,7 +47,6 @@ unindent.workspace = true util.workspace = true uuid = { version = "1.6", features = ["v4"] } workspace-hack.workspace = true - [[bin]] name = "eval" path = "src/eval.rs" diff --git a/crates/eval/examples/find_and_replace_diff_card/base.toml b/crates/eval/examples/find_and_replace_diff_card/base.toml deleted file mode 100644 index c88298997d..0000000000 --- a/crates/eval/examples/find_and_replace_diff_card/base.toml +++ /dev/null @@ -1,3 +0,0 @@ -url = "https://github.com/zed-industries/zed.git" -revision = "38fcadf9481d018543c65f36ac3bafeba190179b" -language_extension = "rs" diff --git a/crates/eval/examples/find_and_replace_diff_card/diff_criteria.md b/crates/eval/examples/find_and_replace_diff_card/diff_criteria.md deleted file mode 100644 index 12290f66ed..0000000000 --- a/crates/eval/examples/find_and_replace_diff_card/diff_criteria.md +++ /dev/null @@ -1,2 +0,0 @@ -- The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the task we were returning before, and a new `card` field that contains a view for the card. -- The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged. diff --git a/crates/eval/examples/find_and_replace_diff_card/prompt.md b/crates/eval/examples/find_and_replace_diff_card/prompt.md deleted file mode 100644 index a4c2cfdb0c..0000000000 --- a/crates/eval/examples/find_and_replace_diff_card/prompt.md +++ /dev/null @@ -1,3 +0,0 @@ -Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait. - -The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line. diff --git a/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md b/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md deleted file mode 100644 index 79144bb1a0..0000000000 --- a/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md +++ /dev/null @@ -1,3 +0,0 @@ -- The first tool call should be to path search including "find_replace_file_tool.rs" in the string. (*Not* grep, for example, or reading the file based on a guess at the path.) This is because we gave the model a filename and it needs to turn that into a real path. -- After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path. -- When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information on what path the Render trait might be in. diff --git a/crates/eval/src/assertions.rs b/crates/eval/src/assertions.rs new file mode 100644 index 0000000000..c021694401 --- /dev/null +++ b/crates/eval/src/assertions.rs @@ -0,0 +1,157 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::Write; +use std::fmt::{self}; + +#[derive(Default, Debug, Serialize, Deserialize, Clone)] +pub struct AssertionsReport { + pub ran: Vec, + pub max: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RanAssertion { + pub id: String, + pub result: Result, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RanAssertionResult { + pub analysis: Option, + pub passed: bool, +} + +impl AssertionsReport { + pub fn new(max: Option) -> Self { + AssertionsReport { + ran: Vec::new(), + max, + } + } + + pub fn is_empty(&self) -> bool { + self.ran.is_empty() + } + + pub fn total_count(&self) -> usize { + self.run_count().max(self.max.unwrap_or(0)) + } + + pub fn run_count(&self) -> usize { + self.ran.len() + } + + pub fn passed_count(&self) -> usize { + self.ran + .iter() + .filter(|a| a.result.as_ref().map_or(false, |result| result.passed)) + .count() + } + + pub fn passed_percentage(&self) -> f32 { + if self.total_count() == 0 { + 0.0 + } else { + (self.passed_count() as f32 / self.total_count() as f32) * 100.0 + } + } +} + +const ROUND_WIDTH: usize = "Round".len(); +const ASSERTIONS_WIDTH: usize = 42; +const RESULTS_WIDTH: usize = 8; + +pub fn print_table_header() { + println!( + "┌─{}─┬─{}─┬─{}─┐", + "─".repeat(ROUND_WIDTH), + "─".repeat(ASSERTIONS_WIDTH), + "─".repeat(RESULTS_WIDTH) + ); + + println!( + "│ {:^ROUND_WIDTH$} │ {:^ASSERTIONS_WIDTH$} │ {:^RESULTS_WIDTH$} │", + "Round", "Assertion", "Result" + ); + + println!( + "├─{}─┼─{}─┼─{}─┤", + "─".repeat(ROUND_WIDTH), + "─".repeat(ASSERTIONS_WIDTH), + "─".repeat(RESULTS_WIDTH) + ) +} + +pub fn display_error_row(f: &mut String, round: usize, error: String) -> fmt::Result { + let last_two_columns = ASSERTIONS_WIDTH + RESULTS_WIDTH; + writeln!( + f, + "│ {:^ROUND_WIDTH$} │ {: fmt::Result { + let result = match &assertion.result { + Ok(result) if result.passed => "\x1b[32m✔︎ Passed\x1b[0m", + Ok(_) => "\x1b[31m✗ Failed\x1b[0m", + Err(_) => "\x1b[31m💥 Judge Error\x1b[0m", + }; + + writeln!( + f, + "│ {:^ROUND_WIDTH$} │ {:RESULTS_WIDTH$} │", + round, + truncate(&assertion.id, ASSERTIONS_WIDTH), + result + ) +} + +pub fn print_table_round_summary<'a>( + round: &str, + reports: impl Iterator, +) { + let mut passed = 0; + let mut total = 0; + for report in reports { + passed += report.passed_count(); + total += report.total_count(); + } + + println!( + "│ {:^ROUND_WIDTH$} │ {:RESULTS_WIDTH$} │", + round, + "total", + format!("{}%", (passed as f32 / total as f32 * 100.0).floor()) + ) +} + +pub fn print_table_footer() { + println!( + "└─{}─┴─{}─┴─{}─┘", + "─".repeat(ROUND_WIDTH), + "─".repeat(ASSERTIONS_WIDTH), + "─".repeat(RESULTS_WIDTH) + ) +} + +pub fn print_table_divider() { + println!( + "├─{}─┼─{}─┼─{}─┤", + "─".repeat(ROUND_WIDTH), + "─".repeat(ASSERTIONS_WIDTH), + "─".repeat(RESULTS_WIDTH) + ) +} + +fn truncate(assertion: &str, max_width: usize) -> String { + if assertion.len() <= max_width { + assertion.to_string() + } else { + let mut end_ix = max_width - 1; + while !assertion.is_char_boundary(end_ix) { + end_ix -= 1; + } + format!("{}…", &assertion[..end_ix]) + } +} diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 516c34b53a..1873adbb61 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -1,13 +1,16 @@ +mod assertions; mod example; +mod examples; mod ids; +mod instance; mod tool_metrics; -pub(crate) use example::*; -use parking_lot::Mutex; +use assertions::display_error_row; +use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git}; pub(crate) use tool_metrics::*; use ::fs::RealFs; -use anyhow::{Result, anyhow}; +use anyhow::anyhow; use clap::Parser; use client::{Client, ProxySettings, UserStore}; use collections::{HashMap, HashSet}; @@ -25,18 +28,20 @@ use prompt_store::PromptBuilder; use release_channel::AppVersion; use reqwest_client::ReqwestClient; use settings::{Settings, SettingsStore}; +use std::cell::RefCell; use std::collections::VecDeque; use std::env; use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::sync::Arc; use util::ResultExt as _; #[derive(Parser, Debug)] #[command(name = "eval", disable_version_flag = true)] struct Args { - /// Runs all examples that contain these substrings. If unspecified, all examples are run. + /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run. #[arg(value_name = "EXAMPLE_SUBSTRING")] - examples: Vec, + filter: Vec, /// Model to use (default: "claude-3-7-sonnet-latest") #[arg(long, default_value = "claude-3-7-sonnet-latest")] model: String, @@ -66,43 +71,30 @@ fn main() { .parent() .unwrap() .parent() + .unwrap() + .canonicalize() .unwrap(); - let eval_crate_dir = root_dir.join("crates/eval"); + let eval_crate_dir = root_dir.join("crates").join("eval"); let repos_dir = eval_crate_dir.join("repos"); let worktrees_dir = eval_crate_dir.join("worktrees"); - let examples_dir = eval_crate_dir.join("examples"); - let runs_dir = eval_crate_dir.join("runs"); - let run_dir = runs_dir.join(format!("{}", run_timestamp)); + let examples_dir = eval_crate_dir.join("src").join("examples"); + let run_dir = eval_crate_dir + .join("runs") + .join(format!("{}", run_timestamp)); std::fs::create_dir_all(&run_dir).unwrap(); std::fs::create_dir_all(&repos_dir).unwrap(); std::fs::create_dir_all(&worktrees_dir).unwrap(); std::fs::create_dir_all(&examples_dir).unwrap(); std::fs::create_dir_all(&paths::config_dir()).unwrap(); - let zed_commit_sha = commit_sha_for_path(root_dir); - let zed_branch_name = git_branch_for_path(root_dir); + let zed_commit_sha = commit_sha_for_path(&root_dir); + let zed_branch_name = git_branch_for_path(&root_dir); let args = Args::parse(); - let all_available_examples = list_all_examples(&examples_dir).unwrap(); - - let example_paths = all_available_examples - .iter() - .filter_map(|example_path| { - let name = example_path.file_name()?.to_string_lossy(); - if args.examples.is_empty() - || args - .examples - .iter() - .any(|name_substring| name.contains(name_substring)) - { - Some(example_path.clone()) - } else { - None - } - }) - .collect::>(); + let languages: HashSet = args.languages.into_iter().collect(); let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client.clone()); + let all_threads = examples::all(&examples_dir); app.run(move |cx| { let app_state = init(cx); @@ -163,28 +155,40 @@ fn main() { let mut skipped = Vec::new(); - for example_path in &example_paths { - let example = Example::load_from_directory( - example_path, - &run_dir, - &worktrees_dir, - &repos_dir, - )?; - - if !example - .base - .language_extension - .as_ref() - .map_or(false, |lang| args.languages.contains(lang)) + for thread in all_threads { + let meta = thread.meta(); + if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub)) { - skipped.push(example.name); + skipped.push(meta.name); continue; } - examples.extend(example.repeat(args.repetitions)); + if meta.language_server.map_or(false, |language| { + !languages.contains(&language.file_extension) + }) { + skipped.push(meta.name); + continue; + } + + // TODO: This creates a worktree per repetition. Ideally these examples should + // either be run sequentially on the same worktree, or reuse worktrees when there + // are more examples to run than the concurrency limit. + for repetition_number in 0..args.repetitions { + let example_instance = ExampleInstance::new( + thread.clone(), + &repos_dir, + &run_dir, + &worktrees_dir, + repetition_number, + ); + + examples.push(example_instance); + } } - println!("Skipped examples: {}\n", skipped.join(", ")); + if !skipped.is_empty() { + println!("Skipped threads: {}", skipped.join(", ")); + } if examples.is_empty() { eprintln!("Filter matched no examples"); @@ -196,22 +200,23 @@ fn main() { let max_name_width = examples .iter() - .map(|e| e.repetition_name().len()) + .map(|e| e.worktree_name().len()) .max() .unwrap_or(0); - for (i, example) in examples.iter_mut().enumerate() { + + for (i, example_instance) in examples.iter_mut().enumerate() { let color = COLORS[i % COLORS.len()].to_string(); - example.set_log_prefix_style(&color, max_name_width); + example_instance.set_log_prefix_style(&color, max_name_width); println!( "{}Logging to: {}", - example.log_prefix, - example.run_directory_path().display() + example_instance.log_prefix, + example_instance.run_directory.display() ); - let repo_url = example.base.url.clone(); + let repo_url = example_instance.repo_url(); if repo_urls.insert(repo_url.clone()) { - let repo_path = example.repo_path.clone(); + let repo_path = example_instance.repo_path.clone(); if !repo_path.join(".git").is_dir() { println!( @@ -251,12 +256,12 @@ fn main() { future::join_all(clone_tasks).await; - for example in examples.iter_mut() { - example.fetch().await?; + for example_instance in examples.iter_mut() { + example_instance.fetch().await?; } - let examples = Arc::new(Mutex::new(VecDeque::from(examples))); - let results_by_example_name = Arc::new(Mutex::new(HashMap::default())); + let examples = Rc::new(RefCell::new(VecDeque::from(examples))); + let results_by_example_name = Rc::new(RefCell::new(HashMap::default())); future::join_all((0..args.concurrency).map(|_| { let app_state = app_state.clone(); @@ -268,7 +273,7 @@ fn main() { let results = results_by_example_name.clone(); cx.spawn(async move |cx| { loop { - let Some(mut example) = examples.lock().pop_front() else { + let Some(mut example) = examples.borrow_mut().pop_front() else { break; }; let result = async { @@ -291,7 +296,7 @@ fn main() { } .await; results - .lock() + .borrow_mut() .entry(example.name.clone()) .or_insert(Vec::new()) .push((example.clone(), result)); @@ -300,98 +305,156 @@ fn main() { })) .await; - println!("\n\n"); - print_header("EVAL RESULTS"); + print_h1("EVAL RESULTS"); let mut diff_scores = Vec::new(); let mut thread_scores = Vec::new(); + let mut programmatic_scores = Vec::new(); let mut error_count = 0; - for (example_name, results) in results_by_example_name.lock().iter_mut() { - print_header(&example_name); + for (example_name, results) in results_by_example_name.borrow_mut().iter_mut() { + print_h2(&example_name); results.sort_unstable_by_key(|(example, _)| example.repetition); let mut example_cumulative_tool_metrics = ToolMetrics::default(); - println!("┌───────┬──────┬────────┐"); - println!("│ Round │ Diff │ Thread │"); - println!("├───────┼──────┼────────┤"); - for (example, result) in results { - let run_dir_path = example.run_directory_path(); - let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap(); + let mut table_rows = String::new(); + for (example, result) in results.iter() { match result { Err(err) => { - println!( - "|{:^7}│{:^6}│{:^8}│ {:?}{}", + display_error_row( + &mut table_rows, example.repetition, - "N/A", - "N/A", - err, - relative_run_dir_path.display() - ); + err.to_string(), + )?; error_count += 1; } - Ok((run_output, judge_result)) => { + Ok((run_output, judge_output)) => { cumulative_tool_metrics.merge(&run_output.tool_metrics); example_cumulative_tool_metrics.merge(&run_output.tool_metrics); - match judge_result { - Ok(judge_output) => { - diff_scores.push(judge_output.diff.score()); - thread_scores.push(judge_output.thread.score()); - println!( - "|{:^7}│{:^6}│{:^8}│ {}", + if !run_output.programmatic_assertions.total_count() > 0 { + for assertion in &run_output.programmatic_assertions.ran { + assertions::display_table_row( + &mut table_rows, example.repetition, - format!("{}%", judge_output.diff.score()), - format!("{}%", judge_output.thread.score()), - relative_run_dir_path.display() - ); + assertion, + )?; } - Err(err) => { - println!( - "|{:^7}│{:^6}│{:^8}│{:?}│ {}", + + programmatic_scores + .push(run_output.programmatic_assertions.passed_percentage()) + } + + if !judge_output.diff.is_empty() { + diff_scores.push(judge_output.diff.passed_percentage()); + + for assertion in &judge_output.diff.ran { + assertions::display_table_row( + &mut table_rows, example.repetition, - "N/A", - "N/A", - err, - relative_run_dir_path.display() - ); + assertion, + )?; + } + } + + if !judge_output.thread.is_empty() { + thread_scores.push(judge_output.thread.passed_percentage()); + + for assertion in &judge_output.thread.ran { + assertions::display_table_row( + &mut table_rows, + example.repetition, + assertion, + )?; } } } } } - println!("└───────┴──────┴────────┘"); - println!("{}", example_cumulative_tool_metrics); + if !table_rows.is_empty() { + assertions::print_table_header(); + print!("{}", table_rows); + + assertions::print_table_divider(); + + for (example, result) in results.iter() { + if let Ok((run_output, judge_output)) = result { + assertions::print_table_round_summary( + &example.repetition.to_string(), + [ + &run_output.programmatic_assertions, + &judge_output.diff, + &judge_output.thread, + ] + .into_iter(), + ) + } + } + + assertions::print_table_divider(); + + assertions::print_table_round_summary( + "avg", + results.iter().flat_map(|(_, result)| { + result.iter().flat_map(|(run_output, judge_output)| { + [ + &run_output.programmatic_assertions, + &judge_output.diff, + &judge_output.thread, + ] + .into_iter() + }) + }), + ); + + assertions::print_table_footer(); + } + + if !example_cumulative_tool_metrics.is_empty() { + println!("{}", &example_cumulative_tool_metrics); + } } - let diff_score_count = diff_scores.len(); - let average_diff_score = diff_scores - .into_iter() - .map(|score| score as f32) - .sum::() - / (diff_score_count as f32); + if results_by_example_name.borrow().len() > 1 { + print_h1("AGGREGATE"); - if error_count > 0 { - println!("\n{error_count} examples failed to run!"); + if error_count > 0 { + println!("\n{error_count} examples failed to run!"); + } + + let programmatic_score_count = programmatic_scores.len(); + if programmatic_score_count > 0 { + let average_programmatic_score = (programmatic_scores.into_iter().sum::() + / (programmatic_score_count as f32)) + .floor(); + println!("Average programmatic score: {average_programmatic_score}%"); + } + + let diff_score_count = diff_scores.len(); + if diff_score_count > 0 { + let average_diff_score = + (diff_scores.into_iter().sum::() / (diff_score_count as f32)).floor(); + println!("Average diff score: {average_diff_score}%"); + } + + let thread_score_count = thread_scores.len(); + + if thread_score_count > 0 { + let average_thread_score = (thread_scores.into_iter().sum::() + / (thread_score_count as f32)) + .floor(); + println!("Average thread score: {average_thread_score}%"); + } + + println!(""); + + print_h2("CUMULATIVE TOOL METRICS"); + println!("{}", cumulative_tool_metrics); } - println!("\nAverage code diff score: {average_diff_score}"); - - let thread_score_count = thread_scores.len(); - let average_thread_score = thread_scores - .into_iter() - .map(|score| score as f32) - .sum::() - / (thread_score_count as f32); - - println!("\nAverage thread score: {average_thread_score}"); - - print_header("CUMULATIVE TOOL METRICS"); - println!("{}", cumulative_tool_metrics); - app_state.client.telemetry().flush_events().await; cx.update(|cx| cx.quit()) @@ -400,20 +463,6 @@ fn main() { }); } -fn list_all_examples(examples_dir: &Path) -> Result> { - let path = std::fs::canonicalize(examples_dir).unwrap(); - let entries = std::fs::read_dir(path).unwrap(); - let mut result_paths = Vec::new(); - for entry in entries { - let entry = entry?; - let path = entry.path(); - if path.is_dir() { - result_paths.push(path); - } - } - Ok(result_paths) -} - /// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. pub struct AgentAppState { pub languages: Arc, @@ -570,7 +619,7 @@ pub fn git_branch_for_path(repo_path: &Path) -> String { } async fn judge_example( - example: Example, + example: ExampleInstance, model: Arc, zed_commit_sha: &str, zed_branch_name: &str, @@ -578,19 +627,9 @@ async fn judge_example( run_output: &RunOutput, enable_telemetry: bool, cx: &AsyncApp, -) -> Result { +) -> JudgeOutput { let judge_output = example.judge(model.clone(), &run_output, cx).await; - let diff_evaluation; - let thread_evaluation; - if let Ok(output) = judge_output.as_ref() { - diff_evaluation = Some(output.diff.clone()); - thread_evaluation = Some(output.thread.clone()); - } else { - diff_evaluation = None; - thread_evaluation = None; - } - if enable_telemetry { telemetry::event!( "Agent Example Evaluated", @@ -599,15 +638,15 @@ async fn judge_example( run_id = run_id, example_name = example.name.clone(), example_repetition = example.repetition, - diff_evaluation = diff_evaluation, - thread_evaluation = thread_evaluation, + diff_evaluation = judge_output.diff.clone(), + thread_evaluation = judge_output.thread.clone(), tool_metrics = run_output.tool_metrics, response_count = run_output.response_count, token_usage = run_output.token_usage, model = model.telemetry_id(), model_provider = model.provider_id().to_string(), - repository_url = example.base.url.clone(), - repository_revision = example.base.revision.clone(), + repository_url = example.repo_url(), + repository_revision = example.revision(), diagnostic_summary_before = run_output.diagnostic_summary_before, diagnostic_summary_after = run_output.diagnostic_summary_after, diagnostics_before = run_output.diagnostics_before, @@ -618,8 +657,16 @@ async fn judge_example( judge_output } -fn print_header(header: &str) { - println!("\n========================================"); - println!("{:^40}", header); - println!("========================================\n"); +const HEADER_WIDTH: usize = 65; + +fn print_h1(header: &str) { + println!("\n\n{:=^HEADER_WIDTH$}", ""); + println!("{:^HEADER_WIDTH$}", header); + println!("{:=^HEADER_WIDTH$}\n", ""); +} + +fn print_h2(header: &str) { + println!("\n{:-^HEADER_WIDTH$}", ""); + println!("{:^HEADER_WIDTH$}", header); + println!("{:-^HEADER_WIDTH$}\n", ""); } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 5cd38929a2..445057ebe0 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,53 +1,57 @@ -use crate::{AgentAppState, ToolMetrics}; -use agent::{ThreadEvent, ThreadStore}; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::ToolWorkingSet; -use client::proto::LspWorkProgress; -use futures::channel::mpsc; -use futures::{FutureExt, StreamExt as _, select_biased}; -use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; -use handlebars::Handlebars; -use language::{Buffer, DiagnosticSeverity, OffsetRangeExt}; -use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, - MessageContent, Role, StopReason, TokenUsage, -}; -use project::{DiagnosticSummary, Project, ProjectPath}; -use serde::{Deserialize, Serialize}; -use std::cell::RefCell; -use std::fmt::Write as _; -use std::fs::File; -use std::io::Write as _; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; -use std::time::Duration; use std::{ - fs, - path::{Path, PathBuf}, + error::Error, + fmt::{self, Debug}, + sync::{Arc, Mutex}, + time::Duration, }; -use unindent::Unindent as _; -use util::ResultExt as _; -use util::command::new_smol_command; -use util::markdown::MarkdownString; -use util::serde::default_true; -const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); +use crate::{ + ToolMetrics, + assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, +}; +use agent::ThreadEvent; +use anyhow::{Result, anyhow}; +use async_trait::async_trait; +use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased}; +use gpui::{AppContext, AsyncApp, Entity}; +use language_model::{LanguageModel, Role, StopReason}; -const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git"; +pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2); -#[derive(Clone, Debug, Deserialize)] -pub struct ExampleBase { +#[async_trait(?Send)] +pub trait Example { + fn meta(&self) -> ExampleMetadata; + async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>; + fn diff_assertions(&self) -> Vec { + Vec::new() + } + fn thread_assertions(&self) -> Vec { + Vec::new() + } +} + +#[derive(Clone, Debug)] +pub struct JudgeAssertion { + pub id: String, + pub description: String, +} + +#[derive(Clone, Debug)] +pub struct ExampleMetadata { + pub name: String, pub url: String, pub revision: String, - pub language_extension: Option, - pub insert_id: Option, - #[serde(default = "default_true")] - pub require_lsp: bool, - #[serde(default)] + pub language_server: Option, + pub max_assertions: Option, +} + +#[derive(Clone, Debug)] +pub struct LanguageServer { + pub file_extension: String, pub allow_preexisting_diagnostics: bool, } -impl ExampleBase { +impl ExampleMetadata { pub fn repo_name(&self) -> String { self.url .split('/') @@ -58,1042 +62,310 @@ impl ExampleBase { } } -#[derive(Clone, Debug)] -pub struct Example { - pub name: String, - /// Content of `base.toml` - pub base: ExampleBase, - /// Content of `prompt.md` - pub prompt: String, - /// Content of `diff_criteria.md` - pub diff_criteria: String, - /// Content of `thread_criteria.md`, if that file exists (it's optional) - pub thread_criteria: String, - /// Prefix used for logging that identifies this example - pub log_prefix: String, - /// The repetition number for this example (0-based) - /// When running multiple repetitions of the same example, each instance is assigned a unique repetition number. - /// This affects the worktree path and log prefix to avoid clobbering results between runs. - pub repetition: usize, - pub repo_path: PathBuf, - /// Path to the directory containing the requests and responses for the agentic loop - run_dir_path: PathBuf, - worktrees_dir: PathBuf, +pub struct FailedAssertion(pub String); + +impl fmt::Debug for FailedAssertion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Assertion failure: {}", self.0) + } } -#[derive(Debug, Serialize, Clone)] -pub struct RunOutput { - pub repository_diff: String, - pub ran_diagnostics_check: bool, - pub diagnostic_summary_before: DiagnosticSummary, - pub diagnostic_summary_after: DiagnosticSummary, - pub diagnostics_before: Option, - pub diagnostics_after: Option, - pub response_count: usize, - pub token_usage: TokenUsage, - pub tool_metrics: ToolMetrics, - pub last_request: LanguageModelRequest, +impl fmt::Display for FailedAssertion { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeDiffInput { - pub repository_diff: String, - pub criteria: String, +impl Error for FailedAssertion {} + +pub struct ExampleContext { + meta: ExampleMetadata, + log_prefix: String, + agent_thread: Entity, + app: AsyncApp, + model: Arc, + pub assertions: AssertionsReport, + pub tool_metrics: Arc>, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeThreadInput { - pub messages: String, - pub criteria: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeResponse { - pub analysis: String, - pub passing_criteria: u32, - pub total_criteria: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeOutput { - pub thread: JudgeResponse, - pub diff: JudgeResponse, -} - -impl Example { - /// Load an example from a directory containing base.toml, prompt.md, and criteria.md - pub fn load_from_directory( - dir_path: &Path, - run_dir: &Path, - worktrees_dir: &Path, - repos_dir: &Path, - ) -> Result { - let name = dir_path.file_name().unwrap().to_string_lossy().to_string(); - let base_path = dir_path.join("base.toml"); - let prompt_path = dir_path.join("prompt.md"); - let diff_criteria_path = dir_path.join("diff_criteria.md"); - let thread_criteria_path = dir_path.join("thread_criteria.md"); - let base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; - - let repo_path = repo_path_for_url(repos_dir, &base.url); - - Ok(Example { - name: name.clone(), - base, - prompt: fs::read_to_string(prompt_path.clone())?, - thread_criteria: fs::read_to_string(thread_criteria_path.clone())?, - diff_criteria: fs::read_to_string(diff_criteria_path.clone())?, - run_dir_path: run_dir.to_path_buf(), - worktrees_dir: worktrees_dir.to_path_buf(), - repo_path, - log_prefix: name, - repetition: 0, - }) - } - - pub fn repetition_name(&self) -> String { - format!("{}-{}", self.name, self.repetition) - } - - pub fn worktree_path(&self) -> PathBuf { - self.worktrees_dir - .canonicalize() - .unwrap() - .join(self.repetition_name()) - .join(&self.base.repo_name()) - } - - pub fn run_directory_path(&self) -> PathBuf { - self.run_dir_path.join(self.repetition_name()) - } - - /// Create an iterator that returns copies of this example with different repetition numbers - /// Each copy will have a different repetition number and worktree path based on the repetition - pub fn repeat(self, repetitions: usize) -> impl Iterator { - (0..repetitions).map(move |repetition| { - let mut example = self.clone(); - example.repetition = repetition; - example - }) - } - - pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) { - self.log_prefix = format!( - "{}{: Result<()> { - let revision_exists = run_git( - &self.repo_path, - &["rev-parse", &format!("{}^{{commit}}", self.base.revision)], - ) - .await - .is_ok(); - - if !revision_exists { - println!( - "{}Fetching revision {}", - self.log_prefix, &self.base.revision - ); - run_git( - &self.repo_path, - &["fetch", "--depth", "1", "origin", &self.base.revision], - ) - .await?; - } - Ok(()) - } - - /// Set up the example by checking out the specified Git revision - pub async fn setup(&mut self) -> Result<()> { - let worktree_path = self.worktree_path(); - if worktree_path.is_dir() { - println!("{}Resetting existing worktree", self.log_prefix); - - // TODO: consider including "-x" to remove ignored files. The downside of this is that - // it will also remove build artifacts, and so prevent incremental reuse there. - run_git(&worktree_path, &["clean", "--force", "-d"]).await?; - run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", &self.base.revision]).await?; - } else { - println!("{}Creating worktree", self.log_prefix); - - let worktree_path_string = worktree_path.to_string_lossy().to_string(); - - run_git( - &self.repo_path, - &[ - "worktree", - "add", - "-f", - &worktree_path_string, - &self.base.revision, - ], - ) - .await?; - } - - if self.base.url == ZED_REPO_URL { - std::fs::write(worktree_path.join(".rules"), std::fs::read(".rules")?)?; - } - - std::fs::create_dir_all(self.run_directory_path())?; - - Ok(()) - } - - pub fn run( - &self, +impl ExampleContext { + pub fn new( + meta: ExampleMetadata, + log_prefix: String, + agent_thread: Entity, model: Arc, - app_state: Arc, - cx: &mut App, - ) -> Task> { - let project = Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ); + app: AsyncApp, + ) -> Self { + let assertions = AssertionsReport::new(meta.max_assertions); - let worktree = project.update(cx, |project, cx| { - project.create_worktree(self.worktree_path(), true, cx) + Self { + meta, + log_prefix, + agent_thread, + assertions, + model, + app, + tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())), + } + } + + pub fn push_user_message(&mut self, text: impl ToString) { + self.app + .update_entity(&self.agent_thread, |thread, cx| { + thread.insert_user_message(text.to_string(), vec![], None, cx); + }) + .unwrap(); + } + + pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> { + let message = message.to_string(); + self.log_assertion( + if expected { + Ok(()) + } else { + Err(anyhow::Error::from(FailedAssertion(message.clone()))) + }, + message, + ) + } + + pub fn assert_some(&mut self, option: Option, message: impl ToString) -> Result { + let message = message.to_string(); + self.log_assertion( + match option { + Some(value) => Ok(value), + None => Err(anyhow::Error::from(FailedAssertion(message.clone()))), + }, + message, + ) + } + + #[allow(dead_code)] + pub fn assert_eq( + &mut self, + left: T, + right: T, + message: impl ToString, + ) -> Result<()> { + let message = message.to_string(); + self.log_assertion( + if left == right { + Ok(()) + } else { + println!("{}{:#?} != {:#?}", self.log_prefix, left, right); + Err(anyhow::Error::from(FailedAssertion(message.clone()))) + }, + message, + ) + } + + fn log_assertion(&mut self, result: Result, message: String) -> Result { + if let Some(max) = self.meta.max_assertions { + if self.assertions.run_count() > max { + return Err(anyhow!( + "More assertions were run than the stated max_assertions of {}", + max + )); + } + } + + self.assertions.ran.push(RanAssertion { + id: message.clone(), + result: Ok(RanAssertionResult { + analysis: None, + passed: result.is_ok(), + }), }); - let tools = cx.new(|_| ToolWorkingSet::default()); - let thread_store = - ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); - let this = self.clone(); - - cx.spawn(async move |cx| { - let worktree = worktree.await?; - - // Wait for worktree scan to finish before choosing a file to open. - worktree - .update(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - - let lsp = if this.base.require_lsp { - let language_extension = this.base.language_extension.as_deref().context( - "language_extension field is required in base.toml when `require_lsp == true`", - )?; - - // Open a file that matches the language to cause LSP to start. - let language_file = worktree.read_with(cx, |worktree, _cx| { - worktree - .files(false, 0) - .find_map(|e| { - if e.path.clone().extension().and_then(|ext| ext.to_str()) - == Some(language_extension) - { - Some(ProjectPath { - worktree_id: worktree.id(), - path: e.path.clone(), - }) - } else { - None - } - }) - .context("Failed to find a file for example language") - })??; - - let open_language_file_buffer_task = project.update(cx, |project, cx| { - project.open_buffer(language_file.clone(), cx) - })?; - - let language_file_buffer = open_language_file_buffer_task.await?; - - let lsp_open_handle = project.update(cx, |project, cx| { - project.register_buffer_with_language_servers(&language_file_buffer, cx) - })?; - - wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?; - - Some((lsp_open_handle, language_file_buffer)) - } else { - None - }; - - let diagnostic_summary_before = project.read_with(cx, |project, cx| { - project.diagnostic_summary(false, cx) - })?; - let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?; - if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics { - return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`")); - } - - if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { - return Err(anyhow!("Setup only mode")); - } - - let example_output_dir = this.run_directory_path(); - let last_diff_file_path = example_output_dir.join("last.diff"); - - // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the - // history using undo/redo. - std::fs::write(&last_diff_file_path, "")?; - - let thread_store = thread_store.await?; - let thread = - thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; - let last_request = Rc::new(RefCell::new(None)); - - thread.update(cx, |thread, _cx| { - let mut request_count = 0; - let last_request = Rc::clone(&last_request); - let previous_diff = Rc::new(RefCell::new("".to_string())); - let example_output_dir = example_output_dir.clone(); - let last_diff_file_path = last_diff_file_path.clone(); - let this = this.clone(); - thread.set_request_callback(move |request, response_events| { - *last_request.borrow_mut() = Some(request.clone()); - - request_count += 1; - let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md")); - let diff_file_path = example_output_dir.join(format!("{request_count}.diff")); - let last_messages_file_path = example_output_dir.join("last.messages.md"); - let request_markdown = RequestMarkdown::new(request); - let response_events_markdown = response_events_to_markdown(response_events); - - let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown); - fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file"); - fs::write(&last_messages_file_path, messages).expect("failed to write last messages file"); - - let diff_result = smol::block_on(this.repository_diff()); - match diff_result { - Ok(diff) => { - if diff != previous_diff.borrow().clone() { - fs::write(&diff_file_path, &diff).expect("failed to write diff file"); - fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file"); - *previous_diff.borrow_mut() = diff; - } - } - Err(err) => { - let error_message = format!("{err:?}"); - fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file"); - fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file"); - } - } - - if request_count == 1 { - let tools_file_path = example_output_dir.join("tools.md"); - fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file"); - } - }); - })?; - - let tool_metrics = Arc::new(Mutex::new(ToolMetrics::default())); - - let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded(); - - let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| { - thread_event_tx.unbounded_send(event.clone()).log_err(); - }); - - let event_handler_task = cx.spawn({ - let log_prefix = this.log_prefix.clone(); - let tool_metrics = tool_metrics.clone(); - let thread = thread.downgrade(); - async move |cx| { - loop { - let event = select_biased! { - event = thread_event_rx.next() => event, - _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { - return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT)); - } - }; - let Some(event) = event else { - return Err(anyhow!("ThreadEvent channel ended early")); - }; - - match event { - ThreadEvent::Stopped(reason) => match reason { - Ok(StopReason::EndTurn) => { - return Ok(()); - } - Ok(StopReason::MaxTokens) => { - return Err(anyhow!("Exceeded maximum tokens")); - } - Ok(StopReason::ToolUse) => { - if std::env::var("ZED_EVAL_DEBUG").is_ok() { - println!("{}StopReason: Tool use", log_prefix); - } - } - Err(error) => { - return Err(anyhow!(error.clone())); - } - }, - ThreadEvent::ShowError(thread_error) => { - break Err(anyhow!(thread_error.clone())); - } - ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => { - } - ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - .. - } => { - thread.update(cx, |thread, _cx| { - if let Some(tool_use) = pending_tool_use { - let mut tool_metrics = tool_metrics.lock().unwrap(); - if let Some(tool_result) = thread.tool_result(&tool_use_id) { - let message = if tool_result.is_error { - format!("TOOL FAILED: {}", tool_use.name) - } else { - format!("TOOL FINISHED: {}", tool_use.name) - }; - println!("{log_prefix}{message}"); - tool_metrics.insert(tool_result.tool_name.clone(), !tool_result.is_error); - } else { - let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); - println!("{log_prefix}{message}"); - tool_metrics.insert(tool_use.name.clone(), true); - } - } - })?; - } - ThreadEvent::ToolConfirmationNeeded => { - panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix); - }, - ThreadEvent::StreamedToolUse { .. } | - ThreadEvent::StreamedCompletion | - ThreadEvent::MessageAdded(_) | - ThreadEvent::MessageEdited(_) | - ThreadEvent::MessageDeleted(_) | - ThreadEvent::SummaryChanged | - ThreadEvent::SummaryGenerated | - ThreadEvent::CheckpointChanged | - ThreadEvent::ReceivedTextChunk | - ThreadEvent::UsageUpdated(_) => { - if std::env::var("ZED_EVAL_DEBUG").is_ok() { - println!("{}Event: {:#?}", log_prefix, event); - } - } - } - } - } - }); - - thread.update(cx, |thread, cx| { - let context = vec![]; - thread.insert_user_message(this.prompt.clone(), context, None, cx); - thread.send_to_model(model, cx); - })?; - - event_handler_task.await?; - - println!("{}Stopped", this.log_prefix); - - if let Some((_, language_file_buffer)) = lsp.as_ref() { - wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?; - } - - println!("{}Getting repository diff", this.log_prefix); - let repository_diff = this.repository_diff().await?; - std::fs::write(last_diff_file_path, &repository_diff)?; - - println!("{}Getting diagnostics", this.log_prefix); - let diagnostic_summary_after = project.read_with(cx, |project, cx| { - project.diagnostic_summary(false, cx) - })?; - let diagnostics_after = cx - .update(move |cx| { - cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) - })? - .await?; - println!("{}Got diagnostics", this.log_prefix); - - let Some(last_request) = last_request.borrow_mut().take() else { - return Err(anyhow!("No requests ran.")); - }; - - drop(subscription); - drop(lsp); - - if let Some(diagnostics_before) = &diagnostics_before { - fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?; - } - - if let Some(diagnostics_after) = &diagnostics_after { - fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?; - } - - - thread.update(cx, |thread, _cx| { - let response_count = thread - .messages() - .filter(|message| message.role == language_model::Role::Assistant) - .count(); - RunOutput { - repository_diff, - ran_diagnostics_check: this.base.require_lsp, - diagnostic_summary_before, - diagnostic_summary_after, - diagnostics_before, - diagnostics_after, - response_count, - token_usage: thread.cumulative_token_usage(), - tool_metrics: tool_metrics.lock().unwrap().clone(), - last_request, - } - }) - }) - } - - async fn judge_diff( - &self, - model: Arc, - run_output: &RunOutput, - cx: &AsyncApp, - ) -> Result<(String, JudgeResponse)> { - let judge_diff_prompt = include_str!("judge_diff_prompt.hbs"); - let judge_diff_prompt_name = "judge_diff_prompt"; - let mut hbs = Handlebars::new(); - hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?; - - let diff_prompt = hbs.render( - judge_diff_prompt_name, - &JudgeDiffInput { - repository_diff: run_output.repository_diff.clone(), - criteria: self.diff_criteria.clone(), - }, - )?; - - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text(diff_prompt)], - cache: false, - }], - temperature: None, - tools: Vec::new(), - stop: Vec::new(), - }; - - let diff_response = send_language_model_request(model, request, cx).await?; - let diff_output = JudgeResponse::parse(&diff_response)?; - - println!( - "{}Judge - Diff score: {}%", - self.log_prefix, - diff_output.score() - ); - - Ok((diff_response, diff_output)) - } - - async fn judge_thread( - &self, - model: Arc, - run_output: &RunOutput, - cx: &AsyncApp, - ) -> Result<(String, JudgeResponse)> { - let judge_thread_prompt = include_str!("judge_thread_prompt.hbs"); - let judge_thread_prompt_name = "judge_thread_prompt"; - let mut hbs = Handlebars::new(); - hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?; - - let request_markdown = RequestMarkdown::new(&run_output.last_request); - let thread_prompt = hbs.render( - judge_thread_prompt_name, - &JudgeThreadInput { - messages: request_markdown.messages, - criteria: self.thread_criteria.clone(), - }, - )?; - - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text(thread_prompt)], - cache: false, - }], - temperature: None, - tools: Vec::new(), - stop: Vec::new(), - }; - - let thread_response = send_language_model_request(model, request, cx).await?; - let thread_output = JudgeResponse::parse(&thread_response)?; - - println!( - "{}Judge - Thread score: {}%", - self.log_prefix, - thread_output.score() - ); - - Ok((thread_response, thread_output)) - } - - pub async fn judge( - &self, - model: Arc, - run_output: &RunOutput, - cx: &AsyncApp, - ) -> Result { - let mut output_file = File::create(self.run_directory_path().join("judge.md")) - .expect("failed to create judge.md"); - - println!("{}Running judge", self.log_prefix); - - let diff_task = self.judge_diff(model.clone(), &run_output, cx); - let thread_task = self.judge_thread(model.clone(), &run_output, cx); - - let (diff_result, thread_result) = futures::join!(diff_task, thread_task); - - let (diff_response, diff_output) = diff_result?; - let (thread_response, thread_output) = thread_result?; - - writeln!( - &mut output_file, - "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}", - ) - .log_err(); - - Ok(JudgeOutput { - thread: thread_output, - diff: diff_output, - }) - } - - async fn repository_diff(&self) -> Result { - let worktree_path = self.worktree_path(); - - run_git(&worktree_path, &["add", "."]).await?; - let mut diff_args = vec!["diff", "--staged"]; - if self.base.url == ZED_REPO_URL { - diff_args.push(":(exclude).rules"); + if result.is_ok() { + println!("{}✅ {}", self.log_prefix, message); + } else { + println!("{}❌ {}", self.log_prefix, message); } - run_git(&worktree_path, &diff_args).await - } -} -fn wait_for_lang_server( - project: &Entity, - buffer: &Entity, - log_prefix: String, - cx: &mut AsyncApp, -) -> Task> { - println!("{}⏵ Waiting for language server", log_prefix); - - let (mut tx, mut rx) = mpsc::channel(1); - - let lsp_store = project - .update(cx, |project, _| project.lsp_store()) - .unwrap(); - - let has_lang_server = buffer - .update(cx, |buffer, cx| { - lsp_store.update(cx, |lsp_store, cx| { - lsp_store - .language_servers_for_local_buffer(&buffer, cx) - .next() - .is_some() - }) - }) - .unwrap_or(false); - - if has_lang_server { - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .unwrap() - .detach(); - } - - let subscriptions = - [ - cx.subscribe(&lsp_store, { - let log_prefix = log_prefix.clone(); - move |_, event, _| match event { - project::LspStoreEvent::LanguageServerUpdate { - message: - client::proto::update_language_server::Variant::WorkProgress( - LspWorkProgress { - message: Some(message), - .. - }, - ), - .. - } => println!("{}⟲ {message}", log_prefix), - _ => {} - } - }), - cx.subscribe(&project, { - let buffer = buffer.clone(); - move |project, event, cx| match event { - project::Event::LanguageServerAdded(_, _, _) => { - let buffer = buffer.clone(); - project - .update(cx, |project, cx| project.save_buffer(buffer, cx)) - .detach(); - } - project::Event::DiskBasedDiagnosticsFinished { .. } => { - tx.try_send(()).ok(); - } - _ => {} - } - }), - ]; - - cx.spawn(async move |cx| { - let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); - let result = futures::select! { - _ = rx.next() => { - println!("{}⚑ Language server idle", log_prefix); - anyhow::Ok(()) - }, - _ = timeout.fuse() => { - Err(anyhow!("LSP wait timed out after 5 minutes")) - } - }; - drop(subscriptions); result - }) -} - -async fn query_lsp_diagnostics( - project: Entity, - cx: &mut AsyncApp, -) -> Result> { - let paths_with_diagnostics = project.update(cx, |project, cx| { - project - .diagnostic_summaries(true, cx) - .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0) - .map(|(project_path, _, _)| project_path) - .collect::>() - })?; - - if paths_with_diagnostics.is_empty() { - return Ok(None); } - let mut output = String::new(); - for project_path in paths_with_diagnostics { - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - for (_, group) in snapshot.diagnostic_groups(None) { - let entry = &group.entries[group.primary_ix]; - let range = entry.range.to_point(&snapshot); - let severity = match entry.diagnostic.severity { - DiagnosticSeverity::ERROR => "error", - DiagnosticSeverity::WARNING => "warning", - _ => continue, - }; - - writeln!( - output, - "{} at line {}: {}", - severity, - range.start.row + 1, - entry.diagnostic.message - )?; - } - } - anyhow::Ok(Some(output)) -} - -impl JudgeResponse { - fn parse(response: &str) -> Result { - let analysis = get_tag("analysis", response)?.to_string(); - let passing_criteria = get_tag("passing_criteria", response)? - .parse() - .context("error parsing score")?; - let total_criteria = get_tag("total_criteria", response)? - .parse() - .context("error parsing score")?; - Ok(Self { - analysis, - total_criteria, - passing_criteria, - }) + pub async fn run_to_end(&mut self) -> Result { + self.run_turns(u32::MAX).await } - pub fn score(&self) -> u32 { - (100.0 * self.passing_criteria as f32 / self.total_criteria as f32).round() as u32 + pub async fn run_turn(&mut self) -> Result { + self.run_turns(1).await } -} -fn get_tag(name: &'static str, response: &str) -> Result { - let start_tag = format!("<{}>", name); - let end_tag = format!("", name); + pub async fn run_turns(&mut self, iterations: u32) -> Result { + let (mut tx, mut rx) = mpsc::channel(1); - let start_ix = response - .find(&start_tag) - .context(format!("{} start tag not found", name))?; - let content_start_ix = start_ix + start_tag.len(); - - let end_ix = content_start_ix - + response[content_start_ix..] - .find(&end_tag) - .context(format!("{} end tag not found", name))?; - - let content = response[content_start_ix..end_ix].trim().unindent(); - - anyhow::Ok(content) -} - -pub fn repo_path_for_url(repos_dir: &Path, repo_url: &str) -> PathBuf { - let repo_name = repo_url - .trim_start_matches("https://") - .replace(|c: char| !c.is_alphanumeric(), "-"); - Path::new(repos_dir) - .canonicalize() - .context(format!("No such directory {}", repos_dir.display())) - .unwrap() - .join(repo_name) -} - -pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { - let output = new_smol_command("git") - .current_dir(repo_path) - .args(args) - .output() - .await?; - - if output.status.success() { - Ok(String::from_utf8(output.stdout)?.trim().to_string()) - } else { - Err(anyhow!( - "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", - args.join(" "), - repo_path.display(), - output.status, - String::from_utf8_lossy(&output.stderr), - String::from_utf8_lossy(&output.stdout), - )) - } -} - -pub async fn send_language_model_request( - model: Arc, - request: LanguageModelRequest, - cx: &AsyncApp, -) -> anyhow::Result { - match model.stream_completion_text(request, &cx).await { - Ok(mut stream) => { - let mut full_response = String::new(); - while let Some(chunk_result) = stream.stream.next().await { - match chunk_result { - Ok(chunk_str) => { - full_response.push_str(&chunk_str); + let tool_metrics = self.tool_metrics.clone(); + let log_prefix = self.log_prefix.clone(); + let _subscription = self.app.subscribe( + &self.agent_thread, + move |thread, event: &ThreadEvent, cx| match event { + ThreadEvent::ShowError(thread_error) => { + tx.try_send(Err(anyhow!(thread_error.clone()))).ok(); + } + ThreadEvent::Stopped(reason) => match reason { + Ok(StopReason::EndTurn) => { + tx.close_channel(); + } + Ok(StopReason::ToolUse) => { + if thread.read(cx).remaining_turns() == 0 { + tx.close_channel(); + } + } + Ok(StopReason::MaxTokens) => { + tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok(); } Err(err) => { - return Err(anyhow!( - "Error receiving response from language model: {err}" - )); + tx.try_send(Err(anyhow!(err.clone()))).ok(); } - } - } - Ok(full_response) - } - Err(err) => Err(anyhow!( - "Failed to get response from language model. Error was: {err}" - )), - } -} - -struct RequestMarkdown { - tools: String, - messages: String, -} - -impl RequestMarkdown { - fn new(request: &LanguageModelRequest) -> Self { - let mut tools = String::new(); - let mut messages = String::new(); - let mut assistant_message_number: u32 = 1; - - // Print the tools - if !request.tools.is_empty() { - for tool in &request.tools { - write!(&mut tools, "# {}\n\n", tool.name).unwrap(); - write!(&mut tools, "{}\n\n", tool.description).unwrap(); - write!( - &mut tools, - "{}\n", - MarkdownString::code_block("json", &format!("{:#}", tool.input_schema)) - ) - .unwrap(); - } - } - - // Print the messages - for message in &request.messages { - match message.role { - Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"), - Role::User => messages.push_str("# 👤 USER\n\n"), - Role::Assistant => { - messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n")); - assistant_message_number += 1; - } - }; - - for content in &message.content { - match content { - MessageContent::Text(text) => { - messages.push_str(text); - messages.push_str("\n\n"); - } - MessageContent::Image(_) => { - messages.push_str("[IMAGE DATA]\n\n"); - } - MessageContent::Thinking { text, signature } => { - messages.push_str("**Thinking**:\n\n"); - if let Some(sig) = signature { - messages.push_str(&format!("Signature: {}\n\n", sig)); + }, + ThreadEvent::StreamedAssistantText(_, _) + | ThreadEvent::StreamedAssistantThinking(_, _) + | ThreadEvent::UsePendingTools { .. } => {} + ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + .. + } => { + thread.update(cx, |thread, _cx| { + if let Some(tool_use) = pending_tool_use { + let mut tool_metrics = tool_metrics.lock().unwrap(); + if let Some(tool_result) = thread.tool_result(&tool_use_id) { + let message = if tool_result.is_error { + format!("TOOL FAILED: {}", tool_use.name) + } else { + format!("TOOL FINISHED: {}", tool_use.name) + }; + println!("{log_prefix}{message}"); + tool_metrics + .insert(tool_result.tool_name.clone(), !tool_result.is_error); + } else { + let message = + format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); + println!("{log_prefix}{message}"); + tool_metrics.insert(tool_use.name.clone(), true); + } } - messages.push_str(text); - messages.push_str("\n"); - } - MessageContent::RedactedThinking(items) => { - messages.push_str(&format!( - "**Redacted Thinking**: {} item(s)\n\n", - items.len() - )); - } - MessageContent::ToolUse(tool_use) => { - messages.push_str(&format!( - "**Tool Use**: {} (ID: {})\n", - tool_use.name, tool_use.id - )); - messages.push_str(&format!( - "{}\n", - MarkdownString::code_block("json", &format!("{:#}", tool_use.input)) - )); - } - MessageContent::ToolResult(tool_result) => { - messages.push_str(&format!( - "**Tool Result**: {} (ID: {})\n\n", - tool_result.tool_name, tool_result.tool_use_id - )); - if tool_result.is_error { - messages.push_str("**ERROR:**\n"); - } - messages.push_str(&format!("{}\n\n", tool_result.content)); + }); + } + ThreadEvent::ToolConfirmationNeeded => { + panic!( + "{}Bug: Tool confirmation should not be required in eval", + log_prefix + ); + } + ThreadEvent::StreamedCompletion + | ThreadEvent::MessageAdded(_) + | ThreadEvent::MessageEdited(_) + | ThreadEvent::MessageDeleted(_) + | ThreadEvent::SummaryChanged + | ThreadEvent::SummaryGenerated + | ThreadEvent::ReceivedTextChunk + | ThreadEvent::StreamedToolUse { .. } + | ThreadEvent::CheckpointChanged + | ThreadEvent::UsageUpdated(_) => { + tx.try_send(Ok(())).ok(); + if std::env::var("ZED_EVAL_DEBUG").is_ok() { + println!("{}Event: {:#?}", log_prefix, event); } } - } - } - - Self { tools, messages } - } -} - -fn response_events_to_markdown( - response_events: &[std::result::Result], -) -> String { - let mut response = String::new(); - // Print the response events if any - response.push_str("# Response\n\n"); - let mut text_buffer = String::new(); - let mut thinking_buffer = String::new(); - - let flush_buffers = - |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| { - if !text_buffer.is_empty() { - output.push_str(&format!("**Text**:\n{}\n\n", text_buffer)); - text_buffer.clear(); - } - if !thinking_buffer.is_empty() { - output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer)); - thinking_buffer.clear(); - } - }; - - for event in response_events { - match event { - Ok(LanguageModelCompletionEvent::Text(text)) => { - text_buffer.push_str(text); - } - Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => { - thinking_buffer.push_str(text); - } - Ok(LanguageModelCompletionEvent::Stop(reason)) => { - flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); - response.push_str(&format!("**Stop**: {:?}\n\n", reason)); - } - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { - flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); - response.push_str(&format!( - "**Tool Use**: {} (ID: {})\n", - tool_use.name, tool_use.id - )); - response.push_str(&format!( - "{}\n", - MarkdownString::code_block("json", &format!("{:#}", tool_use.input)) - )); - } - Ok( - LanguageModelCompletionEvent::UsageUpdate(_) - | LanguageModelCompletionEvent::StartMessage { .. }, - ) => {} - Err(error) => { - flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); - response.push_str(&format!("**Error**: {}\n\n", error)); - } - } - } - - flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); - - response -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_parse_judge_output() { - let response = r#" - The model did a good job but there were still compilations errors. - 3 - 5 - "# - .unindent(); - - let output = JudgeResponse::parse(&response).unwrap(); - assert_eq!( - output.analysis, - "The model did a good job but there were still compilations errors." + }, ); - assert_eq!(output.passing_criteria, 3); - assert_eq!(output.total_criteria, 5); - let response = r#" - Text around ignored + let model = self.model.clone(); - - Failed to compile: - - Error 1 - - Error 2 - + let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| { + thread.set_remaining_turns(iterations); + thread.send_to_model(model, cx); + thread.messages().len() + })?; - 1 + loop { + select_biased! { + result = rx.next() => { + if let Some(result) = result { + result?; + } else { + break; + } + } + _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => { + return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT)); + } + } + } - 3 - "# - .unindent(); + let messages = self.app.read_entity(&self.agent_thread, |thread, cx| { + let mut messages = Vec::new(); + for message in thread.messages().skip(message_count_before) { + messages.push(Message { + _role: message.role, + _text: message.to_string(), + tool_use: thread + .tool_uses_for_message(message.id, cx) + .into_iter() + .map(|tool_use| ToolUse { + name: tool_use.name.to_string(), + value: tool_use.input, + }) + .collect(), + }); + } + messages + })?; - let output = JudgeResponse::parse(&response).unwrap(); - assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2"); - assert_eq!(output.passing_criteria, 1); - assert_eq!(output.total_criteria, 3); + let response = Response::new(messages); + + Ok(response) + } +} + +#[derive(Debug)] +pub struct Response { + messages: Vec, +} + +impl Response { + pub fn new(messages: Vec) -> Self { + Self { messages } + } + + pub fn expect_tool( + &self, + tool_name: &'static str, + cx: &mut ExampleContext, + ) -> Result<&ToolUse> { + let result = self.messages.iter().find_map(|msg| { + msg.tool_use + .iter() + .find(|tool_use| tool_use.name == tool_name) + }); + cx.assert_some(result, format!("called `{}`", tool_name)) + } +} + +#[derive(Debug)] +pub struct Message { + _role: Role, + _text: String, + tool_use: Vec, +} + +#[derive(Debug)] +pub struct ToolUse { + name: String, + value: serde_json::Value, +} + +impl ToolUse { + pub fn expect_input(&self, cx: &mut ExampleContext) -> Result + where + Input: for<'de> serde::Deserialize<'de>, + { + let result = + serde_json::from_value::(self.value.clone()).map_err(|err| anyhow!(err)); + cx.log_assertion(result, format!("valid `{}` input", &self.name)) } } diff --git a/crates/eval/src/examples/file_search.rs b/crates/eval/src/examples/file_search.rs new file mode 100644 index 0000000000..2649c87506 --- /dev/null +++ b/crates/eval/src/examples/file_search.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use assistant_tools::PathSearchToolInput; +use async_trait::async_trait; +use regex::Regex; + +use crate::example::{Example, ExampleContext, ExampleMetadata}; + +pub struct FileSearchExample; + +#[async_trait(?Send)] +impl Example for FileSearchExample { + fn meta(&self) -> ExampleMetadata { + ExampleMetadata { + name: "file_search".to_string(), + url: "https://github.com/zed-industries/zed.git".to_string(), + revision: "03ecb88fe30794873f191ddb728f597935b3101c".to_string(), + language_server: None, + max_assertions: Some(4), + } + } + + async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { + const FILENAME: &str = "find_replace_file_tool.rs"; + cx.push_user_message(format!( + r#" + Look at the `{FILENAME}`. I want to implement a card for it. The card should implement the `Render` trait. + + The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for + markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green + background for lines that were added. We should have a div per diff line. + "# + )); + + let response = cx.run_turn().await?; + let tool_use = response.expect_tool("path_search", cx)?; + let input = tool_use.expect_input::(cx)?; + + let glob = input.glob; + cx.assert( + glob.ends_with(FILENAME), + format!("glob ends with `{FILENAME}`"), + )?; + + let without_filename = glob.replace(FILENAME, ""); + let matches = Regex::new("(\\*\\*|zed)/(\\*\\*?/)?") + .unwrap() + .is_match(&without_filename); + + cx.assert(matches, "glob starts with either `**` or `zed`")?; + + Ok(()) + } +} diff --git a/crates/eval/src/examples/find_and_replace_diff_card.toml b/crates/eval/src/examples/find_and_replace_diff_card.toml new file mode 100644 index 0000000000..0e1b9c3972 --- /dev/null +++ b/crates/eval/src/examples/find_and_replace_diff_card.toml @@ -0,0 +1,43 @@ +url = "https://github.com/zed-industries/zed.git" +revision = "38fcadf9481d018543c65f36ac3bafeba190179b" +language_extension = "rs" + +prompt = """ +Look at the `find_replace_file_tool.rs`. I want to implement a card for it. +The card should implement the `Render` trait. + +The card should show a diff. It should be a beautifully presented diff. +The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). +I want to see a red background for lines that were deleted and a green background for lines +that were added. We should have a div per diff line. +""" + +[diff_assertions] + +modify_find_and_replace_tool = """ +The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. +The struct should contain an `output` field that is the same as the task we were returning before, +and a new `card` field that contains a view for the card. +""" + +card_implementation = """ +The card should be a view that displays a diff. +Each line in the diff should be colored according to whether it was added, removed or unchanged. +""" + +[thread_assertions] + +path_search = """ +The first tool call should be to path search including "find_replace_file_tool.rs" in the string. +(*Not* grep, for example, or reading the file based on a guess at the path.) +This is because we gave the model a filename and it needs to turn that into a real path. +""" + +read_file_from_path_search = """ +After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path. +""" + +symbol_search = """ +When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information +on what path the Render trait might be in. +""" diff --git a/crates/eval/src/examples/mod.rs b/crates/eval/src/examples/mod.rs new file mode 100644 index 0000000000..5e30f8ec80 --- /dev/null +++ b/crates/eval/src/examples/mod.rs @@ -0,0 +1,128 @@ +use anyhow::Result; +use async_trait::async_trait; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::fs; +use std::{ + path::{Path, PathBuf}, + rc::Rc, +}; +use util::serde::default_true; + +use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion}; + +mod file_search; + +pub fn all(examples_dir: &Path) -> Vec> { + let mut threads: Vec> = vec![Rc::new(file_search::FileSearchExample)]; + + for example_path in list_declarative_examples(examples_dir).unwrap() { + threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap())); + } + + threads +} + +struct DeclarativeExample { + metadata: ExampleMetadata, + prompt: String, + diff_assertions: Vec, + thread_assertions: Vec, +} + +impl DeclarativeExample { + pub fn load(example_path: &Path) -> Result { + let name = Self::name_from_path(example_path); + let base: ExampleToml = toml::from_str(&fs::read_to_string(&example_path)?)?; + + let language_server = if base.require_lsp { + Some(crate::example::LanguageServer { + file_extension: base + .language_extension + .expect("Language extension is required when require_lsp = true"), + allow_preexisting_diagnostics: base.allow_preexisting_diagnostics, + }) + } else { + None + }; + + let metadata = ExampleMetadata { + name, + url: base.url, + revision: base.revision, + language_server, + max_assertions: None, + }; + + Ok(DeclarativeExample { + metadata, + prompt: base.prompt, + thread_assertions: base + .thread_assertions + .into_iter() + .map(|(id, description)| JudgeAssertion { id, description }) + .collect(), + diff_assertions: base + .diff_assertions + .into_iter() + .map(|(id, description)| JudgeAssertion { id, description }) + .collect(), + }) + } + + pub fn name_from_path(path: &Path) -> String { + path.file_stem().unwrap().to_string_lossy().to_string() + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct ExampleToml { + pub url: String, + pub revision: String, + pub language_extension: Option, + pub insert_id: Option, + #[serde(default = "default_true")] + pub require_lsp: bool, + #[serde(default)] + pub allow_preexisting_diagnostics: bool, + pub prompt: String, + #[serde(default)] + pub diff_assertions: BTreeMap, + #[serde(default)] + pub thread_assertions: BTreeMap, +} + +#[async_trait(?Send)] +impl Example for DeclarativeExample { + fn meta(&self) -> ExampleMetadata { + self.metadata.clone() + } + + async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> { + cx.push_user_message(&self.prompt); + let _ = cx.run_to_end().await; + Ok(()) + } + + fn diff_assertions(&self) -> Vec { + self.diff_assertions.clone() + } + + fn thread_assertions(&self) -> Vec { + self.thread_assertions.clone() + } +} + +fn list_declarative_examples(examples_dir: &Path) -> Result> { + let path = std::fs::canonicalize(examples_dir).unwrap(); + let entries = std::fs::read_dir(path).unwrap(); + let mut result_paths = Vec::new(); + for entry in entries { + let entry = entry?; + let path = entry.path(); + if path.extension() == Some("toml".as_ref()) { + result_paths.push(path); + } + } + Ok(result_paths) +} diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs new file mode 100644 index 0000000000..91e5f24c0b --- /dev/null +++ b/crates/eval/src/instance.rs @@ -0,0 +1,1023 @@ +use agent::ThreadStore; +use anyhow::{Context, Result, anyhow, bail}; +use assistant_tool::ToolWorkingSet; +use client::proto::LspWorkProgress; +use futures::channel::mpsc; +use futures::{FutureExt as _, StreamExt as _, future}; +use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; +use handlebars::Handlebars; +use language::{Buffer, DiagnosticSeverity, OffsetRangeExt as _}; +use language_model::{ + LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, + MessageContent, Role, TokenUsage, +}; +use project::lsp_store::OpenLspBufferHandle; +use project::{DiagnosticSummary, Project, ProjectPath}; +use serde::{Deserialize, Serialize}; +use std::cell::RefCell; +use std::fmt::Write as _; +use std::fs; +use std::fs::File; +use std::io::Write as _; +use std::path::Path; +use std::path::PathBuf; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; +use unindent::Unindent as _; +use util::ResultExt as _; +use util::command::new_smol_command; +use util::markdown::MarkdownString; + +use crate::assertions::{AssertionsReport, RanAssertion, RanAssertionResult}; +use crate::example::{Example, ExampleContext, FailedAssertion, JudgeAssertion}; +use crate::{AgentAppState, ToolMetrics}; + +pub const ZED_REPO_URL: &str = "https://github.com/zed-industries/zed.git"; + +#[derive(Clone)] +pub struct ExampleInstance { + pub thread: Rc, + pub name: String, + pub run_directory: PathBuf, + pub log_prefix: String, + /// The repetition number for this example (0-based) + /// When running multiple repetitions of the same example, each instance is assigned a unique repetition number. + /// This affects the worktree path and log prefix to avoid clobbering results between runs. + pub repetition: usize, + pub repo_path: PathBuf, + /// Path to the directory containing the requests and responses for the agentic loop + worktrees_dir: PathBuf, +} + +#[derive(Debug, Serialize, Clone)] +pub struct RunOutput { + pub repository_diff: String, + pub diagnostic_summary_before: DiagnosticSummary, + pub diagnostic_summary_after: DiagnosticSummary, + pub diagnostics_before: Option, + pub diagnostics_after: Option, + pub response_count: usize, + pub token_usage: TokenUsage, + pub tool_metrics: ToolMetrics, + pub last_request: LanguageModelRequest, + pub programmatic_assertions: AssertionsReport, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeDiffInput { + pub repository_diff: String, + pub assertion: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeThreadInput { + pub messages: String, + pub assertion: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeOutput { + pub thread: AssertionsReport, + pub diff: AssertionsReport, +} + +impl ExampleInstance { + pub fn new( + thread: Rc, + repos_dir: &Path, + run_dir: &Path, + worktrees_dir: &Path, + repetition: usize, + ) -> Self { + let name = thread.meta().name.to_string(); + let run_directory = run_dir + .join(&name) + .join(repetition.to_string()) + .to_path_buf(); + + let repo_path = repo_path_for_url(repos_dir, &thread.meta().url); + + Self { + name, + thread, + log_prefix: String::new(), + run_directory, + repetition, + repo_path, + worktrees_dir: worktrees_dir.to_path_buf(), + } + } + + pub fn repo_url(&self) -> String { + self.thread.meta().url + } + + pub fn revision(&self) -> String { + self.thread.meta().revision + } + + pub fn worktree_name(&self) -> String { + format!("{}-{}", self.name, self.repetition) + } + + pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) { + self.log_prefix = format!( + "{}{: Result<()> { + let meta = self.thread.meta(); + + let revision_exists = run_git( + &self.repo_path, + &["rev-parse", &format!("{}^{{commit}}", &meta.revision)], + ) + .await + .is_ok(); + + if !revision_exists { + println!("{}Fetching revision {}", self.log_prefix, &meta.revision); + run_git( + &self.repo_path, + &["fetch", "--depth", "1", "origin", &meta.revision], + ) + .await?; + } + Ok(()) + } + + /// Set up the example by checking out the specified Git revision + pub async fn setup(&mut self) -> Result<()> { + let worktree_path = self.worktree_path(); + let meta = self.thread.meta(); + if worktree_path.is_dir() { + println!("{}Resetting existing worktree", self.log_prefix); + + // TODO: consider including "-x" to remove ignored files. The downside of this is that + // it will also remove build artifacts, and so prevent incremental reuse there. + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", &meta.revision]).await?; + } else { + println!("{}Creating worktree", self.log_prefix); + + let worktree_path_string = worktree_path.to_string_lossy().to_string(); + + run_git( + &self.repo_path, + &[ + "worktree", + "add", + "-f", + &worktree_path_string, + &meta.revision, + ], + ) + .await?; + } + + if meta.url == ZED_REPO_URL { + std::fs::write(worktree_path.join(".rules"), std::fs::read(".rules")?)?; + } + + std::fs::create_dir_all(&self.run_directory)?; + + Ok(()) + } + + pub fn worktree_path(&self) -> PathBuf { + self.worktrees_dir + .join(self.worktree_name()) + .join(self.thread.meta().repo_name()) + } + + pub fn run( + &self, + model: Arc, + app_state: Arc, + cx: &mut App, + ) -> Task> { + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ); + + let worktree = project.update(cx, |project, cx| { + project.create_worktree(self.worktree_path(), true, cx) + }); + + let tools = cx.new(|_| ToolWorkingSet::default()); + let thread_store = + ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + let meta = self.thread.meta(); + let this = self.clone(); + + cx.spawn(async move |cx| { + let worktree = worktree.await?; + + // Wait for worktree scan to finish before choosing a file to open. + worktree + .update(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + struct LanguageServerState { + _lsp_open_handle: OpenLspBufferHandle, + language_file_buffer: Entity, + } + + let mut diagnostics_before = None; + let mut diagnostic_summary_before = DiagnosticSummary::default(); + + let lsp = if let Some(language_server) = &meta.language_server { + // Open a file that matches the language to cause LSP to start. + let language_file = worktree.read_with(cx, |worktree, _cx| { + worktree + .files(false, 0) + .find_map(|e| { + if e.path.clone().extension().and_then(|ext| ext.to_str()) + == Some(&language_server.file_extension) + { + Some(ProjectPath { + worktree_id: worktree.id(), + path: e.path.clone(), + }) + } else { + None + } + }) + .context("Failed to find a file for example language") + })??; + + let open_language_file_buffer_task = project.update(cx, |project, cx| { + project.open_buffer(language_file.clone(), cx) + })?; + + let language_file_buffer = open_language_file_buffer_task.await?; + + let lsp_open_handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&language_file_buffer, cx) + })?; + + wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?; + + diagnostic_summary_before = project.read_with(cx, |project, cx| { + project.diagnostic_summary(false, cx) + })?; + + diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?; + if diagnostics_before.is_some() && language_server.allow_preexisting_diagnostics { + return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`")); + } + + Some(LanguageServerState { + _lsp_open_handle: lsp_open_handle, + language_file_buffer, + }) + } else { + None + }; + + if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { + return Err(anyhow!("Setup only mode")); + } + + let last_diff_file_path = this.run_directory.join("last.diff"); + + // Write an empty "last.diff" so that it can be opened in Zed for convenient view of the + // history using undo/redo. + std::fs::write(&last_diff_file_path, "")?; + + let thread_store = thread_store.await?; + let thread = + thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; + let last_request = Rc::new(RefCell::new(None)); + + thread.update(cx, |thread, _cx| { + let mut request_count = 0; + let last_request = Rc::clone(&last_request); + let previous_diff = Rc::new(RefCell::new("".to_string())); + let example_output_dir = this.run_directory.clone(); + let last_diff_file_path = last_diff_file_path.clone(); + let this = this.clone(); + thread.set_request_callback(move |request, response_events| { + *last_request.borrow_mut() = Some(request.clone()); + + request_count += 1; + let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md")); + let diff_file_path = example_output_dir.join(format!("{request_count}.diff")); + let last_messages_file_path = example_output_dir.join("last.messages.md"); + let request_markdown = RequestMarkdown::new(request); + let response_events_markdown = response_events_to_markdown(response_events); + + let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown); + fs::write(&messages_file_path, messages.clone()).expect("failed to write messages file"); + fs::write(&last_messages_file_path, messages).expect("failed to write last messages file"); + + let diff_result = smol::block_on(this.repository_diff()); + match diff_result { + Ok(diff) => { + if diff != previous_diff.borrow().clone() { + fs::write(&diff_file_path, &diff).expect("failed to write diff file"); + fs::write(&last_diff_file_path, &diff).expect("failed to write last diff file"); + *previous_diff.borrow_mut() = diff; + } + } + Err(err) => { + let error_message = format!("{err:?}"); + fs::write(&diff_file_path, &error_message).expect("failed to write diff error to file"); + fs::write(&last_diff_file_path, &error_message).expect("failed to write last diff file"); + } + } + + if request_count == 1 { + let tools_file_path = example_output_dir.join("tools.md"); + fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file"); + } + }); + })?; + + let mut example_cx = ExampleContext::new(meta.clone(), this.log_prefix.clone(), thread.clone(), model.clone(), cx.clone()); + let result = this.thread.conversation(&mut example_cx).await; + + if let Err(err) = result { + if !err.is::() { + return Err(err); + } + } + + println!("{}Stopped", this.log_prefix); + + println!("{}Getting repository diff", this.log_prefix); + let repository_diff = this.repository_diff().await?; + + std::fs::write(last_diff_file_path, &repository_diff)?; + + + let mut diagnostics_after = None; + let mut diagnostic_summary_after = Default::default(); + + if let Some(language_server_state) = lsp { + wait_for_lang_server(&project, &language_server_state.language_file_buffer, this.log_prefix.clone(), cx).await?; + + println!("{}Getting diagnostics", this.log_prefix); + diagnostics_after = cx + .update(|cx| { + let project = project.clone(); + cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) + })? + .await?; + println!("{}Got diagnostics", this.log_prefix); + + diagnostic_summary_after = project.read_with(cx, |project, cx| { + project.diagnostic_summary(false, cx) + })?; + + } + + let Some(last_request) = last_request.borrow_mut().take() else { + return Err(anyhow!("No requests ran.")); + }; + + if let Some(diagnostics_before) = &diagnostics_before { + fs::write(this.run_directory.join("diagnostics_before.txt"), diagnostics_before)?; + } + + if let Some(diagnostics_after) = &diagnostics_after { + fs::write(this.run_directory.join("diagnostics_after.txt"), diagnostics_after)?; + } + + thread.update(cx, |thread, _cx| { + let response_count = thread + .messages() + .filter(|message| message.role == language_model::Role::Assistant) + .count(); + RunOutput { + repository_diff, + diagnostic_summary_before, + diagnostic_summary_after, + diagnostics_before, + diagnostics_after, + response_count, + token_usage: thread.cumulative_token_usage(), + tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(), + last_request, + programmatic_assertions: example_cx.assertions, + } + }) + }) + } + + async fn repository_diff(&self) -> Result { + let worktree_path = self.worktree_path(); + run_git(&worktree_path, &["add", "."]).await?; + let mut diff_args = vec!["diff", "--staged"]; + if self.thread.meta().url == ZED_REPO_URL { + diff_args.push(":(exclude).rules"); + } + run_git(&worktree_path, &diff_args).await + } + + pub async fn judge( + &self, + model: Arc, + run_output: &RunOutput, + cx: &AsyncApp, + ) -> JudgeOutput { + let mut output_file = + File::create(self.run_directory.join("judge.md")).expect("failed to create judge.md"); + + let diff_task = self.judge_diff(model.clone(), &run_output, cx); + let thread_task = self.judge_thread(model.clone(), &run_output, cx); + + let (diff_result, thread_result) = futures::join!(diff_task, thread_task); + + let (diff_response, diff_output) = diff_result; + let (thread_response, thread_output) = thread_result; + + writeln!( + &mut output_file, + "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}", + ) + .log_err(); + + JudgeOutput { + thread: thread_output, + diff: diff_output, + } + } + + async fn judge_diff( + &self, + model: Arc, + run_output: &RunOutput, + cx: &AsyncApp, + ) -> (String, AssertionsReport) { + let diff_assertions = self.thread.diff_assertions(); + + if diff_assertions.is_empty() { + return ( + "No diff assertions".to_string(), + AssertionsReport::default(), + ); + } + + println!("{}Running diff judge", self.log_prefix); + + let judge_diff_prompt = include_str!("judge_diff_prompt.hbs"); + let judge_diff_prompt_name = "judge_diff_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt) + .unwrap(); + + let to_prompt = |assertion: String| { + hbs.render( + judge_diff_prompt_name, + &JudgeDiffInput { + repository_diff: run_output.repository_diff.clone(), + assertion, + }, + ) + .unwrap() + }; + + let (responses, report) = self + .judge_assertions(model, diff_assertions, to_prompt, cx) + .await; + + println!( + "{}Judge - Diff score: {}%", + self.log_prefix, + report.passed_percentage() + ); + + (responses, report) + } + + async fn judge_thread( + &self, + model: Arc, + run_output: &RunOutput, + cx: &AsyncApp, + ) -> (String, AssertionsReport) { + let thread_assertions = self.thread.thread_assertions(); + + if thread_assertions.is_empty() { + return ( + "No diff assertions".to_string(), + AssertionsReport::default(), + ); + } + + let judge_thread_prompt = include_str!("judge_thread_prompt.hbs"); + let judge_diff_prompt_name = "judge_thread_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_diff_prompt_name, judge_thread_prompt) + .unwrap(); + + let request_markdown = RequestMarkdown::new(&run_output.last_request); + let to_prompt = |assertion: String| { + hbs.render( + judge_diff_prompt_name, + &JudgeThreadInput { + messages: request_markdown.messages.clone(), + assertion, + }, + ) + .unwrap() + }; + + let (responses, report) = self + .judge_assertions(model, thread_assertions, to_prompt, cx) + .await; + + println!( + "{}Judge - Thread score: {}%", + self.log_prefix, + report.passed_percentage() + ); + + (responses, report) + } + + async fn judge_assertions( + &self, + model: Arc, + assertions: Vec, + to_prompt: impl Fn(String) -> String, + cx: &AsyncApp, + ) -> (String, AssertionsReport) { + let assertions = assertions.into_iter().map(|assertion| { + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(to_prompt(assertion.description))], + cache: false, + }], + temperature: None, + tools: Vec::new(), + stop: Vec::new(), + }; + + let model = model.clone(); + let log_prefix = self.log_prefix.clone(); + async move { + let response = send_language_model_request(model, request, cx).await; + + let (response, result) = match response { + Ok(response) => ( + response.clone(), + parse_assertion_result(&response).map_err(|err| err.to_string()), + ), + Err(err) => (err.to_string(), Err(err.to_string())), + }; + + if result.is_ok() { + println!("{}✅ {}", log_prefix, assertion.id); + } else { + println!("{}❌ {}", log_prefix, assertion.id); + } + + ( + response, + RanAssertion { + id: assertion.id, + result, + }, + ) + } + }); + + let mut responses = String::new(); + let mut report = AssertionsReport::default(); + + for (response, assertion) in future::join_all(assertions).await { + writeln!(&mut responses, "# {}", assertion.id).unwrap(); + writeln!(&mut responses, "{}\n\n", response).unwrap(); + report.ran.push(assertion); + } + + (responses, report) + } +} + +pub fn wait_for_lang_server( + project: &Entity, + buffer: &Entity, + log_prefix: String, + cx: &mut AsyncApp, +) -> Task> { + if std::env::var("ZED_EVAL_SKIP_LS").is_ok() { + return Task::ready(Ok(())); + } + + println!("{}⏵ Waiting for language server", log_prefix); + + let (mut tx, mut rx) = mpsc::channel(1); + + let lsp_store = project + .update(cx, |project, _| project.lsp_store()) + .unwrap(); + + let has_lang_server = buffer + .update(cx, |buffer, cx| { + lsp_store.update(cx, |lsp_store, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .is_some() + }) + }) + .unwrap_or(false); + + if has_lang_server { + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .unwrap() + .detach(); + } + + let subscriptions = + [ + cx.subscribe(&lsp_store, { + let log_prefix = log_prefix.clone(); + move |_, event, _| match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{}⟲ {message}", log_prefix), + _ => {} + } + }), + cx.subscribe(&project, { + let buffer = buffer.clone(); + move |project, event, cx| match event { + project::Event::LanguageServerAdded(_, _, _) => { + let buffer = buffer.clone(); + project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + .detach(); + } + project::Event::DiskBasedDiagnosticsFinished { .. } => { + tx.try_send(()).ok(); + } + _ => {} + } + }), + ]; + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}⚑ Language server idle", log_prefix); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + Err(anyhow!("LSP wait timed out after 5 minutes")) + } + }; + drop(subscriptions); + result + }) +} + +pub async fn query_lsp_diagnostics( + project: Entity, + cx: &mut AsyncApp, +) -> Result> { + let paths_with_diagnostics = project.update(cx, |project, cx| { + project + .diagnostic_summaries(true, cx) + .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0) + .map(|(project_path, _, _)| project_path) + .collect::>() + })?; + + if paths_with_diagnostics.is_empty() { + return Ok(None); + } + + let mut output = String::new(); + for project_path in paths_with_diagnostics { + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for (_, group) in snapshot.diagnostic_groups(None) { + let entry = &group.entries[group.primary_ix]; + let range = entry.range.to_point(&snapshot); + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + _ => continue, + }; + + writeln!( + output, + "{} at line {}: {}", + severity, + range.start.row + 1, + entry.diagnostic.message + )?; + } + } + anyhow::Ok(Some(output)) +} + +fn parse_assertion_result(response: &str) -> Result { + let analysis = get_tag("analysis", response)?.to_string(); + let passed = match get_tag("passed", response)?.to_lowercase().as_str() { + "true" => true, + "false" => false, + value @ _ => bail!("invalid judge `passed` tag: {value}"), + }; + Ok(RanAssertionResult { + analysis: Some(analysis), + passed, + }) +} + +fn get_tag(name: &'static str, response: &str) -> Result { + let start_tag = format!("<{}>", name); + let end_tag = format!("", name); + + let start_ix = response + .find(&start_tag) + .context(format!("{} start tag not found", name))?; + let content_start_ix = start_ix + start_tag.len(); + + let end_ix = content_start_ix + + response[content_start_ix..] + .find(&end_tag) + .context(format!("{} end tag not found", name))?; + + let content = response[content_start_ix..end_ix].trim().unindent(); + + anyhow::Ok(content) +} + +pub fn repo_path_for_url(repos_dir: &Path, repo_url: &str) -> PathBuf { + let repo_name = repo_url + .trim_start_matches("https://") + .replace(|c: char| !c.is_alphanumeric(), "-"); + Path::new(repos_dir).join(repo_name) +} + +pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { + let output = new_smol_command("git") + .current_dir(repo_path) + .args(args) + .output() + .await?; + + if output.status.success() { + Ok(String::from_utf8(output.stdout)?.trim().to_string()) + } else { + Err(anyhow!( + "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", + args.join(" "), + repo_path.display(), + output.status, + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + )) + } +} + +pub async fn send_language_model_request( + model: Arc, + request: LanguageModelRequest, + cx: &AsyncApp, +) -> anyhow::Result { + match model.stream_completion_text(request, &cx).await { + Ok(mut stream) => { + let mut full_response = String::new(); + while let Some(chunk_result) = stream.stream.next().await { + match chunk_result { + Ok(chunk_str) => { + full_response.push_str(&chunk_str); + } + Err(err) => { + return Err(anyhow!( + "Error receiving response from language model: {err}" + )); + } + } + } + Ok(full_response) + } + Err(err) => Err(anyhow!( + "Failed to get response from language model. Error was: {err}" + )), + } +} + +pub struct RequestMarkdown { + pub tools: String, + pub messages: String, +} + +impl RequestMarkdown { + pub fn new(request: &LanguageModelRequest) -> Self { + let mut tools = String::new(); + let mut messages = String::new(); + let mut assistant_message_number: u32 = 1; + + // Print the tools + if !request.tools.is_empty() { + for tool in &request.tools { + write!(&mut tools, "# {}\n\n", tool.name).unwrap(); + write!(&mut tools, "{}\n\n", tool.description).unwrap(); + write!( + &mut tools, + "{}\n", + MarkdownString::code_block("json", &format!("{:#}", tool.input_schema)) + ) + .unwrap(); + } + } + + // Print the messages + for message in &request.messages { + match message.role { + Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"), + Role::User => messages.push_str("# 👤 USER\n\n"), + Role::Assistant => { + messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n")); + assistant_message_number += 1; + } + }; + + for content in &message.content { + match content { + MessageContent::Text(text) => { + messages.push_str(text); + messages.push_str("\n\n"); + } + MessageContent::Image(_) => { + messages.push_str("[IMAGE DATA]\n\n"); + } + MessageContent::Thinking { text, signature } => { + messages.push_str("**Thinking**:\n\n"); + if let Some(sig) = signature { + messages.push_str(&format!("Signature: {}\n\n", sig)); + } + messages.push_str(text); + messages.push_str("\n"); + } + MessageContent::RedactedThinking(items) => { + messages.push_str(&format!( + "**Redacted Thinking**: {} item(s)\n\n", + items.len() + )); + } + MessageContent::ToolUse(tool_use) => { + messages.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + messages.push_str(&format!( + "{}\n", + MarkdownString::code_block("json", &format!("{:#}", tool_use.input)) + )); + } + MessageContent::ToolResult(tool_result) => { + messages.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + messages.push_str("**ERROR:**\n"); + } + messages.push_str(&format!("{}\n\n", tool_result.content)); + } + } + } + } + + Self { tools, messages } + } +} + +pub fn response_events_to_markdown( + response_events: &[std::result::Result], +) -> String { + let mut response = String::new(); + // Print the response events if any + response.push_str("# Response\n\n"); + let mut text_buffer = String::new(); + let mut thinking_buffer = String::new(); + + let flush_buffers = + |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| { + if !text_buffer.is_empty() { + output.push_str(&format!("**Text**:\n{}\n\n", text_buffer)); + text_buffer.clear(); + } + if !thinking_buffer.is_empty() { + output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer)); + thinking_buffer.clear(); + } + }; + + for event in response_events { + match event { + Ok(LanguageModelCompletionEvent::Text(text)) => { + text_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => { + thinking_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Stop**: {:?}\n\n", reason)); + } + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + response.push_str(&format!( + "{}\n", + MarkdownString::code_block("json", &format!("{:#}", tool_use.input)) + )); + } + Ok( + LanguageModelCompletionEvent::UsageUpdate(_) + | LanguageModelCompletionEvent::StartMessage { .. }, + ) => {} + Err(error) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Error**: {}\n\n", error)); + } + } + } + + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + + response +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_judge_output() { + let response = r#" + The model did a good job but there were still compilations errors. + true + "# + .unindent(); + + let output = parse_assertion_result(&response).unwrap(); + assert_eq!( + output.analysis, + Some("The model did a good job but there were still compilations errors.".into()) + ); + assert_eq!(output.passed, true); + + let response = r#" + Text around ignored + + + Failed to compile: + - Error 1 + - Error 2 + + + false + "# + .unindent(); + + let output = parse_assertion_result(&response).unwrap(); + assert_eq!( + output.analysis, + Some("Failed to compile:\n- Error 1\n- Error 2".into()) + ); + assert_eq!(output.passed, false); + } +} diff --git a/crates/eval/src/judge_diff_prompt.hbs b/crates/eval/src/judge_diff_prompt.hbs index bf51f933ee..24ef9ac97e 100644 --- a/crates/eval/src/judge_diff_prompt.hbs +++ b/crates/eval/src/judge_diff_prompt.hbs @@ -1,5 +1,5 @@ -You are an expert software developer. Your task is to evaluate a diff produced by an AI agent in response to a prompt. -Here is the prompt and the diff: +You are an expert software developer. Your task is to evaluate a diff produced by an AI agent +in response to a prompt. Here is the prompt and the diff: {{{prompt}}} @@ -9,17 +9,17 @@ Here is the prompt and the diff: {{{repository_diff}}} -Evaluate how many of the following criteria were satisfied by the diff: +Evaluate whether or not the diff passes the following assertion: - -{{criteria}} -- There are no changes unrelated to the prompt - + +{{assertion}} + Analyze the diff hunk by hunk, and structure your answer in the following XML format: ``` {YOUR ANALYSIS HERE} -{THE TOTAL NUMBER OF CRITERIA THAT WERE LISTED} -{THE NUMBER OF CRITERIA THAT ARE MET BY THE DIFF} +{PASSED_ASSERTION} ``` + +Where `PASSED_ASSERTION` is either `true` or `false`. diff --git a/crates/eval/src/judge_thread_prompt.hbs b/crates/eval/src/judge_thread_prompt.hbs index aa1b6d6f6e..e80bafcce1 100644 --- a/crates/eval/src/judge_thread_prompt.hbs +++ b/crates/eval/src/judge_thread_prompt.hbs @@ -1,19 +1,21 @@ -You are an expert software developer. Your task is to evaluate an AI agent's messages and tool calls in this conversation: +You are an expert software developer. +Your task is to evaluate an AI agent's messages and tool calls in this conversation: {{{messages}}} -You must count how many of the following criteria were satisfied by the messages: +Evaluate whether or not the sequence of messages passes the following assertion: - -{{{criteria}}} - + +{{{assertion}}} + Analyze the messages one by one, and structure your answer in the following XML format: ``` {YOUR ANALYSIS HERE} -{THE TOTAL NUMBER OF CRITERIA THAT WERE LISTED} -{THE NUMBER OF CRITERIA THAT ARE MET BY THE MESSAGES} +{PASSED_ASSERTION} ``` + +Where `PASSED_ASSERTION` is either `true` or `false`. diff --git a/crates/eval/src/tool_metrics.rs b/crates/eval/src/tool_metrics.rs index e576cca822..63d8a4f2bc 100644 --- a/crates/eval/src/tool_metrics.rs +++ b/crates/eval/src/tool_metrics.rs @@ -24,6 +24,10 @@ impl ToolMetrics { *self.failure_counts.entry(tool_name.clone()).or_insert(0) += failure_count; } } + + pub fn is_empty(&self) -> bool { + self.use_counts.is_empty() && self.failure_counts.is_empty() + } } impl Display for ToolMetrics { @@ -79,7 +83,7 @@ impl Display for ToolMetrics { let failure_count = self.failure_counts.get(&tool_name).cloned().unwrap_or(0); writeln!( f, - "│{:^30}│{:^10}│{:^10}│{:^10}│", + "│{:<30}│{:^10}│{:^10}│{:^10}│", tool_name, use_count, failure_count,