Add judge to new eval + provide LSP diagnostics (#28713)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <antonio@zed.dev>
Co-authored-by: agus <agus@zed.dev>
This commit is contained in:
Michael Sloan 2025-04-14 14:18:47 -06:00 committed by GitHub
parent 2603f36737
commit 6b80eb556c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 838 additions and 84 deletions

11
Cargo.lock generated
View file

@ -4878,25 +4878,36 @@ dependencies = [
"assistant_settings",
"assistant_tool",
"assistant_tools",
"async-watch",
"chrono",
"clap",
"client",
"context_server",
"dap",
"env_logger 0.11.8",
"extension",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"language",
"language_extension",
"language_model",
"language_models",
"languages",
"node_runtime",
"paths",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"settings",
"shellexpand 2.1.2",
"toml 0.8.20",
"unindent",
"util",
"workspace-hack",
]

View file

@ -827,7 +827,7 @@ impl Thread {
})
.collect(),
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage.clone(),
cumulative_token_usage: this.cumulative_token_usage,
detailed_summary_state: this.detailed_summary_state.clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
})
@ -1016,7 +1016,7 @@ impl Thread {
let task = cx.spawn(async move |thread, cx| {
let stream = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
let mut events = stream.await?;
let mut stop_reason = StopReason::EndTurn;
@ -1038,9 +1038,9 @@ impl Thread {
stop_reason = reason;
}
LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
thread.cumulative_token_usage =
thread.cumulative_token_usage.clone() + token_usage.clone()
- current_token_usage.clone();
thread.cumulative_token_usage = thread.cumulative_token_usage
+ token_usage
- current_token_usage;
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
@ -1183,7 +1183,7 @@ impl Thread {
thread.auto_capture_telemetry(cx);
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage.clone() - initial_usage;
let usage = thread.cumulative_token_usage - initial_usage;
telemetry::event!(
"Assistant Thread Completion",
@ -1862,6 +1862,10 @@ impl Thread {
.detach();
}
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage
}
pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
let model_registry = LanguageModelRegistry::read_global(cx);
let Some(model) = model_registry.default_model() else {

3
crates/eval/.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
repos/
worktrees/
runs/

View file

@ -7,28 +7,39 @@ edition.workspace = true
[dependencies]
agent.workspace = true
anyhow.workspace = true
async-watch.workspace = true
assistant_settings.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
assistant_settings.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true
context_server.workspace = true
dap.workspace = true
env_logger.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
handlebars.workspace = true
language.workspace = true
language_extension.workspace = true
language_model.workspace = true
language_models.workspace = true
languages.workspace = true
node_runtime.workspace = true
paths.workspace = true
project.workspace = true
prompt_store.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
settings.workspace = true
shellexpand.workspace = true
toml.workspace = true
unindent.workspace = true
util.workspace = true
workspace-hack.workspace = true
[[bin]]

View file

@ -1,2 +1,3 @@
path = "../zed_worktree"
url = "https://github.com/zed-industries/zed.git"
revision = "38fcadf9481d018543c65f36ac3bafeba190179b"
language_extension = "rs"

View file

@ -0,0 +1,2 @@
1. The changes must replace the previous output returned by `FindReplaceFileTool` with the new `ToolResult` struct. The struct should contain an `output` field that is the same as the string we were returning before, and a new `card` field that contains a view for the card
2. The card should be a view that displays a diff. Each line in the diff should be colored according to whether it was added, removed or unchanged.

View file

@ -1,32 +1,75 @@
mod example;
use assistant_settings::AssistantSettings;
use client::{Client, UserStore};
use client::{Client, ProxySettings, UserStore};
pub(crate) use example::*;
use ::fs::RealFs;
use anyhow::anyhow;
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
use anyhow::{Result, anyhow};
use clap::Parser;
use extension::ExtensionHostProxy;
use futures::future;
use gpui::http_client::{Uri, read_proxy_from_env};
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use node_runtime::{NodeBinaryOptions, NodeRuntime};
use project::Project;
use project::project_settings::ProjectSettings;
use prompt_store::PromptBuilder;
use release_channel::AppVersion;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use util::ResultExt as _;
pub const RUNS_DIR: &str = "./crates/eval/runs";
#[derive(Parser, Debug)]
#[command(name = "eval", disable_version_flag = true)]
struct Args {
/// Runs all examples that contain these substrings. If unspecified, all examples are run.
#[arg(value_name = "EXAMPLE_SUBSTRING")]
examples: Vec<String>,
/// Model to use (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model: String,
}
fn main() {
env_logger::init();
let args = Args::parse();
let all_available_examples = list_all_examples().unwrap();
let example_paths = all_available_examples
.iter()
.filter_map(|example_path| {
let name = example_path.file_name()?.to_string_lossy();
if args.examples.is_empty()
|| args
.examples
.iter()
.any(|name_substring| name.contains(name_substring))
{
Some(example_path.clone())
} else {
None
}
})
.collect::<Vec<_>>();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
app.run(move |cx| {
let app_state = init(cx);
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
@ -39,17 +82,142 @@ fn main() {
cx.spawn(async move |cx| {
authenticate.await.unwrap();
let example =
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
example.setup()?;
cx.update(|cx| example.run(model, app_state, cx))?.await?;
std::fs::create_dir_all(REPOS_DIR)?;
std::fs::create_dir_all(WORKTREES_DIR)?;
anyhow::Ok(())
let run_dir = Path::new(RUNS_DIR).join(format!(
"{}",
chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
));
std::fs::create_dir_all(&run_dir)?;
let mut examples = Vec::new();
for example_path in example_paths {
let example = Example::load_from_directory(&example_path, &run_dir)?;
examples.push((example_path, example));
}
let mut repo_urls = HashSet::new();
let mut clone_tasks = Vec::new();
for (_, example) in examples.iter() {
let repo_url = example.base.url.clone();
if repo_urls.insert(repo_url.clone()) {
let repo_path = repo_path_for_url(&repo_url);
if !repo_path.join(".git").is_dir() {
println!("Cloning: {}", repo_url);
let git_task = cx.spawn(async move |_cx| {
std::fs::create_dir_all(&repo_path)?;
run_git(&repo_path, &["init"]).await?;
run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
});
clone_tasks.push(git_task);
} else {
println!("Already cloned: {}", repo_url);
let actual_origin =
run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
if actual_origin != repo_url {
return Err(anyhow!(
"remote origin {} does not match expected origin {}",
actual_origin,
repo_url,
));
}
}
}
}
future::join_all(clone_tasks).await;
let tasks = examples
.into_iter()
.map(|(example_path, example)| {
let app_state = app_state.clone();
let model = model.clone();
cx.spawn(async move |cx| {
(
example_path,
run_example(example, model, app_state, cx).await,
)
})
})
.collect::<Vec<_>>();
let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
println!("\n\n");
println!("========================================");
println!(" EVAL RESULTS ");
println!("========================================");
println!("");
let mut judge_scores = Vec::new();
for (example_path, result) in results {
let example_name = example_path.file_name().unwrap().to_string_lossy();
match result {
Err(err) => {
println!("💥 {:<30}: {:?}", example_name, err);
}
Ok(judge_output) => {
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
println!(
"{} {:<30}: {}",
SCORES[judge_output.score.min(5) as usize],
example_name,
judge_output.score,
);
judge_scores.push(judge_output.score);
}
}
}
let score_count = judge_scores.len();
let average_score = judge_scores
.into_iter()
.map(|score| score as f32)
.sum::<f32>()
/ (score_count as f32);
println!("\nAverage score: {average_score}");
cx.update(|cx| cx.quit())
})
.detach_and_log_err(cx);
});
}
async fn run_example(
mut example: Example,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
cx: &mut AsyncApp,
) -> Result<JudgeOutput> {
example.setup().await?;
cx.update(|cx| example.run(model.clone(), app_state, cx))?
.await?;
let diff = example.repository_diff().await?;
example.judge(model, diff, cx).await
}
fn list_all_examples() -> Result<Vec<PathBuf>> {
let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
let entries = std::fs::read_dir(path).unwrap();
let mut result_paths = Vec::new();
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
result_paths.push(path);
}
}
Ok(result_paths)
}
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
@ -72,6 +240,27 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
// Set User-Agent so we can download language servers from GitHub
let user_agent = format!(
"Zed/{} ({}; {})",
AppVersion::global(cx),
std::env::consts::OS,
std::env::consts::ARCH
);
let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
let proxy_url = proxy_str
.as_ref()
.and_then(|input| input.parse::<Uri>().ok())
.or_else(read_proxy_from_env);
let http = {
let _guard = Tokio::handle(cx).enter();
ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
.expect("could not start HTTP client")
};
cx.set_http_client(Arc::new(http));
Project::init_settings(cx);
let client = Client::production(cx);
@ -83,13 +272,47 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let mut languages = LanguageRegistry::new(cx.background_executor().clone());
languages.set_language_server_download_dir(paths::languages_dir().clone());
let languages = Arc::new(languages);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
extension::init(cx);
let (tx, rx) = async_watch::channel(None);
cx.observe_global::<SettingsStore>(move |cx| {
let settings = &ProjectSettings::get_global(cx).node;
let options = NodeBinaryOptions {
allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
allow_binary_download: true,
use_paths: settings.path.as_ref().map(|node_path| {
let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
let npm_path = settings
.npm_path
.as_ref()
.map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
(
node_path.clone(),
npm_path.unwrap_or_else(|| {
let base_path = PathBuf::new();
node_path.parent().unwrap_or(&base_path).join("npm")
}),
)
}),
};
tx.send(Some(options)).log_err();
})
.detach();
let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
let extension_host_proxy = ExtensionHostProxy::global(cx);
language::init(cx);
language_extension::init(extension_host_proxy.clone(), languages.clone());
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
languages::init(languages.clone(), node_runtime.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
@ -109,7 +332,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
node_runtime,
prompt_builder,
})
}

View file

@ -1,83 +1,161 @@
use agent::{RequestKind, ThreadEvent, ThreadStore};
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress;
use dap::DapRegistry;
use futures::channel::oneshot;
use gpui::{App, Task};
use language_model::{LanguageModel, StopReason};
use project::Project;
use serde::Deserialize;
use std::process::Command;
use std::sync::Arc;
use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, StreamExt as _};
use gpui::{App, AsyncApp, Entity, Task};
use handlebars::Handlebars;
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
StopReason, TokenUsage,
};
use project::{LspStore, Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::fmt::Write as _;
use std::fs::File;
use std::io::Write as _;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::{
fs,
path::{Path, PathBuf},
};
use unindent::Unindent as _;
use util::ResultExt as _;
use util::command::new_smol_command;
use util::serde::default_true;
use crate::AgentAppState;
#[derive(Debug, Deserialize)]
pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
pub const REPOS_DIR: &str = "./crates/eval/repos";
pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
#[derive(Clone, Debug, Deserialize)]
pub struct ExampleBase {
pub path: PathBuf,
pub url: String,
pub revision: String,
pub language_extension: Option<String>,
pub insert_id: Option<String>,
#[serde(default = "default_true")]
pub require_lsp: bool,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Example {
pub name: String,
/// Content of `base.toml`
pub base: ExampleBase,
/// Content of the prompt.md file
/// Content of `prompt.md`
pub prompt: String,
/// Content of `criteria.md`
pub criteria: String,
/// Markdown log file to append to
pub log_file: Arc<Mutex<File>>,
}
/// Content of the rubric.md file
pub _rubric: String,
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RunOutput {
pub repository_diff: String,
pub diagnostics: String,
pub response_count: usize,
pub token_usage: TokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeInput {
pub repository_diff: String,
pub criteria: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JudgeOutput {
pub analysis: String,
pub score: u32,
}
impl Example {
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
let base_path = dir_path.as_ref().join("base.toml");
let prompt_path = dir_path.as_ref().join("prompt.md");
let rubric_path = dir_path.as_ref().join("rubric.md");
/// Load an example from a directory containing base.toml, prompt.md, and criteria.md
pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
let name = dir_path.file_name().unwrap().to_string_lossy().to_string();
let base_path = dir_path.join("base.toml");
let prompt_path = dir_path.join("prompt.md");
let criteria_path = dir_path.join("criteria.md");
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
base.path = base.path.canonicalize()?;
let log_file_path = run_dir.join(format!(
"{}.md",
dir_path.file_name().unwrap().to_str().unwrap()
));
let log_file = Arc::new(Mutex::new(File::create(&log_file_path).unwrap()));
println!("{}> Logging to {:?}", name, log_file_path);
Ok(Example {
base,
prompt: fs::read_to_string(prompt_path)?,
_rubric: fs::read_to_string(rubric_path)?,
name,
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
prompt: fs::read_to_string(prompt_path.clone())?,
criteria: fs::read_to_string(criteria_path.clone())?,
log_file,
})
}
/// Set up the example by checking out the specified Git revision
pub fn setup(&self) -> Result<()> {
// Check if the directory exists
let path = Path::new(&self.base.path);
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
pub fn worktree_path(&self) -> PathBuf {
Path::new(WORKTREES_DIR)
.canonicalize()
.context(format!("No such directory {WORKTREES_DIR}"))
.unwrap()
.join(&self.name)
}
// Change to the project directory and checkout the specified revision
let output = Command::new("git")
.current_dir(&self.base.path)
.arg("checkout")
.arg(&self.base.revision)
.output()?;
anyhow::ensure!(
output.status.success(),
"Failed to checkout revision {}: {}",
self.base.revision,
String::from_utf8_lossy(&output.stderr),
);
/// Set up the example by checking out the specified Git revision
pub async fn setup(&self) -> Result<()> {
let repo_path = repo_path_for_url(&self.base.url);
run_git(
&repo_path,
&["fetch", "--depth", "1", "origin", &self.base.revision],
)
.await?;
let worktree_path = self.worktree_path();
if worktree_path.is_dir() {
println!("{}> Resetting existing worktree", self.name);
// TODO: consider including "-x" to remove ignored files. The downside of this is that
// it will also remove build artifacts, and so prevent incremental reuse there.
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
} else {
println!("{}> Creating worktree", self.name);
let worktree_path_string = worktree_path.to_string_lossy().to_string();
run_git(
&repo_path,
&[
"worktree",
"add",
"-f",
&worktree_path_string,
&self.base.revision,
],
)
.await?;
}
Ok(())
}
pub fn run(
self,
&self,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
cx: &mut App,
) -> Task<Result<()>> {
) -> Task<Result<RunOutput>> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
@ -89,30 +167,119 @@ impl Example {
cx,
);
let worktree_path = self.worktree_path();
let worktree = project.update(cx, |project, cx| {
project.create_worktree(self.base.path, true, cx)
project.create_worktree(&worktree_path, true, cx)
});
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
let this = self.clone();
println!("USER:");
println!("{}", self.prompt);
println!("ASSISTANT:");
cx.spawn(async move |cx| {
worktree.await?;
let worktree = worktree.await?;
// Wait for worktree scan to finish before choosing a file to open.
worktree
.update(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
let lsp_open_handle_and_store = if this.base.require_lsp {
let language_extension = this.base.language_extension.as_deref().context(
"language_extension field is required in base.toml when `require_lsp == true`",
)?;
// Open a file that matches the language to cause LSP to start.
let language_file = worktree.read_with(cx, |worktree, _cx| {
worktree
.files(false, 0)
.find_map(|e| {
if e.path.clone().extension().and_then(|ext| ext.to_str())
== Some(language_extension)
{
Some(ProjectPath {
worktree_id: worktree.id(),
path: e.path.clone(),
})
} else {
None
}
})
.context("Failed to find a file for example language")
})??;
let open_language_file_buffer_task = project.update(cx, |project, cx| {
project.open_buffer(language_file.clone(), cx)
})?;
let language_file_buffer = open_language_file_buffer_task.await?;
let (lsp_open_handle, lsp_store) = project.update(cx, |project, cx| {
(
project.register_buffer_with_language_servers(&language_file_buffer, cx),
project.lsp_store().clone(),
)
})?;
// TODO: remove this once the diagnostics tool waits for new diagnostics
cx.background_executor().timer(Duration::new(5, 0)).await;
wait_for_lang_server(&lsp_store, this.name.clone(), cx).await?;
lsp_store.update(cx, |lsp_store, cx| {
lsp_open_handle.update(cx, |buffer, cx| {
buffer.update(cx, |buffer, cx| {
let has_language_server = lsp_store
.language_servers_for_local_buffer(buffer, cx)
.next()
.is_some();
if has_language_server {
Ok(())
} else {
Err(anyhow!(
"`{:?}` was opened to cause the language server to start, \
but no language servers are registered for its buffer. \
Set `require_lsp = false` in `base.toml` to skip this.",
language_file
))
}
})
})
})??;
Some((lsp_open_handle, lsp_store))
} else {
None
};
if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
return Err(anyhow!("Setup only mode"));
}
let thread_store = thread_store.await;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
{
let mut log_file = this.log_file.lock().unwrap();
writeln!(&mut log_file, "👤 USER:").log_err();
writeln!(&mut log_file, "{}", this.prompt).log_err();
writeln!(&mut log_file, "🤖 ASSISTANT:").log_err();
log_file.flush().log_err();
}
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let _subscription =
cx.subscribe(
&thread,
move |thread, event: &ThreadEvent, cx| match event {
let _subscription = cx.subscribe(&thread, {
let log_file = this.log_file.clone();
let name = this.name.clone();
move |thread, event: &ThreadEvent, cx| {
let mut log_file = log_file.lock().unwrap();
match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
if let Some(tx) = tx.take() {
@ -137,15 +304,16 @@ impl Example {
}
}
ThreadEvent::StreamedAssistantText(_, chunk) => {
print!("{}", chunk);
write!(&mut log_file, "{}", chunk).log_err();
}
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
print!("{}", chunk);
write!(&mut log_file, "{}", chunk).log_err();
}
ThreadEvent::UsePendingTools { tool_uses } => {
println!("\n\nUSING TOOLS:");
writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
for tool_use in tool_uses {
println!("{}: {}", tool_use.name, tool_use.input);
writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
.log_err();
}
}
ThreadEvent::ToolFinished {
@ -154,25 +322,331 @@ impl Example {
..
} => {
if let Some(tool_use) = pending_tool_use {
println!("\nTOOL FINISHED: {}", tool_use.name);
let message = format!("TOOL FINISHED: {}", tool_use.name);
println!("{name}> {message}");
writeln!(&mut log_file, "\n{}", message).log_err();
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("\n{}\n", tool_result.content);
let message = format!("\n{}\n", tool_result.content);
writeln!(&mut log_file, "{}", message).log_err();
}
}
_ => {}
},
)?;
}
log_file.flush().log_err();
}
})?;
thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.prompt.clone(), context, None, cx);
thread.insert_user_message(this.prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
})?;
rx.await??;
Ok(())
if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;
}
let repository_diff = this.repository_diff().await?;
let diagnostics = cx
.update(move |cx| {
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
})?
.await?;
drop(lsp_open_handle_and_store);
thread.update(cx, |thread, _cx| {
let response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
RunOutput {
repository_diff,
diagnostics,
response_count,
token_usage: thread.cumulative_token_usage(),
}
})
})
}
pub async fn judge(
&mut self,
model: Arc<dyn LanguageModel>,
repository_diff: String,
cx: &AsyncApp,
) -> Result<JudgeOutput> {
let judge_prompt = include_str!("judge_prompt.hbs");
let judge_prompt_name = "judge_prompt";
let mut handlebars = Handlebars::new();
handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
let prompt = handlebars.render(
judge_prompt_name,
&JudgeInput {
repository_diff,
criteria: self.criteria.clone(),
},
)?;
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(prompt)],
cache: false,
}],
temperature: None,
tools: Vec::new(),
stop: Vec::new(),
};
let response = send_language_model_request(model, request, cx).await?;
let mut log_file = self.log_file.lock().unwrap();
writeln!(&mut log_file, "\n\n").log_err();
writeln!(&mut log_file, "========================================").log_err();
writeln!(&mut log_file, " JUDGE OUTPUT ").log_err();
writeln!(&mut log_file, "========================================").log_err();
writeln!(&mut log_file, "\n{}", &response).log_err();
parse_judge_output(&response)
}
pub async fn repository_diff(&self) -> Result<String> {
let worktree_path = self.worktree_path();
run_git(&worktree_path, &["add", "-N"]).await?;
run_git(&worktree_path, &["diff"]).await
}
}
fn wait_for_lang_server(
lsp_store: &Entity<LspStore>,
name: String,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
if cx
.update(|cx| !has_pending_lang_server_work(lsp_store, cx))
.unwrap()
|| std::env::var("ZED_EVAL_SKIP_LS_WAIT").is_ok()
{
return Task::ready(anyhow::Ok(()));
}
println!("{}> ⏵ Waiting for language server", name);
let (mut tx, mut rx) = mpsc::channel(1);
let subscription =
cx.subscribe(&lsp_store, {
let name = name.clone();
move |lsp_store, event, cx| {
match event {
project::LspStoreEvent::LanguageServerUpdate {
message:
client::proto::update_language_server::Variant::WorkProgress(
LspWorkProgress {
message: Some(message),
..
},
),
..
} => println!("{name}> ⟲ {message}"),
_ => {}
}
if !has_pending_lang_server_work(&lsp_store, cx) {
tx.try_send(()).ok();
}
}
});
cx.spawn(async move |cx| {
let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
let result = futures::select! {
_ = rx.next() => {
println!("{}> ⚑ Language server idle", name);
anyhow::Ok(())
},
_ = timeout.fuse() => {
Err(anyhow!("LSP wait timed out after 5 minutes"))
}
};
drop(subscription);
result
})
}
fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool {
lsp_store
.read(cx)
.language_server_statuses()
.any(|(_, status)| !status.pending_work.is_empty())
}
async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
let paths_with_diagnostics = project.update(cx, |project, cx| {
project
.diagnostic_summaries(true, cx)
.filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
.map(|(project_path, _, _)| project_path)
.collect::<Vec<_>>()
})?;
let mut output = String::new();
for project_path in paths_with_diagnostics {
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
for (_, group) in snapshot.diagnostic_groups(None) {
let entry = &group.entries[group.primary_ix];
let range = entry.range.to_point(&snapshot);
let severity = match entry.diagnostic.severity {
DiagnosticSeverity::ERROR => "error",
DiagnosticSeverity::WARNING => "warning",
_ => continue,
};
writeln!(
output,
"{} at line {}: {}",
severity,
range.start.row + 1,
entry.diagnostic.message
)?;
}
}
anyhow::Ok(output)
}
fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
let analysis = get_tag("analysis", response)?.to_string();
let score = get_tag("score", response)?
.parse()
.context("error parsing score")?;
Ok(JudgeOutput { analysis, score })
}
fn get_tag(name: &'static str, response: &str) -> Result<String> {
let start_tag = format!("<{}>", name);
let end_tag = format!("</{}>", name);
let start_ix = response
.find(&start_tag)
.context(format!("{} start tag not found", name))?;
let content_start_ix = start_ix + start_tag.len();
let end_ix = content_start_ix
+ response[content_start_ix..]
.find(&end_tag)
.context(format!("{} end tag not found", name))?;
let content = response[content_start_ix..end_ix].trim().unindent();
anyhow::Ok(content)
}
pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
let repo_name = repo_url
.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "-");
Path::new(REPOS_DIR)
.canonicalize()
.context(format!("No such directory {REPOS_DIR}"))
.unwrap()
.join(repo_name)
}
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
))
}
}
pub async fn send_language_model_request(
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> anyhow::Result<String> {
match model.stream_completion_text(request, &cx).await {
Ok(mut stream) => {
let mut full_response = String::new();
while let Some(chunk_result) = stream.stream.next().await {
match chunk_result {
Ok(chunk_str) => {
print!("{}", &chunk_str);
full_response.push_str(&chunk_str);
}
Err(err) => {
return Err(anyhow!(
"Error receiving response from language model: {err}"
));
}
}
}
Ok(full_response)
}
Err(err) => Err(anyhow!(
"Failed to get response from language model. Error was: {err}"
)),
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_parse_judge_output() {
let response = r#"
<analysis>The model did a good job but there were still compilations errors.</analysis>
<score>3</score>
"#
.unindent();
let output = parse_judge_output(&response).unwrap();
assert_eq!(
output.analysis,
"The model did a good job but there were still compilations errors."
);
assert_eq!(output.score, 3);
let response = r#"
Text around ignored
<analysis>
Failed to compile:
- Error 1
- Error 2
</analysis>
<score>1</score>
"#
.unindent();
let output = parse_judge_output(&response).unwrap();
assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
assert_eq!(output.score, 1);
}
}

View file

@ -0,0 +1,25 @@
You are an expert software developer tasked with evaluating the following changes to a codebase:
<changes>
{{repository_diff}}
</changes>
Use the following criteria to score the above changes.
<criteria>
{{criteria}}
</criteria>
Based on these criteria, give the test output a score between 0 and 5.
- 5 means: changes meet all criteria
- 0 means: changes don't meet any criteria
Be suspicious of the changes because they were generated by an LLM.
Sometimes the LLM decides to change random code, so if the changes are not mentioned in the criteria, penalize the score.
Analyze the diff hunk by hunk and describe how each change meets or fails to meet the criteria.
```
<analysis>{YOUR ANALYSIS HERE}</analysis>
<score>{YOUR SCORE HERE}</score>
```

View file

@ -83,7 +83,7 @@ pub enum StopReason {
ToolUse,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
pub struct TokenUsage {
#[serde(default, skip_serializing_if = "is_default")]
pub input_tokens: u32,