Add initial implementation of evaluating changes generated by the assistant (#26799)
Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Thomas <thomas@zed.dev>
This commit is contained in:
parent
e9b4fa1465
commit
7a888de9f5
14 changed files with 1113 additions and 24 deletions
35
Cargo.lock
generated
35
Cargo.lock
generated
|
@ -566,6 +566,41 @@ dependencies = [
|
|||
"workspace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assistant_eval"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant2",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"clap",
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"env_logger 0.11.7",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"node_runtime",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assistant_settings"
|
||||
version = "0.1.0"
|
||||
|
|
|
@ -8,6 +8,7 @@ members = [
|
|||
"crates/assistant",
|
||||
"crates/assistant2",
|
||||
"crates/assistant_context_editor",
|
||||
"crates/assistant_eval",
|
||||
"crates/assistant_settings",
|
||||
"crates/assistant_slash_command",
|
||||
"crates/assistant_slash_commands",
|
||||
|
@ -206,6 +207,7 @@ assets = { path = "crates/assets" }
|
|||
assistant = { path = "crates/assistant" }
|
||||
assistant2 = { path = "crates/assistant2" }
|
||||
assistant_context_editor = { path = "crates/assistant_context_editor" }
|
||||
assistant_eval = { path = "crates/assistant_eval" }
|
||||
assistant_settings = { path = "crates/assistant_settings" }
|
||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
|
||||
|
|
|
@ -298,6 +298,7 @@ impl ActiveThread {
|
|||
ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
|
||||
self.save_thread(cx);
|
||||
}
|
||||
ThreadEvent::DoneStreaming => {}
|
||||
ThreadEvent::StreamedAssistantText(message_id, text) => {
|
||||
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
|
||||
markdown.update(cx, |markdown, cx| {
|
||||
|
|
|
@ -31,8 +31,11 @@ use gpui::{actions, App};
|
|||
use prompt_store::PromptBuilder;
|
||||
use settings::Settings as _;
|
||||
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
|
||||
actions!(
|
||||
assistant2,
|
||||
|
|
|
@ -284,6 +284,10 @@ impl Thread {
|
|||
self.tool_use.tool_results_for_message(id)
|
||||
}
|
||||
|
||||
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
|
||||
self.tool_use.tool_result(id)
|
||||
}
|
||||
|
||||
pub fn scripting_tool_results_for_message(
|
||||
&self,
|
||||
id: MessageId,
|
||||
|
@ -652,32 +656,37 @@ impl Thread {
|
|||
let result = stream_completion.await;
|
||||
|
||||
thread
|
||||
.update(&mut cx, |thread, cx| match result.as_ref() {
|
||||
Ok(stop_reason) => match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
cx.emit(ThreadEvent::UsePendingTools);
|
||||
}
|
||||
StopReason::EndTurn => {}
|
||||
StopReason::MaxTokens => {}
|
||||
},
|
||||
Err(error) => {
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message(
|
||||
SharedString::from(error_message.clone()),
|
||||
)));
|
||||
}
|
||||
.update(&mut cx, |thread, cx| {
|
||||
match result.as_ref() {
|
||||
Ok(stop_reason) => match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
cx.emit(ThreadEvent::UsePendingTools);
|
||||
}
|
||||
StopReason::EndTurn => {}
|
||||
StopReason::MaxTokens => {}
|
||||
},
|
||||
Err(error) => {
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::Message(
|
||||
SharedString::from(error_message.clone()),
|
||||
)));
|
||||
}
|
||||
|
||||
thread.cancel_last_completion();
|
||||
thread.cancel_last_completion();
|
||||
}
|
||||
}
|
||||
cx.emit(ThreadEvent::DoneStreaming);
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
@ -1094,6 +1103,7 @@ pub enum ThreadEvent {
|
|||
ShowError(ThreadError),
|
||||
StreamedCompletion,
|
||||
StreamedAssistantText(MessageId, String),
|
||||
DoneStreaming,
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
MessageDeleted(MessageId),
|
||||
|
|
|
@ -182,6 +182,13 @@ impl ToolUseState {
|
|||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
pub fn tool_result(
|
||||
&self,
|
||||
tool_use_id: &LanguageModelToolUseId,
|
||||
) -> Option<&LanguageModelToolResult> {
|
||||
self.tool_results.get(tool_use_id)
|
||||
}
|
||||
|
||||
pub fn request_tool_use(
|
||||
&mut self,
|
||||
assistant_message_id: MessageId,
|
||||
|
|
44
crates/assistant_eval/Cargo.toml
Normal file
44
crates/assistant_eval/Cargo.toml
Normal file
|
@ -0,0 +1,44 @@
|
|||
[package]
|
||||
name = "assistant_eval"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "assistant_eval"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant2.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
env_logger.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
node_runtime.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
regex.workspace = true
|
||||
release_channel.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
util.workspace = true
|
1
crates/assistant_eval/LICENSE-GPL
Symbolic link
1
crates/assistant_eval/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
77
crates/assistant_eval/README.md
Normal file
77
crates/assistant_eval/README.md
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Tool Evals
|
||||
|
||||
A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
|
||||
|
||||
## Overview
|
||||
|
||||
Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
|
||||
|
||||
1. Cloning and setting up test repositories
|
||||
2. Sending prompts to language models
|
||||
3. Allowing the assistant to use tools to modify code
|
||||
4. Collecting metrics on performance
|
||||
5. Evaluating results against known good solutions
|
||||
|
||||
## How It Works
|
||||
|
||||
The system consists of several key components:
|
||||
|
||||
- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
|
||||
- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
|
||||
- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
|
||||
|
||||
The evaluation flow:
|
||||
1. An evaluation is loaded from the evaluation_data directory
|
||||
2. The target repository is cloned and checked out at a specific commit
|
||||
3. A HeadlessAssistant instance is created with the specified language model
|
||||
4. The user prompt is sent to the assistant
|
||||
5. The assistant responds and uses tools to modify code
|
||||
6. Upon completion, a diff is generated from the changes
|
||||
7. Results are saved including the diff, assistant's response, and performance metrics
|
||||
8. If a reference solution exists, a Judge evaluates the similarity of the solution
|
||||
|
||||
## Setup Requirements
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust and Cargo
|
||||
- Git
|
||||
- Network access to clone repositories
|
||||
- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
|
||||
- `ZED_ANTHROPIC_API_KEY` for Claude models
|
||||
- `ZED_OPENAI_API_KEY` for OpenAI models
|
||||
- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
|
||||
|
||||
## Usage
|
||||
|
||||
### Running a Single Evaluation
|
||||
|
||||
To run a specific evaluation:
|
||||
|
||||
```bash
|
||||
cargo run -p assistant_eval -- bubbletea-add-set-window-title
|
||||
```
|
||||
|
||||
The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
|
||||
|
||||
```bash
|
||||
cargo run -p assistant_eval -- bubbletea
|
||||
```
|
||||
|
||||
To run all evaluations:
|
||||
|
||||
```bash
|
||||
cargo run -p assistant_eval -- --all
|
||||
```
|
||||
|
||||
## Evaluation Data Structure
|
||||
|
||||
Each evaluation should be placed in the `evaluation_data` directory with the following structure:
|
||||
|
||||
* `prompt.txt`: The user's prompt.
|
||||
* `original.diff`: The `git diff` of the change anticipated for this prompt.
|
||||
* `setup.json`: Information about the repo used for the evaluation.
|
61
crates/assistant_eval/build.rs
Normal file
61
crates/assistant_eval/build.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
|
||||
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
if cfg!(target_os = "macos") {
|
||||
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
|
||||
|
||||
println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
|
||||
if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
|
||||
// Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
|
||||
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
|
||||
} else {
|
||||
// Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
|
||||
println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
|
||||
}
|
||||
|
||||
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
|
||||
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
|
||||
|
||||
// Seems to be required to enable Swift concurrency
|
||||
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
|
||||
|
||||
// Register exported Objective-C selectors, protocols, etc
|
||||
println!("cargo:rustc-link-arg=-Wl,-ObjC");
|
||||
}
|
||||
|
||||
// Populate git sha environment variable if git is available
|
||||
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
|
||||
println!(
|
||||
"cargo:rustc-env=TARGET={}",
|
||||
std::env::var("TARGET").unwrap()
|
||||
);
|
||||
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
|
||||
if output.status.success() {
|
||||
let git_sha = String::from_utf8_lossy(&output.stdout);
|
||||
let git_sha = git_sha.trim();
|
||||
|
||||
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
|
||||
|
||||
if let Ok(build_profile) = std::env::var("PROFILE") {
|
||||
if build_profile == "release" {
|
||||
// This is currently the best way to make `cargo build ...`'s build script
|
||||
// to print something to stdout without extra verbosity.
|
||||
println!(
|
||||
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
#[cfg(target_env = "msvc")]
|
||||
{
|
||||
// todo(windows): This is to avoid stack overflow. Remove it when solved.
|
||||
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
|
||||
}
|
||||
}
|
||||
}
|
252
crates/assistant_eval/src/eval.rs
Normal file
252
crates/assistant_eval/src/eval.rs
Normal file
|
@ -0,0 +1,252 @@
|
|||
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
|
||||
use anyhow::anyhow;
|
||||
use assistant2::RequestKind;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Task};
|
||||
use language_model::{LanguageModel, TokenUsage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use util::command::new_smol_command;
|
||||
|
||||
pub struct Eval {
|
||||
pub name: String,
|
||||
pub path: PathBuf,
|
||||
pub repo_path: PathBuf,
|
||||
pub eval_setup: EvalSetup,
|
||||
pub user_prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvalOutput {
|
||||
pub diff: String,
|
||||
pub last_message: String,
|
||||
pub elapsed_time: Duration,
|
||||
pub assistant_response_count: usize,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub token_usage: TokenUsage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EvalSetup {
|
||||
pub url: String,
|
||||
pub base_sha: String,
|
||||
}
|
||||
|
||||
impl Eval {
|
||||
/// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
|
||||
/// if necessary.
|
||||
pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
|
||||
let prompt_path = path.join("prompt.txt");
|
||||
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
|
||||
let setup_path = path.join("setup.json");
|
||||
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
|
||||
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
|
||||
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
|
||||
Ok(Eval {
|
||||
name,
|
||||
path,
|
||||
repo_path,
|
||||
eval_setup,
|
||||
user_prompt,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
self,
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<EvalOutput>> {
|
||||
cx.spawn(move |mut cx| async move {
|
||||
checkout_repo(&self.eval_setup, &self.repo_path).await?;
|
||||
|
||||
let (assistant, done_rx) =
|
||||
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
|
||||
|
||||
let _worktree = assistant
|
||||
.update(&mut cx, |assistant, cx| {
|
||||
assistant.project.update(cx, |project, cx| {
|
||||
project.create_worktree(&self.repo_path, true, cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let start_time = std::time::SystemTime::now();
|
||||
|
||||
assistant.update(&mut cx, |assistant, cx| {
|
||||
assistant.thread.update(cx, |thread, cx| {
|
||||
let context = vec![];
|
||||
thread.insert_user_message(self.user_prompt.clone(), context, cx);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
});
|
||||
})?;
|
||||
|
||||
done_rx.recv().await??;
|
||||
|
||||
let elapsed_time = start_time.elapsed()?;
|
||||
|
||||
let diff = query_git(&self.repo_path, vec!["diff"]).await?;
|
||||
|
||||
assistant.update(&mut cx, |assistant, cx| {
|
||||
let thread = assistant.thread.read(cx);
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
if last_message.role != language_model::Role::Assistant {
|
||||
return Err(anyhow!("Last message is not from assistant"));
|
||||
}
|
||||
let assistant_response_count = thread
|
||||
.messages()
|
||||
.filter(|message| message.role == language_model::Role::Assistant)
|
||||
.count();
|
||||
Ok(EvalOutput {
|
||||
diff,
|
||||
last_message: last_message.text.clone(),
|
||||
elapsed_time,
|
||||
assistant_response_count,
|
||||
tool_use_counts: assistant.tool_use_counts.clone(),
|
||||
token_usage: thread.cumulative_token_usage(),
|
||||
})
|
||||
})?
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalOutput {
|
||||
// Method to save the output to a directory
|
||||
pub fn save_to_directory(
|
||||
&self,
|
||||
output_dir: &Path,
|
||||
eval_output_value: String,
|
||||
) -> anyhow::Result<()> {
|
||||
// Create the output directory if it doesn't exist
|
||||
fs::create_dir_all(&output_dir)?;
|
||||
|
||||
// Save the diff to a file
|
||||
let diff_path = output_dir.join("diff.patch");
|
||||
let mut diff_file = fs::File::create(&diff_path)?;
|
||||
diff_file.write_all(self.diff.as_bytes())?;
|
||||
|
||||
// Save the last message to a file
|
||||
let message_path = output_dir.join("assistant_response.txt");
|
||||
let mut message_file = fs::File::create(&message_path)?;
|
||||
message_file.write_all(self.last_message.as_bytes())?;
|
||||
|
||||
// Current metrics for this run
|
||||
let current_metrics = serde_json::json!({
|
||||
"elapsed_time_ms": self.elapsed_time.as_millis(),
|
||||
"assistant_response_count": self.assistant_response_count,
|
||||
"tool_use_counts": self.tool_use_counts,
|
||||
"token_usage": self.token_usage,
|
||||
"eval_output_value": eval_output_value,
|
||||
});
|
||||
|
||||
// Get current timestamp in milliseconds
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)?
|
||||
.as_millis()
|
||||
.to_string();
|
||||
|
||||
// Path to metrics file
|
||||
let metrics_path = output_dir.join("metrics.json");
|
||||
|
||||
// Load existing metrics if the file exists, or create a new object
|
||||
let mut historical_metrics = if metrics_path.exists() {
|
||||
let metrics_content = fs::read_to_string(&metrics_path)?;
|
||||
serde_json::from_str::<serde_json::Value>(&metrics_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({}))
|
||||
} else {
|
||||
serde_json::json!({})
|
||||
};
|
||||
|
||||
// Add new run with timestamp as key
|
||||
if let serde_json::Value::Object(ref mut map) = historical_metrics {
|
||||
map.insert(timestamp, current_metrics);
|
||||
}
|
||||
|
||||
// Write updated metrics back to file
|
||||
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
|
||||
let mut metrics_file = fs::File::create(&metrics_path)?;
|
||||
metrics_file.write_all(metrics_json.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn repo_dir_name(url: &str) -> String {
|
||||
url.trim_start_matches("https://")
|
||||
.replace(|c: char| !c.is_alphanumeric(), "_")
|
||||
}
|
||||
|
||||
async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
|
||||
if !repo_path.exists() {
|
||||
smol::unblock({
|
||||
let repo_path = repo_path.to_path_buf();
|
||||
|| std::fs::create_dir_all(repo_path)
|
||||
})
|
||||
.await?;
|
||||
run_git(repo_path, vec!["init"]).await?;
|
||||
run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
|
||||
} else {
|
||||
let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
|
||||
if actual_origin != eval_setup.url {
|
||||
return Err(anyhow!(
|
||||
"remote origin {} does not match expected origin {}",
|
||||
actual_origin,
|
||||
eval_setup.url
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: consider including "-x" to remove ignored files. The downside of this is that it will
|
||||
// also remove build artifacts, and so prevent incremental reuse there.
|
||||
run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
|
||||
run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
|
||||
}
|
||||
|
||||
run_git(
|
||||
repo_path,
|
||||
vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
|
||||
)
|
||||
.await?;
|
||||
run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
|
||||
let exit_status = new_smol_command("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args.clone())
|
||||
.status()
|
||||
.await?;
|
||||
if exit_status.success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"`git {}` failed with {}",
|
||||
args.join(" "),
|
||||
exit_status,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
|
||||
let output = new_smol_command("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args.clone())
|
||||
.output()
|
||||
.await?;
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"`git {}` failed with {}",
|
||||
args.join(" "),
|
||||
output.status
|
||||
))
|
||||
}
|
||||
}
|
241
crates/assistant_eval/src/headless_assistant.rs
Normal file
241
crates/assistant_eval/src/headless_assistant.rs
Normal file
|
@ -0,0 +1,241 @@
|
|||
use anyhow::anyhow;
|
||||
use assistant2::{Thread, ThreadEvent, ThreadStore};
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use futures::StreamExt;
|
||||
use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
LanguageModelRequest,
|
||||
};
|
||||
use node_runtime::NodeRuntime;
|
||||
use project::{Project, RealFs};
|
||||
use prompt_store::PromptBuilder;
|
||||
use settings::SettingsStore;
|
||||
use smol::channel;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
|
||||
pub struct HeadlessAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
pub fs: Arc<dyn fs::Fs>,
|
||||
pub node_runtime: NodeRuntime,
|
||||
|
||||
// Additional fields not present in `workspace::AppState`.
|
||||
pub prompt_builder: Arc<PromptBuilder>,
|
||||
}
|
||||
|
||||
pub struct HeadlessAssistant {
|
||||
pub thread: Entity<Thread>,
|
||||
pub project: Entity<Project>,
|
||||
#[allow(dead_code)]
|
||||
pub thread_store: Entity<ThreadStore>,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub done_tx: channel::Sender<anyhow::Result<()>>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl HeadlessAssistant {
|
||||
pub fn new(
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
cx: &mut App,
|
||||
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
|
||||
let env = None;
|
||||
let project = Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
env,
|
||||
cx,
|
||||
);
|
||||
|
||||
let tools = Arc::new(ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
|
||||
|
||||
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
|
||||
|
||||
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
|
||||
|
||||
let headless_thread = cx.new(move |cx| Self {
|
||||
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
|
||||
thread,
|
||||
project,
|
||||
thread_store,
|
||||
tool_use_counts: HashMap::default(),
|
||||
done_tx,
|
||||
});
|
||||
|
||||
Ok((headless_thread, done_rx))
|
||||
}
|
||||
|
||||
fn handle_thread_event(
|
||||
&mut self,
|
||||
thread: Entity<Thread>,
|
||||
event: &ThreadEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
ThreadEvent::ShowError(err) => self
|
||||
.done_tx
|
||||
.send_blocking(Err(anyhow!("{:?}", err)))
|
||||
.unwrap(),
|
||||
ThreadEvent::DoneStreaming => {
|
||||
let thread = thread.read(cx);
|
||||
if let Some(message) = thread.messages().last() {
|
||||
println!("Message: {}", message.text,);
|
||||
}
|
||||
if thread.all_tools_finished() {
|
||||
self.done_tx.send_blocking(Ok(())).unwrap()
|
||||
}
|
||||
}
|
||||
ThreadEvent::UsePendingTools => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.use_pending_tools(cx);
|
||||
});
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use,
|
||||
} => {
|
||||
if let Some(pending_tool_use) = pending_tool_use {
|
||||
println!(
|
||||
"Used tool {} with input: {}",
|
||||
pending_tool_use.name, pending_tool_use.input
|
||||
);
|
||||
*self
|
||||
.tool_use_counts
|
||||
.entry(pending_tool_use.name.clone())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
|
||||
println!("Tool result: {:?}", tool_result);
|
||||
}
|
||||
if thread.read(cx).all_tools_finished() {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(model) = model_registry.active_model() {
|
||||
thread.update(cx, |thread, cx| {
|
||||
// Currently evals do not support specifying context.
|
||||
let updated_context = vec![];
|
||||
thread.send_tool_results_to_model(model, updated_context, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ThreadEvent::StreamedCompletion
|
||||
| ThreadEvent::SummaryChanged
|
||||
| ThreadEvent::StreamedAssistantText(_, _)
|
||||
| ThreadEvent::MessageAdded(_)
|
||||
| ThreadEvent::MessageEdited(_)
|
||||
| ThreadEvent::MessageDeleted(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
settings_store
|
||||
.set_default_settings(settings::default_settings().as_ref(), cx)
|
||||
.unwrap();
|
||||
cx.set_global(settings_store);
|
||||
client::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
|
||||
let client = Client::production(cx);
|
||||
cx.set_http_client(client.http_client().clone());
|
||||
|
||||
let git_binary_path = None;
|
||||
let fs = Arc::new(RealFs::new(git_binary_path));
|
||||
|
||||
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
language::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
assistant_tools::init(cx);
|
||||
context_server::init(cx);
|
||||
let stdout_is_a_pty = false;
|
||||
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
|
||||
assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
|
||||
|
||||
Arc::new(HeadlessAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
fs,
|
||||
node_runtime: NodeRuntime::unavailable(),
|
||||
prompt_builder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model = model_registry
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == model_name);
|
||||
|
||||
let Some(model) = model else {
|
||||
return Err(anyhow!(
|
||||
"No language model named {} was available. Available models: {}",
|
||||
model_name,
|
||||
model_registry
|
||||
.available_models(cx)
|
||||
.map(|model| model.id().0.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
));
|
||||
};
|
||||
|
||||
Ok(model)
|
||||
}
|
||||
|
||||
pub fn authenticate_model_provider(
|
||||
provider_id: LanguageModelProviderId,
|
||||
cx: &mut App,
|
||||
) -> Task<std::result::Result<(), AuthenticateError>> {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let model_provider = model_registry.provider(&provider_id).unwrap();
|
||||
model_provider.authenticate(cx)
|
||||
}
|
||||
|
||||
pub async fn send_language_model_request(
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request: LanguageModelRequest,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<String> {
|
||||
match model.stream_completion_text(request, &cx).await {
|
||||
Ok(mut stream) => {
|
||||
let mut full_response = String::new();
|
||||
|
||||
// Process the response stream
|
||||
while let Some(chunk_result) = stream.stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk_str) => {
|
||||
full_response.push_str(&chunk_str);
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(anyhow!(
|
||||
"Error receiving response from language model: {err}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(full_response)
|
||||
}
|
||||
Err(err) => Err(anyhow!(
|
||||
"Failed to get response from language model. Error was: {err}"
|
||||
)),
|
||||
}
|
||||
}
|
121
crates/assistant_eval/src/judge.rs
Normal file
121
crates/assistant_eval/src/judge.rs
Normal file
|
@ -0,0 +1,121 @@
|
|||
use crate::eval::EvalOutput;
|
||||
use crate::headless_assistant::send_language_model_request;
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
};
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
pub struct Judge {
|
||||
pub original_diff: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
pub original_message: Option<String>,
|
||||
pub model: Arc<dyn LanguageModel>,
|
||||
}
|
||||
|
||||
impl Judge {
|
||||
pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
|
||||
let original_diff_path = eval_path.join("original.diff");
|
||||
let original_diff = smol::unblock(move || {
|
||||
if std::fs::exists(&original_diff_path)? {
|
||||
anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
|
||||
} else {
|
||||
anyhow::Ok(None)
|
||||
}
|
||||
});
|
||||
|
||||
let original_message_path = eval_path.join("original_message.txt");
|
||||
let original_message = smol::unblock(move || {
|
||||
if std::fs::exists(&original_message_path)? {
|
||||
anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
|
||||
} else {
|
||||
anyhow::Ok(None)
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
original_diff: original_diff.await?,
|
||||
original_message: original_message.await?,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
|
||||
let Some(original_diff) = self.original_diff.as_ref() else {
|
||||
return Task::ready(Err(anyhow!("No original.diff found")));
|
||||
};
|
||||
|
||||
// TODO: check for empty diff?
|
||||
let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
temperature: Some(0.0),
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let model = self.model.clone();
|
||||
cx.spawn(move |cx| send_language_model_request(model, request, cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
|
||||
format!(
|
||||
r#"# Git Diff Similarity Evaluation Template
|
||||
|
||||
## Instructions
|
||||
|
||||
Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
|
||||
- 1.0 = Perfect functional match (achieves identical results)
|
||||
- 0.0 = No functional similarity whatsoever
|
||||
|
||||
## Evaluation Criteria
|
||||
|
||||
Please consider the following aspects in order of importance:
|
||||
|
||||
1. **Functional Equivalence (60%)**
|
||||
- Do both diffs achieve the same end result?
|
||||
- Are the changes functionally equivalent despite possibly using different approaches?
|
||||
- Do the modifications address the same issues or implement the same features?
|
||||
|
||||
2. **Logical Structure (20%)**
|
||||
- Are the logical flows similar?
|
||||
- Do the modifications affect the same code paths?
|
||||
- Are control structures (if/else, loops, etc.) modified in similar ways?
|
||||
|
||||
3. **Code Content (15%)**
|
||||
- Are similar lines added/removed?
|
||||
- Are the same variables, functions, or methods being modified?
|
||||
- Are the same APIs or libraries being used?
|
||||
|
||||
4. **File Layout (5%)**
|
||||
- Are the same files being modified?
|
||||
- Are changes occurring in similar locations within files?
|
||||
|
||||
## Input
|
||||
|
||||
Original Diff:
|
||||
```git
|
||||
{}
|
||||
```
|
||||
|
||||
New Diff:
|
||||
```git
|
||||
{}
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
|
||||
|
||||
Example output:
|
||||
0.85"#,
|
||||
original_diff, new_diff
|
||||
)
|
||||
}
|
234
crates/assistant_eval/src/main.rs
Normal file
234
crates/assistant_eval/src/main.rs
Normal file
|
@ -0,0 +1,234 @@
|
|||
mod eval;
|
||||
mod headless_assistant;
|
||||
mod judge;
|
||||
|
||||
use clap::Parser;
|
||||
use eval::{Eval, EvalOutput};
|
||||
use futures::{stream, StreamExt};
|
||||
use gpui::{Application, AsyncApp};
|
||||
use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
|
||||
use itertools::Itertools;
|
||||
use judge::Judge;
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use regex::Regex;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "assistant_eval",
|
||||
disable_version_flag = true,
|
||||
before_help = "Tool eval runner"
|
||||
)]
|
||||
struct Args {
|
||||
/// Regexes to match the names of evals to run.
|
||||
eval_name_regexes: Vec<String>,
|
||||
/// Runs all evals in `evaluation_data`, causes the regex to be ignored.
|
||||
#[arg(long)]
|
||||
all: bool,
|
||||
/// Name of the model (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model_name: String,
|
||||
/// Name of the editor model (default: value of `--model_name`).
|
||||
#[arg(long)]
|
||||
editor_model_name: Option<String>,
|
||||
/// Name of the judge model (default: value of `--model_name`).
|
||||
#[arg(long)]
|
||||
judge_model_name: Option<String>,
|
||||
/// Number of evaluations to run concurrently (default: 10)
|
||||
#[arg(short, long, default_value = "10")]
|
||||
concurrency: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
let args = Args::parse();
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client.clone());
|
||||
|
||||
let crate_dir = PathBuf::from("../zed-agent-bench");
|
||||
let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
|
||||
let repos_dir = crate_dir.join("repos").canonicalize().unwrap();
|
||||
|
||||
let all_evals = std::fs::read_dir(&evaluation_data_dir)
|
||||
.unwrap()
|
||||
.map(|path| path.unwrap().file_name().to_string_lossy().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let evals_to_run = if args.all {
|
||||
all_evals
|
||||
} else {
|
||||
args.eval_name_regexes
|
||||
.into_iter()
|
||||
.map(|regex_string| Regex::new(®ex_string).unwrap())
|
||||
.flat_map(|regex| {
|
||||
all_evals
|
||||
.iter()
|
||||
.filter(|eval_name| regex.is_match(eval_name))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if evals_to_run.is_empty() {
|
||||
panic!("Names of evals to run must be provided or `--all` specified");
|
||||
}
|
||||
|
||||
println!("Will run the following evals: {evals_to_run:?}");
|
||||
println!("Running up to {} evals concurrently", args.concurrency);
|
||||
|
||||
let editor_model_name = if let Some(model_name) = args.editor_model_name {
|
||||
model_name
|
||||
} else {
|
||||
args.model_name.clone()
|
||||
};
|
||||
|
||||
let judge_model_name = if let Some(model_name) = args.judge_model_name {
|
||||
model_name
|
||||
} else {
|
||||
args.model_name.clone()
|
||||
};
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = headless_assistant::init(cx);
|
||||
|
||||
let model = find_model(&args.model_name, cx).unwrap();
|
||||
let editor_model = find_model(&editor_model_name, cx).unwrap();
|
||||
let judge_model = find_model(&judge_model_name, cx).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_active_model(Some(model.clone()), cx);
|
||||
registry.set_editor_model(Some(editor_model.clone()), cx);
|
||||
});
|
||||
|
||||
let model_provider_id = model.provider_id();
|
||||
let editor_model_provider_id = editor_model.provider_id();
|
||||
let judge_model_provider_id = judge_model.provider_id();
|
||||
|
||||
cx.spawn(move |cx| async move {
|
||||
// Authenticate all model providers first
|
||||
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loaded_evals = stream::iter(evals_to_run)
|
||||
.map(|eval_name| {
|
||||
let eval_path = evaluation_data_dir.join(&eval_name);
|
||||
let repos_dir = repos_dir.clone();
|
||||
async move {
|
||||
match Eval::load(eval_name.clone(), eval_path, &repos_dir).await {
|
||||
Ok(eval) => Some(eval),
|
||||
Err(err) => {
|
||||
// TODO: Persist errors / surface errors at the end.
|
||||
println!("Error loading {eval_name}: {err}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.buffer_unordered(args.concurrency)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// The evals need to be loaded and grouped by URL before concurrently running, since
|
||||
// evals that use the same remote URL will use the same working directory.
|
||||
let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
|
||||
.into_iter()
|
||||
.map(|eval| (eval.eval_setup.url.clone(), eval))
|
||||
.into_group_map()
|
||||
.into_values()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Sort groups in descending order, so that bigger groups start first.
|
||||
evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
|
||||
|
||||
let results = stream::iter(evals_grouped_by_url)
|
||||
.map(|evals| {
|
||||
let model = model.clone();
|
||||
let judge_model = judge_model.clone();
|
||||
let app_state = app_state.clone();
|
||||
let cx = cx.clone();
|
||||
|
||||
async move {
|
||||
let mut results = Vec::new();
|
||||
for eval in evals {
|
||||
let name = eval.name.clone();
|
||||
println!("Starting eval named {}", name);
|
||||
let result = run_eval(
|
||||
eval,
|
||||
model.clone(),
|
||||
judge_model.clone(),
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
results.push((name, result));
|
||||
}
|
||||
results
|
||||
}
|
||||
})
|
||||
.buffer_unordered(args.concurrency)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Process results in order of completion
|
||||
for (eval_name, result) in results {
|
||||
match result {
|
||||
Ok((eval_output, judge_output)) => {
|
||||
println!("Generated diff for {eval_name}:\n");
|
||||
println!("{}\n", eval_output.diff);
|
||||
println!("Last message for {eval_name}:\n");
|
||||
println!("{}\n", eval_output.last_message);
|
||||
println!("Elapsed time: {:?}", eval_output.elapsed_time);
|
||||
println!(
|
||||
"Assistant response count: {}",
|
||||
eval_output.assistant_response_count
|
||||
);
|
||||
println!("Tool use counts: {:?}", eval_output.tool_use_counts);
|
||||
println!("Judge output for {eval_name}: {judge_output}");
|
||||
}
|
||||
Err(err) => {
|
||||
// TODO: Persist errors / surface errors at the end.
|
||||
println!("Error running {eval_name}: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cx.update(|cx| cx.quit()).unwrap();
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
|
||||
println!("Done running evals");
|
||||
}
|
||||
|
||||
async fn run_eval(
|
||||
eval: Eval,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
judge_model: Arc<dyn LanguageModel>,
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
cx: AsyncApp,
|
||||
) -> anyhow::Result<(EvalOutput, String)> {
|
||||
let path = eval.path.clone();
|
||||
let judge = Judge::load(&path, judge_model).await?;
|
||||
let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
|
||||
let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
|
||||
eval_output.save_to_directory(&path, judge_output.to_string())?;
|
||||
Ok((eval_output, judge_output))
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue