From 7a888de9f5f29fbaa27b32f4b527f21f4253fb71 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Fri, 14 Mar 2025 17:10:25 -0600 Subject: [PATCH] Add initial implementation of evaluating changes generated by the assistant (#26799) Release Notes: - N/A --------- Co-authored-by: Richard Feldman Co-authored-by: Thomas --- Cargo.lock | 35 +++ Cargo.toml | 2 + crates/assistant2/src/active_thread.rs | 1 + crates/assistant2/src/assistant.rs | 3 + crates/assistant2/src/thread.rs | 58 ++-- crates/assistant2/src/tool_use.rs | 7 + crates/assistant_eval/Cargo.toml | 44 +++ crates/assistant_eval/LICENSE-GPL | 1 + crates/assistant_eval/README.md | 77 ++++++ crates/assistant_eval/build.rs | 61 +++++ crates/assistant_eval/src/eval.rs | 252 ++++++++++++++++++ .../assistant_eval/src/headless_assistant.rs | 241 +++++++++++++++++ crates/assistant_eval/src/judge.rs | 121 +++++++++ crates/assistant_eval/src/main.rs | 234 ++++++++++++++++ 14 files changed, 1113 insertions(+), 24 deletions(-) create mode 100644 crates/assistant_eval/Cargo.toml create mode 120000 crates/assistant_eval/LICENSE-GPL create mode 100644 crates/assistant_eval/README.md create mode 100644 crates/assistant_eval/build.rs create mode 100644 crates/assistant_eval/src/eval.rs create mode 100644 crates/assistant_eval/src/headless_assistant.rs create mode 100644 crates/assistant_eval/src/judge.rs create mode 100644 crates/assistant_eval/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 750e3e0ea8..55ba70e06d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -566,6 +566,41 @@ dependencies = [ "workspace", ] +[[package]] +name = "assistant_eval" +version = "0.1.0" +dependencies = [ + "anyhow", + "assistant2", + "assistant_tool", + "assistant_tools", + "clap", + "client", + "collections", + "context_server", + "env_logger 0.11.7", + "fs", + "futures 0.3.31", + "gpui", + "gpui_tokio", + "itertools 0.14.0", + "language", + "language_model", + "language_models", + "node_runtime", + "project", + "prompt_store", + "regex", + "release_channel", + "reqwest_client", + "serde", + "serde_json", + "serde_json_lenient", + "settings", + "smol", + "util", +] + [[package]] name = "assistant_settings" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index face468c70..d49fdbfcb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/assistant", "crates/assistant2", "crates/assistant_context_editor", + "crates/assistant_eval", "crates/assistant_settings", "crates/assistant_slash_command", "crates/assistant_slash_commands", @@ -206,6 +207,7 @@ assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant2 = { path = "crates/assistant2" } assistant_context_editor = { path = "crates/assistant_context_editor" } +assistant_eval = { path = "crates/assistant_eval" } assistant_settings = { path = "crates/assistant_settings" } assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 212ef3496f..e3e55333b6 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -298,6 +298,7 @@ impl ActiveThread { ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => { self.save_thread(cx); } + ThreadEvent::DoneStreaming => {} ThreadEvent::StreamedAssistantText(message_id, text) => { if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) { markdown.update(cx, |markdown, cx| { diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 33da6d7a75..5224b097cc 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -31,8 +31,11 @@ use gpui::{actions, App}; use prompt_store::PromptBuilder; use settings::Settings as _; +pub use crate::active_thread::ActiveThread; pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate}; pub use crate::inline_assistant::InlineAssistant; +pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent}; +pub use crate::thread_store::ThreadStore; actions!( assistant2, diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 9c744ab123..b9338c6e6a 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -284,6 +284,10 @@ impl Thread { self.tool_use.tool_results_for_message(id) } + pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { + self.tool_use.tool_result(id) + } + pub fn scripting_tool_results_for_message( &self, id: MessageId, @@ -652,32 +656,37 @@ impl Thread { let result = stream_completion.await; thread - .update(&mut cx, |thread, cx| match result.as_ref() { - Ok(stop_reason) => match stop_reason { - StopReason::ToolUse => { - cx.emit(ThreadEvent::UsePendingTools); - } - StopReason::EndTurn => {} - StopReason::MaxTokens => {} - }, - Err(error) => { - if error.is::() { - cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); - } else if error.is::() { - cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached)); - } else { - let error_message = error - .chain() - .map(|err| err.to_string()) - .collect::>() - .join("\n"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message( - SharedString::from(error_message.clone()), - ))); - } + .update(&mut cx, |thread, cx| { + match result.as_ref() { + Ok(stop_reason) => match stop_reason { + StopReason::ToolUse => { + cx.emit(ThreadEvent::UsePendingTools); + } + StopReason::EndTurn => {} + StopReason::MaxTokens => {} + }, + Err(error) => { + if error.is::() { + cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); + } else if error.is::() { + cx.emit(ThreadEvent::ShowError( + ThreadError::MaxMonthlySpendReached, + )); + } else { + let error_message = error + .chain() + .map(|err| err.to_string()) + .collect::>() + .join("\n"); + cx.emit(ThreadEvent::ShowError(ThreadError::Message( + SharedString::from(error_message.clone()), + ))); + } - thread.cancel_last_completion(); + thread.cancel_last_completion(); + } } + cx.emit(ThreadEvent::DoneStreaming); }) .ok(); }); @@ -1094,6 +1103,7 @@ pub enum ThreadEvent { ShowError(ThreadError), StreamedCompletion, StreamedAssistantText(MessageId, String), + DoneStreaming, MessageAdded(MessageId), MessageEdited(MessageId), MessageDeleted(MessageId), diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 34821d45f9..5170e77438 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -182,6 +182,13 @@ impl ToolUseState { .map_or(false, |results| !results.is_empty()) } + pub fn tool_result( + &self, + tool_use_id: &LanguageModelToolUseId, + ) -> Option<&LanguageModelToolResult> { + self.tool_results.get(tool_use_id) + } + pub fn request_tool_use( &mut self, assistant_message_id: MessageId, diff --git a/crates/assistant_eval/Cargo.toml b/crates/assistant_eval/Cargo.toml new file mode 100644 index 0000000000..d5e30d339d --- /dev/null +++ b/crates/assistant_eval/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "assistant_eval" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[[bin]] +name = "assistant_eval" +path = "src/main.rs" + +[dependencies] +anyhow.workspace = true +assistant2.workspace = true +assistant_tool.workspace = true +assistant_tools.workspace = true +clap.workspace = true +client.workspace = true +collections.workspace = true +context_server.workspace = true +env_logger.workspace = true +fs.workspace = true +futures.workspace = true +gpui.workspace = true +gpui_tokio.workspace = true +itertools.workspace = true +language.workspace = true +language_model.workspace = true +language_models.workspace = true +node_runtime.workspace = true +project.workspace = true +prompt_store.workspace = true +regex.workspace = true +release_channel.workspace = true +reqwest_client.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_json_lenient.workspace = true +settings.workspace = true +smol.workspace = true +util.workspace = true diff --git a/crates/assistant_eval/LICENSE-GPL b/crates/assistant_eval/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/assistant_eval/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_eval/README.md b/crates/assistant_eval/README.md new file mode 100644 index 0000000000..86c6a6cfbf --- /dev/null +++ b/crates/assistant_eval/README.md @@ -0,0 +1,77 @@ +# Tool Evals + +A framework for evaluating and benchmarking AI assistant performance in the Zed editor. + +## Overview + +Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of: + +1. Cloning and setting up test repositories +2. Sending prompts to language models +3. Allowing the assistant to use tools to modify code +4. Collecting metrics on performance +5. Evaluating results against known good solutions + +## How It Works + +The system consists of several key components: + +- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations +- **HeadlessAssistant**: Provides a headless environment for running the AI assistant +- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity + +The evaluation flow: +1. An evaluation is loaded from the evaluation_data directory +2. The target repository is cloned and checked out at a specific commit +3. A HeadlessAssistant instance is created with the specified language model +4. The user prompt is sent to the assistant +5. The assistant responds and uses tools to modify code +6. Upon completion, a diff is generated from the changes +7. Results are saved including the diff, assistant's response, and performance metrics +8. If a reference solution exists, a Judge evaluates the similarity of the solution + +## Setup Requirements + +### Prerequisites + +- Rust and Cargo +- Git +- Network access to clone repositories +- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.) + +### Environment Variables + +Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables: +- `ZED_ANTHROPIC_API_KEY` for Claude models +- `ZED_OPENAI_API_KEY` for OpenAI models +- `ZED_GITHUB_API_KEY` for GitHub API (or similar) + +## Usage + +### Running a Single Evaluation + +To run a specific evaluation: + +```bash +cargo run -p assistant_eval -- bubbletea-add-set-window-title +``` + +The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run: + +```bash +cargo run -p assistant_eval -- bubbletea +``` + +To run all evaluations: + +```bash +cargo run -p assistant_eval -- --all +``` + +## Evaluation Data Structure + +Each evaluation should be placed in the `evaluation_data` directory with the following structure: + +* `prompt.txt`: The user's prompt. +* `original.diff`: The `git diff` of the change anticipated for this prompt. +* `setup.json`: Information about the repo used for the evaluation. diff --git a/crates/assistant_eval/build.rs b/crates/assistant_eval/build.rs new file mode 100644 index 0000000000..6268b66f9c --- /dev/null +++ b/crates/assistant_eval/build.rs @@ -0,0 +1,61 @@ +// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows. + +use std::process::Command; + +fn main() { + if cfg!(target_os = "macos") { + println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7"); + + println!("cargo:rerun-if-env-changed=ZED_BUNDLE"); + if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") { + // Find WebRTC.framework in the Frameworks folder when running as part of an application bundle. + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks"); + } else { + // Find WebRTC.framework as a sibling of the executable when running outside of an application bundle. + println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path"); + } + + // Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+. + println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit"); + + // Seems to be required to enable Swift concurrency + println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift"); + + // Register exported Objective-C selectors, protocols, etc + println!("cargo:rustc-link-arg=-Wl,-ObjC"); + } + + // Populate git sha environment variable if git is available + println!("cargo:rerun-if-changed=../../.git/logs/HEAD"); + println!( + "cargo:rustc-env=TARGET={}", + std::env::var("TARGET").unwrap() + ); + if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() { + if output.status.success() { + let git_sha = String::from_utf8_lossy(&output.stdout); + let git_sha = git_sha.trim(); + + println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}"); + + if let Ok(build_profile) = std::env::var("PROFILE") { + if build_profile == "release" { + // This is currently the best way to make `cargo build ...`'s build script + // to print something to stdout without extra verbosity. + println!( + "cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var" + ); + } + } + } + } + + #[cfg(target_os = "windows")] + { + #[cfg(target_env = "msvc")] + { + // todo(windows): This is to avoid stack overflow. Remove it when solved. + println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024); + } + } +} diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs new file mode 100644 index 0000000000..98e72ec6b5 --- /dev/null +++ b/crates/assistant_eval/src/eval.rs @@ -0,0 +1,252 @@ +use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant}; +use anyhow::anyhow; +use assistant2::RequestKind; +use collections::HashMap; +use gpui::{App, Task}; +use language_model::{LanguageModel, TokenUsage}; +use serde::{Deserialize, Serialize}; +use std::{ + fs, + io::Write, + path::{Path, PathBuf}, + sync::Arc, + time::Duration, +}; +use util::command::new_smol_command; + +pub struct Eval { + pub name: String, + pub path: PathBuf, + pub repo_path: PathBuf, + pub eval_setup: EvalSetup, + pub user_prompt: String, +} + +#[derive(Debug, Serialize)] +pub struct EvalOutput { + pub diff: String, + pub last_message: String, + pub elapsed_time: Duration, + pub assistant_response_count: usize, + pub tool_use_counts: HashMap, u32>, + pub token_usage: TokenUsage, +} + +#[derive(Deserialize)] +pub struct EvalSetup { + pub url: String, + pub base_sha: String, +} + +impl Eval { + /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo + /// if necessary. + pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result { + let prompt_path = path.join("prompt.txt"); + let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?; + let setup_path = path.join("setup.json"); + let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?; + let eval_setup = serde_json_lenient::from_str_lenient::(&setup_contents)?; + let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url)); + Ok(Eval { + name, + path, + repo_path, + eval_setup, + user_prompt, + }) + } + + pub fn run( + self, + app_state: Arc, + model: Arc, + cx: &mut App, + ) -> Task> { + cx.spawn(move |mut cx| async move { + checkout_repo(&self.eval_setup, &self.repo_path).await?; + + let (assistant, done_rx) = + cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??; + + let _worktree = assistant + .update(&mut cx, |assistant, cx| { + assistant.project.update(cx, |project, cx| { + project.create_worktree(&self.repo_path, true, cx) + }) + })? + .await?; + + let start_time = std::time::SystemTime::now(); + + assistant.update(&mut cx, |assistant, cx| { + assistant.thread.update(cx, |thread, cx| { + let context = vec![]; + thread.insert_user_message(self.user_prompt.clone(), context, cx); + thread.send_to_model(model, RequestKind::Chat, cx); + }); + })?; + + done_rx.recv().await??; + + let elapsed_time = start_time.elapsed()?; + + let diff = query_git(&self.repo_path, vec!["diff"]).await?; + + assistant.update(&mut cx, |assistant, cx| { + let thread = assistant.thread.read(cx); + let last_message = thread.messages().last().unwrap(); + if last_message.role != language_model::Role::Assistant { + return Err(anyhow!("Last message is not from assistant")); + } + let assistant_response_count = thread + .messages() + .filter(|message| message.role == language_model::Role::Assistant) + .count(); + Ok(EvalOutput { + diff, + last_message: last_message.text.clone(), + elapsed_time, + assistant_response_count, + tool_use_counts: assistant.tool_use_counts.clone(), + token_usage: thread.cumulative_token_usage(), + }) + })? + }) + } +} + +impl EvalOutput { + // Method to save the output to a directory + pub fn save_to_directory( + &self, + output_dir: &Path, + eval_output_value: String, + ) -> anyhow::Result<()> { + // Create the output directory if it doesn't exist + fs::create_dir_all(&output_dir)?; + + // Save the diff to a file + let diff_path = output_dir.join("diff.patch"); + let mut diff_file = fs::File::create(&diff_path)?; + diff_file.write_all(self.diff.as_bytes())?; + + // Save the last message to a file + let message_path = output_dir.join("assistant_response.txt"); + let mut message_file = fs::File::create(&message_path)?; + message_file.write_all(self.last_message.as_bytes())?; + + // Current metrics for this run + let current_metrics = serde_json::json!({ + "elapsed_time_ms": self.elapsed_time.as_millis(), + "assistant_response_count": self.assistant_response_count, + "tool_use_counts": self.tool_use_counts, + "token_usage": self.token_usage, + "eval_output_value": eval_output_value, + }); + + // Get current timestamp in milliseconds + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH)? + .as_millis() + .to_string(); + + // Path to metrics file + let metrics_path = output_dir.join("metrics.json"); + + // Load existing metrics if the file exists, or create a new object + let mut historical_metrics = if metrics_path.exists() { + let metrics_content = fs::read_to_string(&metrics_path)?; + serde_json::from_str::(&metrics_content) + .unwrap_or_else(|_| serde_json::json!({})) + } else { + serde_json::json!({}) + }; + + // Add new run with timestamp as key + if let serde_json::Value::Object(ref mut map) = historical_metrics { + map.insert(timestamp, current_metrics); + } + + // Write updated metrics back to file + let metrics_json = serde_json::to_string_pretty(&historical_metrics)?; + let mut metrics_file = fs::File::create(&metrics_path)?; + metrics_file.write_all(metrics_json.as_bytes())?; + + Ok(()) + } +} + +fn repo_dir_name(url: &str) -> String { + url.trim_start_matches("https://") + .replace(|c: char| !c.is_alphanumeric(), "_") +} + +async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> { + if !repo_path.exists() { + smol::unblock({ + let repo_path = repo_path.to_path_buf(); + || std::fs::create_dir_all(repo_path) + }) + .await?; + run_git(repo_path, vec!["init"]).await?; + run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?; + } else { + let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?; + if actual_origin != eval_setup.url { + return Err(anyhow!( + "remote origin {} does not match expected origin {}", + actual_origin, + eval_setup.url + )); + } + + // TODO: consider including "-x" to remove ignored files. The downside of this is that it will + // also remove build artifacts, and so prevent incremental reuse there. + run_git(repo_path, vec!["clean", "--force", "-d"]).await?; + run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?; + } + + run_git( + repo_path, + vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha], + ) + .await?; + run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?; + + Ok(()) +} + +async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> { + let exit_status = new_smol_command("git") + .current_dir(repo_path) + .args(args.clone()) + .status() + .await?; + if exit_status.success() { + Ok(()) + } else { + Err(anyhow!( + "`git {}` failed with {}", + args.join(" "), + exit_status, + )) + } +} + +async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result { + let output = new_smol_command("git") + .current_dir(repo_path) + .args(args.clone()) + .output() + .await?; + if output.status.success() { + Ok(String::from_utf8(output.stdout)?.trim().to_string()) + } else { + Err(anyhow!( + "`git {}` failed with {}", + args.join(" "), + output.status + )) + } +} diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs new file mode 100644 index 0000000000..008b4f63da --- /dev/null +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -0,0 +1,241 @@ +use anyhow::anyhow; +use assistant2::{Thread, ThreadEvent, ThreadStore}; +use assistant_tool::ToolWorkingSet; +use client::{Client, UserStore}; +use collections::HashMap; +use futures::StreamExt; +use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task}; +use language::LanguageRegistry; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, + LanguageModelRequest, +}; +use node_runtime::NodeRuntime; +use project::{Project, RealFs}; +use prompt_store::PromptBuilder; +use settings::SettingsStore; +use smol::channel; +use std::sync::Arc; + +/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. +pub struct HeadlessAppState { + pub languages: Arc, + pub client: Arc, + pub user_store: Entity, + pub fs: Arc, + pub node_runtime: NodeRuntime, + + // Additional fields not present in `workspace::AppState`. + pub prompt_builder: Arc, +} + +pub struct HeadlessAssistant { + pub thread: Entity, + pub project: Entity, + #[allow(dead_code)] + pub thread_store: Entity, + pub tool_use_counts: HashMap, u32>, + pub done_tx: channel::Sender>, + _subscription: Subscription, +} + +impl HeadlessAssistant { + pub fn new( + app_state: Arc, + cx: &mut App, + ) -> anyhow::Result<(Entity, channel::Receiver>)> { + let env = None; + let project = Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + env, + cx, + ); + + let tools = Arc::new(ToolWorkingSet::default()); + let thread_store = + ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?; + + let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); + + let (done_tx, done_rx) = channel::unbounded::>(); + + let headless_thread = cx.new(move |cx| Self { + _subscription: cx.subscribe(&thread, Self::handle_thread_event), + thread, + project, + thread_store, + tool_use_counts: HashMap::default(), + done_tx, + }); + + Ok((headless_thread, done_rx)) + } + + fn handle_thread_event( + &mut self, + thread: Entity, + event: &ThreadEvent, + cx: &mut Context, + ) { + match event { + ThreadEvent::ShowError(err) => self + .done_tx + .send_blocking(Err(anyhow!("{:?}", err))) + .unwrap(), + ThreadEvent::DoneStreaming => { + let thread = thread.read(cx); + if let Some(message) = thread.messages().last() { + println!("Message: {}", message.text,); + } + if thread.all_tools_finished() { + self.done_tx.send_blocking(Ok(())).unwrap() + } + } + ThreadEvent::UsePendingTools => { + thread.update(cx, |thread, cx| { + thread.use_pending_tools(cx); + }); + } + ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + } => { + if let Some(pending_tool_use) = pending_tool_use { + println!( + "Used tool {} with input: {}", + pending_tool_use.name, pending_tool_use.input + ); + *self + .tool_use_counts + .entry(pending_tool_use.name.clone()) + .or_insert(0) += 1; + } + if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { + println!("Tool result: {:?}", tool_result); + } + if thread.read(cx).all_tools_finished() { + let model_registry = LanguageModelRegistry::read_global(cx); + if let Some(model) = model_registry.active_model() { + thread.update(cx, |thread, cx| { + // Currently evals do not support specifying context. + let updated_context = vec![]; + thread.send_tool_results_to_model(model, updated_context, cx); + }); + } + } + } + ThreadEvent::StreamedCompletion + | ThreadEvent::SummaryChanged + | ThreadEvent::StreamedAssistantText(_, _) + | ThreadEvent::MessageAdded(_) + | ThreadEvent::MessageEdited(_) + | ThreadEvent::MessageDeleted(_) => {} + } + } +} + +pub fn init(cx: &mut App) -> Arc { + release_channel::init(SemanticVersion::default(), cx); + gpui_tokio::init(cx); + + let mut settings_store = SettingsStore::new(cx); + settings_store + .set_default_settings(settings::default_settings().as_ref(), cx) + .unwrap(); + cx.set_global(settings_store); + client::init_settings(cx); + Project::init_settings(cx); + + let client = Client::production(cx); + cx.set_http_client(client.http_client().clone()); + + let git_binary_path = None; + let fs = Arc::new(RealFs::new(git_binary_path)); + + let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + + language::init(cx); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); + assistant_tools::init(cx); + context_server::init(cx); + let stdout_is_a_pty = false; + let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx); + assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx); + + Arc::new(HeadlessAppState { + languages, + client, + user_store, + fs, + node_runtime: NodeRuntime::unavailable(), + prompt_builder, + }) +} + +pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model = model_registry + .available_models(cx) + .find(|model| model.id().0 == model_name); + + let Some(model) = model else { + return Err(anyhow!( + "No language model named {} was available. Available models: {}", + model_name, + model_registry + .available_models(cx) + .map(|model| model.id().0.clone()) + .collect::>() + .join(", ") + )); + }; + + Ok(model) +} + +pub fn authenticate_model_provider( + provider_id: LanguageModelProviderId, + cx: &mut App, +) -> Task> { + let model_registry = LanguageModelRegistry::read_global(cx); + let model_provider = model_registry.provider(&provider_id).unwrap(); + model_provider.authenticate(cx) +} + +pub async fn send_language_model_request( + model: Arc, + request: LanguageModelRequest, + cx: AsyncApp, +) -> anyhow::Result { + match model.stream_completion_text(request, &cx).await { + Ok(mut stream) => { + let mut full_response = String::new(); + + // Process the response stream + while let Some(chunk_result) = stream.stream.next().await { + match chunk_result { + Ok(chunk_str) => { + full_response.push_str(&chunk_str); + } + Err(err) => { + return Err(anyhow!( + "Error receiving response from language model: {err}" + )); + } + } + } + + Ok(full_response) + } + Err(err) => Err(anyhow!( + "Failed to get response from language model. Error was: {err}" + )), + } +} diff --git a/crates/assistant_eval/src/judge.rs b/crates/assistant_eval/src/judge.rs new file mode 100644 index 0000000000..4b3e6f26a4 --- /dev/null +++ b/crates/assistant_eval/src/judge.rs @@ -0,0 +1,121 @@ +use crate::eval::EvalOutput; +use crate::headless_assistant::send_language_model_request; +use anyhow::anyhow; +use gpui::{App, Task}; +use language_model::{ + LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, +}; +use std::{path::Path, sync::Arc}; + +pub struct Judge { + pub original_diff: Option, + #[allow(dead_code)] + pub original_message: Option, + pub model: Arc, +} + +impl Judge { + pub async fn load(eval_path: &Path, model: Arc) -> anyhow::Result { + let original_diff_path = eval_path.join("original.diff"); + let original_diff = smol::unblock(move || { + if std::fs::exists(&original_diff_path)? { + anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?)) + } else { + anyhow::Ok(None) + } + }); + + let original_message_path = eval_path.join("original_message.txt"); + let original_message = smol::unblock(move || { + if std::fs::exists(&original_message_path)? { + anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?)) + } else { + anyhow::Ok(None) + } + }); + + Ok(Self { + original_diff: original_diff.await?, + original_message: original_message.await?, + model, + }) + } + + pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task> { + let Some(original_diff) = self.original_diff.as_ref() else { + return Task::ready(Err(anyhow!("No original.diff found"))); + }; + + // TODO: check for empty diff? + let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff); + + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(prompt)], + cache: false, + }], + temperature: Some(0.0), + tools: Vec::new(), + stop: Vec::new(), + }; + + let model = self.model.clone(); + cx.spawn(move |cx| send_language_model_request(model, request, cx)) + } +} + +pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String { + format!( + r#"# Git Diff Similarity Evaluation Template + +## Instructions + +Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity. +- 1.0 = Perfect functional match (achieves identical results) +- 0.0 = No functional similarity whatsoever + +## Evaluation Criteria + +Please consider the following aspects in order of importance: + +1. **Functional Equivalence (60%)** + - Do both diffs achieve the same end result? + - Are the changes functionally equivalent despite possibly using different approaches? + - Do the modifications address the same issues or implement the same features? + +2. **Logical Structure (20%)** + - Are the logical flows similar? + - Do the modifications affect the same code paths? + - Are control structures (if/else, loops, etc.) modified in similar ways? + +3. **Code Content (15%)** + - Are similar lines added/removed? + - Are the same variables, functions, or methods being modified? + - Are the same APIs or libraries being used? + +4. **File Layout (5%)** + - Are the same files being modified? + - Are changes occurring in similar locations within files? + +## Input + +Original Diff: +```git +{} +``` + +New Diff: +```git +{} +``` + +## Output Format + +THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0. + +Example output: +0.85"#, + original_diff, new_diff + ) +} diff --git a/crates/assistant_eval/src/main.rs b/crates/assistant_eval/src/main.rs new file mode 100644 index 0000000000..316aaf04ec --- /dev/null +++ b/crates/assistant_eval/src/main.rs @@ -0,0 +1,234 @@ +mod eval; +mod headless_assistant; +mod judge; + +use clap::Parser; +use eval::{Eval, EvalOutput}; +use futures::{stream, StreamExt}; +use gpui::{Application, AsyncApp}; +use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState}; +use itertools::Itertools; +use judge::Judge; +use language_model::{LanguageModel, LanguageModelRegistry}; +use regex::Regex; +use reqwest_client::ReqwestClient; +use std::{cmp, path::PathBuf, sync::Arc}; + +#[derive(Parser, Debug)] +#[command( + name = "assistant_eval", + disable_version_flag = true, + before_help = "Tool eval runner" +)] +struct Args { + /// Regexes to match the names of evals to run. + eval_name_regexes: Vec, + /// Runs all evals in `evaluation_data`, causes the regex to be ignored. + #[arg(long)] + all: bool, + /// Name of the model (default: "claude-3-7-sonnet-latest") + #[arg(long, default_value = "claude-3-7-sonnet-latest")] + model_name: String, + /// Name of the editor model (default: value of `--model_name`). + #[arg(long)] + editor_model_name: Option, + /// Name of the judge model (default: value of `--model_name`). + #[arg(long)] + judge_model_name: Option, + /// Number of evaluations to run concurrently (default: 10) + #[arg(short, long, default_value = "10")] + concurrency: usize, +} + +fn main() { + env_logger::init(); + let args = Args::parse(); + let http_client = Arc::new(ReqwestClient::new()); + let app = Application::headless().with_http_client(http_client.clone()); + + let crate_dir = PathBuf::from("../zed-agent-bench"); + let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap(); + let repos_dir = crate_dir.join("repos").canonicalize().unwrap(); + + let all_evals = std::fs::read_dir(&evaluation_data_dir) + .unwrap() + .map(|path| path.unwrap().file_name().to_string_lossy().to_string()) + .collect::>(); + + let evals_to_run = if args.all { + all_evals + } else { + args.eval_name_regexes + .into_iter() + .map(|regex_string| Regex::new(®ex_string).unwrap()) + .flat_map(|regex| { + all_evals + .iter() + .filter(|eval_name| regex.is_match(eval_name)) + .cloned() + .collect::>() + }) + .collect::>() + }; + + if evals_to_run.is_empty() { + panic!("Names of evals to run must be provided or `--all` specified"); + } + + println!("Will run the following evals: {evals_to_run:?}"); + println!("Running up to {} evals concurrently", args.concurrency); + + let editor_model_name = if let Some(model_name) = args.editor_model_name { + model_name + } else { + args.model_name.clone() + }; + + let judge_model_name = if let Some(model_name) = args.judge_model_name { + model_name + } else { + args.model_name.clone() + }; + + app.run(move |cx| { + let app_state = headless_assistant::init(cx); + + let model = find_model(&args.model_name, cx).unwrap(); + let editor_model = find_model(&editor_model_name, cx).unwrap(); + let judge_model = find_model(&judge_model_name, cx).unwrap(); + + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_active_model(Some(model.clone()), cx); + registry.set_editor_model(Some(editor_model.clone()), cx); + }); + + let model_provider_id = model.provider_id(); + let editor_model_provider_id = editor_model.provider_id(); + let judge_model_provider_id = judge_model.provider_id(); + + cx.spawn(move |cx| async move { + // Authenticate all model providers first + cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx)) + .unwrap() + .await + .unwrap(); + cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx)) + .unwrap() + .await + .unwrap(); + cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx)) + .unwrap() + .await + .unwrap(); + + let loaded_evals = stream::iter(evals_to_run) + .map(|eval_name| { + let eval_path = evaluation_data_dir.join(&eval_name); + let repos_dir = repos_dir.clone(); + async move { + match Eval::load(eval_name.clone(), eval_path, &repos_dir).await { + Ok(eval) => Some(eval), + Err(err) => { + // TODO: Persist errors / surface errors at the end. + println!("Error loading {eval_name}: {err}"); + None + } + } + } + }) + .buffer_unordered(args.concurrency) + .collect::>() + .await + .into_iter() + .flatten() + .collect::>(); + + // The evals need to be loaded and grouped by URL before concurrently running, since + // evals that use the same remote URL will use the same working directory. + let mut evals_grouped_by_url: Vec> = loaded_evals + .into_iter() + .map(|eval| (eval.eval_setup.url.clone(), eval)) + .into_group_map() + .into_values() + .collect::>(); + + // Sort groups in descending order, so that bigger groups start first. + evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len())); + + let results = stream::iter(evals_grouped_by_url) + .map(|evals| { + let model = model.clone(); + let judge_model = judge_model.clone(); + let app_state = app_state.clone(); + let cx = cx.clone(); + + async move { + let mut results = Vec::new(); + for eval in evals { + let name = eval.name.clone(); + println!("Starting eval named {}", name); + let result = run_eval( + eval, + model.clone(), + judge_model.clone(), + app_state.clone(), + cx.clone(), + ) + .await; + results.push((name, result)); + } + results + } + }) + .buffer_unordered(args.concurrency) + .collect::>() + .await + .into_iter() + .flatten() + .collect::>(); + + // Process results in order of completion + for (eval_name, result) in results { + match result { + Ok((eval_output, judge_output)) => { + println!("Generated diff for {eval_name}:\n"); + println!("{}\n", eval_output.diff); + println!("Last message for {eval_name}:\n"); + println!("{}\n", eval_output.last_message); + println!("Elapsed time: {:?}", eval_output.elapsed_time); + println!( + "Assistant response count: {}", + eval_output.assistant_response_count + ); + println!("Tool use counts: {:?}", eval_output.tool_use_counts); + println!("Judge output for {eval_name}: {judge_output}"); + } + Err(err) => { + // TODO: Persist errors / surface errors at the end. + println!("Error running {eval_name}: {err}"); + } + } + } + + cx.update(|cx| cx.quit()).unwrap(); + }) + .detach(); + }); + + println!("Done running evals"); +} + +async fn run_eval( + eval: Eval, + model: Arc, + judge_model: Arc, + app_state: Arc, + cx: AsyncApp, +) -> anyhow::Result<(EvalOutput, String)> { + let path = eval.path.clone(); + let judge = Judge::load(&path, judge_model).await?; + let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?; + let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?; + eval_output.save_to_directory(&path, judge_output.to_string())?; + Ok((eval_output, judge_output)) +}