Add initial implementation of evaluating changes generated by the assistant (#26799)
Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Thomas <thomas@zed.dev>
This commit is contained in:
parent
e9b4fa1465
commit
7a888de9f5
14 changed files with 1113 additions and 24 deletions
252
crates/assistant_eval/src/eval.rs
Normal file
252
crates/assistant_eval/src/eval.rs
Normal file
|
@ -0,0 +1,252 @@
|
|||
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
|
||||
use anyhow::anyhow;
|
||||
use assistant2::RequestKind;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Task};
|
||||
use language_model::{LanguageModel, TokenUsage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use util::command::new_smol_command;
|
||||
|
||||
pub struct Eval {
|
||||
pub name: String,
|
||||
pub path: PathBuf,
|
||||
pub repo_path: PathBuf,
|
||||
pub eval_setup: EvalSetup,
|
||||
pub user_prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EvalOutput {
|
||||
pub diff: String,
|
||||
pub last_message: String,
|
||||
pub elapsed_time: Duration,
|
||||
pub assistant_response_count: usize,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub token_usage: TokenUsage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct EvalSetup {
|
||||
pub url: String,
|
||||
pub base_sha: String,
|
||||
}
|
||||
|
||||
impl Eval {
|
||||
/// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
|
||||
/// if necessary.
|
||||
pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
|
||||
let prompt_path = path.join("prompt.txt");
|
||||
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
|
||||
let setup_path = path.join("setup.json");
|
||||
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
|
||||
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
|
||||
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
|
||||
Ok(Eval {
|
||||
name,
|
||||
path,
|
||||
repo_path,
|
||||
eval_setup,
|
||||
user_prompt,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
self,
|
||||
app_state: Arc<HeadlessAppState>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut App,
|
||||
) -> Task<anyhow::Result<EvalOutput>> {
|
||||
cx.spawn(move |mut cx| async move {
|
||||
checkout_repo(&self.eval_setup, &self.repo_path).await?;
|
||||
|
||||
let (assistant, done_rx) =
|
||||
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
|
||||
|
||||
let _worktree = assistant
|
||||
.update(&mut cx, |assistant, cx| {
|
||||
assistant.project.update(cx, |project, cx| {
|
||||
project.create_worktree(&self.repo_path, true, cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let start_time = std::time::SystemTime::now();
|
||||
|
||||
assistant.update(&mut cx, |assistant, cx| {
|
||||
assistant.thread.update(cx, |thread, cx| {
|
||||
let context = vec![];
|
||||
thread.insert_user_message(self.user_prompt.clone(), context, cx);
|
||||
thread.send_to_model(model, RequestKind::Chat, cx);
|
||||
});
|
||||
})?;
|
||||
|
||||
done_rx.recv().await??;
|
||||
|
||||
let elapsed_time = start_time.elapsed()?;
|
||||
|
||||
let diff = query_git(&self.repo_path, vec!["diff"]).await?;
|
||||
|
||||
assistant.update(&mut cx, |assistant, cx| {
|
||||
let thread = assistant.thread.read(cx);
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
if last_message.role != language_model::Role::Assistant {
|
||||
return Err(anyhow!("Last message is not from assistant"));
|
||||
}
|
||||
let assistant_response_count = thread
|
||||
.messages()
|
||||
.filter(|message| message.role == language_model::Role::Assistant)
|
||||
.count();
|
||||
Ok(EvalOutput {
|
||||
diff,
|
||||
last_message: last_message.text.clone(),
|
||||
elapsed_time,
|
||||
assistant_response_count,
|
||||
tool_use_counts: assistant.tool_use_counts.clone(),
|
||||
token_usage: thread.cumulative_token_usage(),
|
||||
})
|
||||
})?
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalOutput {
|
||||
// Method to save the output to a directory
|
||||
pub fn save_to_directory(
|
||||
&self,
|
||||
output_dir: &Path,
|
||||
eval_output_value: String,
|
||||
) -> anyhow::Result<()> {
|
||||
// Create the output directory if it doesn't exist
|
||||
fs::create_dir_all(&output_dir)?;
|
||||
|
||||
// Save the diff to a file
|
||||
let diff_path = output_dir.join("diff.patch");
|
||||
let mut diff_file = fs::File::create(&diff_path)?;
|
||||
diff_file.write_all(self.diff.as_bytes())?;
|
||||
|
||||
// Save the last message to a file
|
||||
let message_path = output_dir.join("assistant_response.txt");
|
||||
let mut message_file = fs::File::create(&message_path)?;
|
||||
message_file.write_all(self.last_message.as_bytes())?;
|
||||
|
||||
// Current metrics for this run
|
||||
let current_metrics = serde_json::json!({
|
||||
"elapsed_time_ms": self.elapsed_time.as_millis(),
|
||||
"assistant_response_count": self.assistant_response_count,
|
||||
"tool_use_counts": self.tool_use_counts,
|
||||
"token_usage": self.token_usage,
|
||||
"eval_output_value": eval_output_value,
|
||||
});
|
||||
|
||||
// Get current timestamp in milliseconds
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)?
|
||||
.as_millis()
|
||||
.to_string();
|
||||
|
||||
// Path to metrics file
|
||||
let metrics_path = output_dir.join("metrics.json");
|
||||
|
||||
// Load existing metrics if the file exists, or create a new object
|
||||
let mut historical_metrics = if metrics_path.exists() {
|
||||
let metrics_content = fs::read_to_string(&metrics_path)?;
|
||||
serde_json::from_str::<serde_json::Value>(&metrics_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({}))
|
||||
} else {
|
||||
serde_json::json!({})
|
||||
};
|
||||
|
||||
// Add new run with timestamp as key
|
||||
if let serde_json::Value::Object(ref mut map) = historical_metrics {
|
||||
map.insert(timestamp, current_metrics);
|
||||
}
|
||||
|
||||
// Write updated metrics back to file
|
||||
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
|
||||
let mut metrics_file = fs::File::create(&metrics_path)?;
|
||||
metrics_file.write_all(metrics_json.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn repo_dir_name(url: &str) -> String {
|
||||
url.trim_start_matches("https://")
|
||||
.replace(|c: char| !c.is_alphanumeric(), "_")
|
||||
}
|
||||
|
||||
async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
|
||||
if !repo_path.exists() {
|
||||
smol::unblock({
|
||||
let repo_path = repo_path.to_path_buf();
|
||||
|| std::fs::create_dir_all(repo_path)
|
||||
})
|
||||
.await?;
|
||||
run_git(repo_path, vec!["init"]).await?;
|
||||
run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
|
||||
} else {
|
||||
let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
|
||||
if actual_origin != eval_setup.url {
|
||||
return Err(anyhow!(
|
||||
"remote origin {} does not match expected origin {}",
|
||||
actual_origin,
|
||||
eval_setup.url
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: consider including "-x" to remove ignored files. The downside of this is that it will
|
||||
// also remove build artifacts, and so prevent incremental reuse there.
|
||||
run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
|
||||
run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
|
||||
}
|
||||
|
||||
run_git(
|
||||
repo_path,
|
||||
vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
|
||||
)
|
||||
.await?;
|
||||
run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
|
||||
let exit_status = new_smol_command("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args.clone())
|
||||
.status()
|
||||
.await?;
|
||||
if exit_status.success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"`git {}` failed with {}",
|
||||
args.join(" "),
|
||||
exit_status,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
|
||||
let output = new_smol_command("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args.clone())
|
||||
.output()
|
||||
.await?;
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"`git {}` failed with {}",
|
||||
args.join(" "),
|
||||
output.status
|
||||
))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue