From 6b80eb556c3fcbcb52fe669be00a086cf9c9ff26 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Mon, 14 Apr 2025 14:18:47 -0600 Subject: [PATCH] Add judge to new eval + provide LSP diagnostics (#28713) Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: agus --- Cargo.lock | 11 + crates/agent/src/thread.rs | 16 +- crates/eval/.gitignore | 3 + crates/eval/Cargo.toml | 13 +- .../find_and_replace_diff_card/base.toml | 3 +- .../find_and_replace_diff_card/criteria.md | 2 + .../find_and_replace_diff_card/rubric.md | 0 crates/eval/src/eval.rs | 247 ++++++- crates/eval/src/example.rs | 600 ++++++++++++++++-- crates/eval/src/judge_prompt.hbs | 25 + crates/language_model/src/language_model.rs | 2 +- 11 files changed, 838 insertions(+), 84 deletions(-) create mode 100644 crates/eval/.gitignore create mode 100644 crates/eval/examples/find_and_replace_diff_card/criteria.md delete mode 100644 crates/eval/examples/find_and_replace_diff_card/rubric.md create mode 100644 crates/eval/src/judge_prompt.hbs diff --git a/Cargo.lock b/Cargo.lock index 911ea64417..73761469b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4878,25 +4878,36 @@ dependencies = [ "assistant_settings", "assistant_tool", "assistant_tools", + "async-watch", + "chrono", + "clap", "client", "context_server", "dap", "env_logger 0.11.8", + "extension", "fs", "futures 0.3.31", "gpui", "gpui_tokio", + "handlebars 4.5.0", "language", + "language_extension", "language_model", "language_models", + "languages", "node_runtime", + "paths", "project", "prompt_store", "release_channel", "reqwest_client", "serde", "settings", + "shellexpand 2.1.2", "toml 0.8.20", + "unindent", + "util", "workspace-hack", ] diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1cab03a46f..fe4844bd86 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -827,7 +827,7 @@ impl Thread { }) .collect(), initial_project_snapshot, - cumulative_token_usage: this.cumulative_token_usage.clone(), + cumulative_token_usage: this.cumulative_token_usage, detailed_summary_state: this.detailed_summary_state.clone(), exceeded_window_error: this.exceeded_window_error.clone(), }) @@ -1016,7 +1016,7 @@ impl Thread { let task = cx.spawn(async move |thread, cx| { let stream = model.stream_completion(request, &cx); let initial_token_usage = - thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone()); + thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; @@ -1038,9 +1038,9 @@ impl Thread { stop_reason = reason; } LanguageModelCompletionEvent::UsageUpdate(token_usage) => { - thread.cumulative_token_usage = - thread.cumulative_token_usage.clone() + token_usage.clone() - - current_token_usage.clone(); + thread.cumulative_token_usage = thread.cumulative_token_usage + + token_usage + - current_token_usage; current_token_usage = token_usage; } LanguageModelCompletionEvent::Text(chunk) => { @@ -1183,7 +1183,7 @@ impl Thread { thread.auto_capture_telemetry(cx); if let Ok(initial_usage) = initial_token_usage { - let usage = thread.cumulative_token_usage.clone() - initial_usage; + let usage = thread.cumulative_token_usage - initial_usage; telemetry::event!( "Assistant Thread Completion", @@ -1862,6 +1862,10 @@ impl Thread { .detach(); } + pub fn cumulative_token_usage(&self) -> TokenUsage { + self.cumulative_token_usage + } + pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { let model_registry = LanguageModelRegistry::read_global(cx); let Some(model) = model_registry.default_model() else { diff --git a/crates/eval/.gitignore b/crates/eval/.gitignore new file mode 100644 index 0000000000..89fb02c122 --- /dev/null +++ b/crates/eval/.gitignore @@ -0,0 +1,3 @@ +repos/ +worktrees/ +runs/ diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index 0249c24dcf..6828de36fc 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -7,28 +7,39 @@ edition.workspace = true [dependencies] agent.workspace = true anyhow.workspace = true +async-watch.workspace = true +assistant_settings.workspace = true assistant_tool.workspace = true assistant_tools.workspace = true -assistant_settings.workspace = true +chrono.workspace = true +clap.workspace = true client.workspace = true context_server.workspace = true dap.workspace = true env_logger.workspace = true +extension.workspace = true fs.workspace = true futures.workspace = true gpui.workspace = true gpui_tokio.workspace = true +handlebars.workspace = true language.workspace = true +language_extension.workspace = true language_model.workspace = true language_models.workspace = true +languages.workspace = true node_runtime.workspace = true +paths.workspace = true project.workspace = true prompt_store.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true settings.workspace = true +shellexpand.workspace = true toml.workspace = true +unindent.workspace = true +util.workspace = true workspace-hack.workspace = true [[bin]] diff --git a/crates/eval/examples/find_and_replace_diff_card/base.toml b/crates/eval/examples/find_and_replace_diff_card/base.toml index 2b14a64530..c88298997d 100644 --- a/crates/eval/examples/find_and_replace_diff_card/base.toml +++ b/crates/eval/examples/find_and_replace_diff_card/base.toml @@ -1,2 +1,3 @@ -path = "../zed_worktree" +url = "https://github.com/zed-industries/zed.git" revision = "38fcadf9481d018543c65f36ac3bafeba190179b" +language_extension = "rs" diff --git a/crates/eval/examples/find_and_replace_diff_card/criteria.md b/crates/eval/examples/find_and_replace_diff_card/criteria.md new file mode 100644 index 0000000000..393056f134 --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/criteria.md @@ -0,0 +1,2 @@ +1. The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the string we were returning before, and a new `card` field that contains a view for the card +2. The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged. diff --git a/crates/eval/examples/find_and_replace_diff_card/rubric.md b/crates/eval/examples/find_and_replace_diff_card/rubric.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 88cca63852..cfdb00b655 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -1,32 +1,75 @@ mod example; use assistant_settings::AssistantSettings; -use client::{Client, UserStore}; +use client::{Client, ProxySettings, UserStore}; pub(crate) use example::*; use ::fs::RealFs; -use anyhow::anyhow; -use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task}; +use anyhow::{Result, anyhow}; +use clap::Parser; +use extension::ExtensionHostProxy; +use futures::future; +use gpui::http_client::{Uri, read_proxy_from_env}; +use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task}; +use gpui_tokio::Tokio; use language::LanguageRegistry; use language_model::{ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry, }; -use node_runtime::NodeRuntime; +use node_runtime::{NodeBinaryOptions, NodeRuntime}; use project::Project; +use project::project_settings::ProjectSettings; use prompt_store::PromptBuilder; +use release_channel::AppVersion; use reqwest_client::ReqwestClient; use settings::{Settings, SettingsStore}; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use util::ResultExt as _; + +pub const RUNS_DIR: &str = "./crates/eval/runs"; + +#[derive(Parser, Debug)] +#[command(name = "eval", disable_version_flag = true)] +struct Args { + /// Runs all examples that contain these substrings. If unspecified, all examples are run. + #[arg(value_name = "EXAMPLE_SUBSTRING")] + examples: Vec, + /// Model to use (default: "claude-3-7-sonnet-latest") + #[arg(long, default_value = "claude-3-7-sonnet-latest")] + model: String, +} fn main() { env_logger::init(); + + let args = Args::parse(); + let all_available_examples = list_all_examples().unwrap(); + let example_paths = all_available_examples + .iter() + .filter_map(|example_path| { + let name = example_path.file_name()?.to_string_lossy(); + if args.examples.is_empty() + || args + .examples + .iter() + .any(|name_substring| name.contains(name_substring)) + { + Some(example_path.clone()) + } else { + None + } + }) + .collect::>(); + let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client.clone()); app.run(move |cx| { let app_state = init(cx); - let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap(); + let model = find_model("claude-3-7-sonnet-latest", cx).unwrap(); LanguageModelRegistry::global(cx).update(cx, |registry, cx| { registry.set_default_model(Some(model.clone()), cx); @@ -39,17 +82,142 @@ fn main() { cx.spawn(async move |cx| { authenticate.await.unwrap(); - let example = - Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?; - example.setup()?; - cx.update(|cx| example.run(model, app_state, cx))?.await?; + std::fs::create_dir_all(REPOS_DIR)?; + std::fs::create_dir_all(WORKTREES_DIR)?; - anyhow::Ok(()) + let run_dir = Path::new(RUNS_DIR).join(format!( + "{}", + chrono::Local::now().format("%Y-%m-%d_%H-%M-%S") + )); + std::fs::create_dir_all(&run_dir)?; + + let mut examples = Vec::new(); + for example_path in example_paths { + let example = Example::load_from_directory(&example_path, &run_dir)?; + examples.push((example_path, example)); + } + let mut repo_urls = HashSet::new(); + + let mut clone_tasks = Vec::new(); + + for (_, example) in examples.iter() { + let repo_url = example.base.url.clone(); + if repo_urls.insert(repo_url.clone()) { + let repo_path = repo_path_for_url(&repo_url); + + if !repo_path.join(".git").is_dir() { + println!("Cloning: {}", repo_url); + + let git_task = cx.spawn(async move |_cx| { + std::fs::create_dir_all(&repo_path)?; + run_git(&repo_path, &["init"]).await?; + run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await + }); + + clone_tasks.push(git_task); + } else { + println!("Already cloned: {}", repo_url); + + let actual_origin = + run_git(&repo_path, &["remote", "get-url", "origin"]).await?; + if actual_origin != repo_url { + return Err(anyhow!( + "remote origin {} does not match expected origin {}", + actual_origin, + repo_url, + )); + } + } + } + } + + future::join_all(clone_tasks).await; + + let tasks = examples + .into_iter() + .map(|(example_path, example)| { + let app_state = app_state.clone(); + let model = model.clone(); + cx.spawn(async move |cx| { + ( + example_path, + run_example(example, model, app_state, cx).await, + ) + }) + }) + .collect::>(); + + let results: Vec<(PathBuf, Result)> = future::join_all(tasks).await; + + println!("\n\n"); + println!("========================================"); + println!(" EVAL RESULTS "); + println!("========================================"); + println!(""); + + let mut judge_scores = Vec::new(); + + for (example_path, result) in results { + let example_name = example_path.file_name().unwrap().to_string_lossy(); + match result { + Err(err) => { + println!("💥 {:<30}: {:?}", example_name, err); + } + Ok(judge_output) => { + const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"]; + + println!( + "{} {:<30}: {}", + SCORES[judge_output.score.min(5) as usize], + example_name, + judge_output.score, + ); + judge_scores.push(judge_output.score); + } + } + } + + let score_count = judge_scores.len(); + let average_score = judge_scores + .into_iter() + .map(|score| score as f32) + .sum::() + / (score_count as f32); + println!("\nAverage score: {average_score}"); + + cx.update(|cx| cx.quit()) }) .detach_and_log_err(cx); }); } +async fn run_example( + mut example: Example, + model: Arc, + app_state: Arc, + cx: &mut AsyncApp, +) -> Result { + example.setup().await?; + cx.update(|cx| example.run(model.clone(), app_state, cx))? + .await?; + let diff = example.repository_diff().await?; + example.judge(model, diff, cx).await +} + +fn list_all_examples() -> Result> { + let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap(); + let entries = std::fs::read_dir(path).unwrap(); + let mut result_paths = Vec::new(); + for entry in entries { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + result_paths.push(path); + } + } + Ok(result_paths) +} + /// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields. pub struct AgentAppState { pub languages: Arc, @@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc { .unwrap(); cx.set_global(settings_store); client::init_settings(cx); + + // Set User-Agent so we can download language servers from GitHub + let user_agent = format!( + "Zed/{} ({}; {})", + AppVersion::global(cx), + std::env::consts::OS, + std::env::consts::ARCH + ); + let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); + let proxy_url = proxy_str + .as_ref() + .and_then(|input| input.parse::().ok()) + .or_else(read_proxy_from_env); + let http = { + let _guard = Tokio::handle(cx).enter(); + + ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent) + .expect("could not start HTTP client") + }; + cx.set_http_client(Arc::new(http)); + Project::init_settings(cx); let client = Client::production(cx); @@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc { cx.background_executor().clone(), )); - let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone())); + let mut languages = LanguageRegistry::new(cx.background_executor().clone()); + languages.set_language_server_download_dir(paths::languages_dir().clone()); + let languages = Arc::new(languages); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + extension::init(cx); + + let (tx, rx) = async_watch::channel(None); + cx.observe_global::(move |cx| { + let settings = &ProjectSettings::get_global(cx).node; + let options = NodeBinaryOptions { + allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(), + allow_binary_download: true, + use_paths: settings.path.as_ref().map(|node_path| { + let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref()); + let npm_path = settings + .npm_path + .as_ref() + .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref())); + ( + node_path.clone(), + npm_path.unwrap_or_else(|| { + let base_path = PathBuf::new(); + node_path.parent().unwrap_or(&base_path).join("npm") + }), + ) + }), + }; + tx.send(Some(options)).log_err(); + }) + .detach(); + let node_runtime = NodeRuntime::new(client.http_client().clone(), rx); + + let extension_host_proxy = ExtensionHostProxy::global(cx); + language::init(cx); + language_extension::init(extension_host_proxy.clone(), languages.clone()); language_model::init(client.clone(), cx); language_models::init(user_store.clone(), client.clone(), fs.clone(), cx); + languages::init(languages.clone(), node_runtime.clone(), cx); assistant_tools::init(client.http_client().clone(), cx); context_server::init(cx); let stdout_is_a_pty = false; @@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc { client, user_store, fs, - node_runtime: NodeRuntime::unavailable(), + node_runtime, prompt_builder, }) } diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 6b45eddb60..c1ffaa51fe 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -1,83 +1,161 @@ use agent::{RequestKind, ThreadEvent, ThreadStore}; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use assistant_tool::ToolWorkingSet; +use client::proto::LspWorkProgress; use dap::DapRegistry; -use futures::channel::oneshot; -use gpui::{App, Task}; -use language_model::{LanguageModel, StopReason}; -use project::Project; -use serde::Deserialize; -use std::process::Command; -use std::sync::Arc; +use futures::channel::{mpsc, oneshot}; +use futures::{FutureExt, StreamExt as _}; +use gpui::{App, AsyncApp, Entity, Task}; +use handlebars::Handlebars; +use language::{DiagnosticSeverity, OffsetRangeExt}; +use language_model::{ + LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, + StopReason, TokenUsage, +}; +use project::{LspStore, Project, ProjectPath}; +use serde::{Deserialize, Serialize}; +use std::fmt::Write as _; +use std::fs::File; +use std::io::Write as _; +use std::sync::{Arc, Mutex}; +use std::time::Duration; use std::{ fs, path::{Path, PathBuf}, }; +use unindent::Unindent as _; +use util::ResultExt as _; +use util::command::new_smol_command; +use util::serde::default_true; use crate::AgentAppState; -#[derive(Debug, Deserialize)] +pub const EXAMPLES_DIR: &str = "./crates/eval/examples"; +pub const REPOS_DIR: &str = "./crates/eval/repos"; +pub const WORKTREES_DIR: &str = "./crates/eval/worktrees"; + +#[derive(Clone, Debug, Deserialize)] pub struct ExampleBase { - pub path: PathBuf, + pub url: String, pub revision: String, + pub language_extension: Option, + pub insert_id: Option, + #[serde(default = "default_true")] + pub require_lsp: bool, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Example { + pub name: String, + /// Content of `base.toml` pub base: ExampleBase, - - /// Content of the prompt.md file + /// Content of `prompt.md` pub prompt: String, + /// Content of `criteria.md` + pub criteria: String, + /// Markdown log file to append to + pub log_file: Arc>, +} - /// Content of the rubric.md file - pub _rubric: String, +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct RunOutput { + pub repository_diff: String, + pub diagnostics: String, + pub response_count: usize, + pub token_usage: TokenUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeInput { + pub repository_diff: String, + pub criteria: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeOutput { + pub analysis: String, + pub score: u32, } impl Example { - /// Load an example from a directory containing base.toml, prompt.md, and rubric.md - pub fn load_from_directory>(dir_path: P) -> Result { - let base_path = dir_path.as_ref().join("base.toml"); - let prompt_path = dir_path.as_ref().join("prompt.md"); - let rubric_path = dir_path.as_ref().join("rubric.md"); + /// Load an example from a directory containing base.toml, prompt.md, and criteria.md + pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result { + let name = dir_path.file_name().unwrap().to_string_lossy().to_string(); + let base_path = dir_path.join("base.toml"); + let prompt_path = dir_path.join("prompt.md"); + let criteria_path = dir_path.join("criteria.md"); - let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?; - base.path = base.path.canonicalize()?; + let log_file_path = run_dir.join(format!( + "{}.md", + dir_path.file_name().unwrap().to_str().unwrap() + )); + let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap())); + println!("{}> Logging to {:?}", name, log_file_path); Ok(Example { - base, - prompt: fs::read_to_string(prompt_path)?, - _rubric: fs::read_to_string(rubric_path)?, + name, + base: toml::from_str(&fs::read_to_string(&base_path)?)?, + prompt: fs::read_to_string(prompt_path.clone())?, + criteria: fs::read_to_string(criteria_path.clone())?, + log_file, }) } - /// Set up the example by checking out the specified Git revision - pub fn setup(&self) -> Result<()> { - // Check if the directory exists - let path = Path::new(&self.base.path); - anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path); + pub fn worktree_path(&self) -> PathBuf { + Path::new(WORKTREES_DIR) + .canonicalize() + .context(format!("No such directory {WORKTREES_DIR}")) + .unwrap() + .join(&self.name) + } - // Change to the project directory and checkout the specified revision - let output = Command::new("git") - .current_dir(&self.base.path) - .arg("checkout") - .arg(&self.base.revision) - .output()?; - anyhow::ensure!( - output.status.success(), - "Failed to checkout revision {}: {}", - self.base.revision, - String::from_utf8_lossy(&output.stderr), - ); + /// Set up the example by checking out the specified Git revision + pub async fn setup(&self) -> Result<()> { + let repo_path = repo_path_for_url(&self.base.url); + + run_git( + &repo_path, + &["fetch", "--depth", "1", "origin", &self.base.revision], + ) + .await?; + + let worktree_path = self.worktree_path(); + + if worktree_path.is_dir() { + println!("{}> Resetting existing worktree", self.name); + + // TODO: consider including "-x" to remove ignored files. The downside of this is that + // it will also remove build artifacts, and so prevent incremental reuse there. + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", &self.base.revision]).await?; + } else { + println!("{}> Creating worktree", self.name); + + let worktree_path_string = worktree_path.to_string_lossy().to_string(); + + run_git( + &repo_path, + &[ + "worktree", + "add", + "-f", + &worktree_path_string, + &self.base.revision, + ], + ) + .await?; + } Ok(()) } pub fn run( - self, + &self, model: Arc, app_state: Arc, cx: &mut App, - ) -> Task> { + ) -> Task> { let project = Project::local( app_state.client.clone(), app_state.node_runtime.clone(), @@ -89,30 +167,119 @@ impl Example { cx, ); + let worktree_path = self.worktree_path(); let worktree = project.update(cx, |project, cx| { - project.create_worktree(self.base.path, true, cx) + project.create_worktree(&worktree_path, true, cx) }); let tools = Arc::new(ToolWorkingSet::default()); let thread_store = ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx); + let this = self.clone(); - println!("USER:"); - println!("{}", self.prompt); - println!("ASSISTANT:"); cx.spawn(async move |cx| { - worktree.await?; + let worktree = worktree.await?; + + // Wait for worktree scan to finish before choosing a file to open. + worktree + .update(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + })? + .await; + + let lsp_open_handle_and_store = if this.base.require_lsp { + let language_extension = this.base.language_extension.as_deref().context( + "language_extension field is required in base.toml when `require_lsp == true`", + )?; + + // Open a file that matches the language to cause LSP to start. + let language_file = worktree.read_with(cx, |worktree, _cx| { + worktree + .files(false, 0) + .find_map(|e| { + if e.path.clone().extension().and_then(|ext| ext.to_str()) + == Some(language_extension) + { + Some(ProjectPath { + worktree_id: worktree.id(), + path: e.path.clone(), + }) + } else { + None + } + }) + .context("Failed to find a file for example language") + })??; + + let open_language_file_buffer_task = project.update(cx, |project, cx| { + project.open_buffer(language_file.clone(), cx) + })?; + + let language_file_buffer = open_language_file_buffer_task.await?; + + let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| { + ( + project.register_buffer_with_language_servers(&language_file_buffer, cx), + project.lsp_store().clone(), + ) + })?; + + // TODO: remove this once the diagnostics tool waits for new diagnostics + cx.background_executor().timer(Duration::new(5, 0)).await; + wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?; + + lsp_store.update(cx, |lsp_store, cx| { + lsp_open_handle.update(cx, |buffer, cx| { + buffer.update(cx, |buffer, cx| { + let has_language_server = lsp_store + .language_servers_for_local_buffer(buffer, cx) + .next() + .is_some(); + if has_language_server { + Ok(()) + } else { + Err(anyhow!( + "`{:?}` was opened to cause the language server to start, \ + but no language servers are registered for its buffer. \ + Set `require_lsp = false` in `base.toml` to skip this.", + language_file + )) + } + }) + }) + })??; + + Some((lsp_open_handle, lsp_store)) + } else { + None + }; + + if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { + return Err(anyhow!("Setup only mode")); + } + let thread_store = thread_store.await; let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; + { + let mut log_file = this.log_file.lock().unwrap(); + writeln!(&mut log_file, "👤 USER:").log_err(); + writeln!(&mut log_file, "{}", this.prompt).log_err(); + writeln!(&mut log_file, "🤖 ASSISTANT:").log_err(); + log_file.flush().log_err(); + } + let (tx, rx) = oneshot::channel(); let mut tx = Some(tx); - let _subscription = - cx.subscribe( - &thread, - move |thread, event: &ThreadEvent, cx| match event { + let _subscription = cx.subscribe(&thread, { + let log_file = this.log_file.clone(); + let name = this.name.clone(); + move |thread, event: &ThreadEvent, cx| { + let mut log_file = log_file.lock().unwrap(); + + match event { ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { if let Some(tx) = tx.take() { @@ -137,15 +304,16 @@ impl Example { } } ThreadEvent::StreamedAssistantText(_, chunk) => { - print!("{}", chunk); + write!(&mut log_file, "{}", chunk).log_err(); } ThreadEvent::StreamedAssistantThinking(_, chunk) => { - print!("{}", chunk); + write!(&mut log_file, "{}", chunk).log_err(); } ThreadEvent::UsePendingTools { tool_uses } => { - println!("\n\nUSING TOOLS:"); + writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err(); for tool_use in tool_uses { - println!("{}: {}", tool_use.name, tool_use.input); + writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input) + .log_err(); } } ThreadEvent::ToolFinished { @@ -154,25 +322,331 @@ impl Example { .. } => { if let Some(tool_use) = pending_tool_use { - println!("\nTOOL FINISHED: {}", tool_use.name); + let message = format!("TOOL FINISHED: {}", tool_use.name); + println!("{name}> {message}"); + writeln!(&mut log_file, "\n{}", message).log_err(); } if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) { - println!("\n{}\n", tool_result.content); + let message = format!("\n{}\n", tool_result.content); + writeln!(&mut log_file, "{}", message).log_err(); } } _ => {} - }, - )?; + } + + log_file.flush().log_err(); + } + })?; thread.update(cx, |thread, cx| { let context = vec![]; - thread.insert_user_message(self.prompt.clone(), context, None, cx); + thread.insert_user_message(this.prompt.clone(), context, None, cx); thread.send_to_model(model, RequestKind::Chat, cx); })?; rx.await??; - Ok(()) + if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() { + wait_for_lang_server(lsp_store, this.name.clone(), cx).await?; + } + + let repository_diff = this.repository_diff().await?; + let diagnostics = cx + .update(move |cx| { + cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) + })? + .await?; + + drop(lsp_open_handle_and_store); + + thread.update(cx, |thread, _cx| { + let response_count = thread + .messages() + .filter(|message| message.role == language_model::Role::Assistant) + .count(); + RunOutput { + repository_diff, + diagnostics, + response_count, + token_usage: thread.cumulative_token_usage(), + } + }) }) } + + pub async fn judge( + &mut self, + model: Arc, + repository_diff: String, + cx: &AsyncApp, + ) -> Result { + let judge_prompt = include_str!("judge_prompt.hbs"); + let judge_prompt_name = "judge_prompt"; + let mut handlebars = Handlebars::new(); + handlebars.register_template_string(judge_prompt_name, judge_prompt)?; + let prompt = handlebars.render( + judge_prompt_name, + &JudgeInput { + repository_diff, + criteria: self.criteria.clone(), + }, + )?; + + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(prompt)], + cache: false, + }], + temperature: None, + tools: Vec::new(), + stop: Vec::new(), + }; + + let response = send_language_model_request(model, request, cx).await?; + + let mut log_file = self.log_file.lock().unwrap(); + + writeln!(&mut log_file, "\n\n").log_err(); + writeln!(&mut log_file, "========================================").log_err(); + writeln!(&mut log_file, " JUDGE OUTPUT ").log_err(); + writeln!(&mut log_file, "========================================").log_err(); + writeln!(&mut log_file, "\n{}", &response).log_err(); + + parse_judge_output(&response) + } + + pub async fn repository_diff(&self) -> Result { + let worktree_path = self.worktree_path(); + run_git(&worktree_path, &["add", "-N"]).await?; + run_git(&worktree_path, &["diff"]).await + } +} + +fn wait_for_lang_server( + lsp_store: &Entity, + name: String, + cx: &mut AsyncApp, +) -> Task> { + if cx + .update(|cx| !has_pending_lang_server_work(lsp_store, cx)) + .unwrap() + || std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok() + { + return Task::ready(anyhow::Ok(())); + } + + println!("{}> ⏵ Waiting for language server", name); + + let (mut tx, mut rx) = mpsc::channel(1); + + let subscription = + cx.subscribe(&lsp_store, { + let name = name.clone(); + move |lsp_store, event, cx| { + match event { + project::LspStoreEvent::LanguageServerUpdate { + message: + client::proto::update_language_server::Variant::WorkProgress( + LspWorkProgress { + message: Some(message), + .. + }, + ), + .. + } => println!("{name}> ⟲ {message}"), + _ => {} + } + + if !has_pending_lang_server_work(&lsp_store, cx) { + tx.try_send(()).ok(); + } + } + }); + + cx.spawn(async move |cx| { + let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0)); + let result = futures::select! { + _ = rx.next() => { + println!("{}> ⚑ Language server idle", name); + anyhow::Ok(()) + }, + _ = timeout.fuse() => { + Err(anyhow!("LSP wait timed out after 5 minutes")) + } + }; + drop(subscription); + result + }) +} + +fn has_pending_lang_server_work(lsp_store: &Entity, cx: &App) -> bool { + lsp_store + .read(cx) + .language_server_statuses() + .any(|(_, status)| !status.pending_work.is_empty()) +} + +async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> Result { + let paths_with_diagnostics = project.update(cx, |project, cx| { + project + .diagnostic_summaries(true, cx) + .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0) + .map(|(project_path, _, _)| project_path) + .collect::>() + })?; + + let mut output = String::new(); + for project_path in paths_with_diagnostics { + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + for (_, group) in snapshot.diagnostic_groups(None) { + let entry = &group.entries[group.primary_ix]; + let range = entry.range.to_point(&snapshot); + let severity = match entry.diagnostic.severity { + DiagnosticSeverity::ERROR => "error", + DiagnosticSeverity::WARNING => "warning", + _ => continue, + }; + + writeln!( + output, + "{} at line {}: {}", + severity, + range.start.row + 1, + entry.diagnostic.message + )?; + } + } + anyhow::Ok(output) +} + +fn parse_judge_output(response: &str) -> Result { + let analysis = get_tag("analysis", response)?.to_string(); + let score = get_tag("score", response)? + .parse() + .context("error parsing score")?; + + Ok(JudgeOutput { analysis, score }) +} + +fn get_tag(name: &'static str, response: &str) -> Result { + let start_tag = format!("<{}>", name); + let end_tag = format!("", name); + + let start_ix = response + .find(&start_tag) + .context(format!("{} start tag not found", name))?; + let content_start_ix = start_ix + start_tag.len(); + + let end_ix = content_start_ix + + response[content_start_ix..] + .find(&end_tag) + .context(format!("{} end tag not found", name))?; + + let content = response[content_start_ix..end_ix].trim().unindent(); + + anyhow::Ok(content) +} + +pub fn repo_path_for_url(repo_url: &str) -> PathBuf { + let repo_name = repo_url + .trim_start_matches("https://") + .replace(|c: char| !c.is_alphanumeric(), "-"); + Path::new(REPOS_DIR) + .canonicalize() + .context(format!("No such directory {REPOS_DIR}")) + .unwrap() + .join(repo_name) +} + +pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result { + let output = new_smol_command("git") + .current_dir(repo_path) + .args(args) + .output() + .await?; + + if output.status.success() { + Ok(String::from_utf8(output.stdout)?.trim().to_string()) + } else { + Err(anyhow!( + "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", + args.join(" "), + repo_path.display(), + output.status, + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + )) + } +} + +pub async fn send_language_model_request( + model: Arc, + request: LanguageModelRequest, + cx: &AsyncApp, +) -> anyhow::Result { + match model.stream_completion_text(request, &cx).await { + Ok(mut stream) => { + let mut full_response = String::new(); + while let Some(chunk_result) = stream.stream.next().await { + match chunk_result { + Ok(chunk_str) => { + print!("{}", &chunk_str); + full_response.push_str(&chunk_str); + } + Err(err) => { + return Err(anyhow!( + "Error receiving response from language model: {err}" + )); + } + } + } + Ok(full_response) + } + Err(err) => Err(anyhow!( + "Failed to get response from language model. Error was: {err}" + )), + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_parse_judge_output() { + let response = r#" + The model did a good job but there were still compilations errors. + 3 + "# + .unindent(); + + let output = parse_judge_output(&response).unwrap(); + assert_eq!( + output.analysis, + "The model did a good job but there were still compilations errors." + ); + assert_eq!(output.score, 3); + + let response = r#" + Text around ignored + + + Failed to compile: + - Error 1 + - Error 2 + + + 1 + "# + .unindent(); + + let output = parse_judge_output(&response).unwrap(); + assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2"); + assert_eq!(output.score, 1); + } } diff --git a/crates/eval/src/judge_prompt.hbs b/crates/eval/src/judge_prompt.hbs new file mode 100644 index 0000000000..cce120d52a --- /dev/null +++ b/crates/eval/src/judge_prompt.hbs @@ -0,0 +1,25 @@ +You are an expert software developer tasked with evaluating the following changes to a codebase: + + +{{repository_diff}} + + +Use the following criteria to score the above changes. + + +{{criteria}} + + +Based on these criteria, give the test output a score between 0 and 5. + +- 5 means: changes meet all criteria +- 0 means: changes don't meet any criteria + +Be suspicious of the changes because they were generated by an LLM. +Sometimes the LLM decides to change random code, so if the changes are not mentioned in the criteria, penalize the score. +Analyze the diff hunk by hunk and describe how each change meets or fails to meet the criteria. + +``` +{YOUR ANALYSIS HERE} +{YOUR SCORE HERE} +``` diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index e1ec23410e..a0e38c629e 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -83,7 +83,7 @@ pub enum StopReason { ToolUse, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)] pub struct TokenUsage { #[serde(default, skip_serializing_if = "is_default")] pub input_tokens: u32,