Systematically optimize agentic editing performance (#28961)
Now that we've established a proper eval in tree, this PR is reboots of our agent loop back to a set of minimal tools and simpler prompts. We should aim to get this branch feeling subjectively competitive with what's on main and then merge it, and build from there. Let's invest in our eval and use it to drive better performance of the agent loop. How you can help: Pick an example, and then make the outcome faster or better. It's fine to even use your own subjective judgment, as our evaluation criteria likely need tuning as well at this point. Focus on making the agent work better in your own subjective experience first. Let's focus on simple/practical improvements to make this thing work better, then determine how we can craft our judgment criteria to lock those improvements in. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Agus <agus@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
8102a16747
commit
bab28560ef
68 changed files with 1575 additions and 478 deletions
|
@ -9,8 +9,7 @@ use ::fs::RealFs;
|
|||
use anyhow::{Result, anyhow};
|
||||
use clap::Parser;
|
||||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
|
@ -183,7 +182,7 @@ fn main() {
|
|||
println!(
|
||||
"{}Logging to: {}",
|
||||
example.log_prefix,
|
||||
example.output_file_path.display()
|
||||
example.example_output_directory().display()
|
||||
);
|
||||
|
||||
let repo_url = example.base.url.clone();
|
||||
|
@ -192,7 +191,7 @@ fn main() {
|
|||
|
||||
if !repo_path.join(".git").is_dir() {
|
||||
println!(
|
||||
"{:<width$} < {}",
|
||||
"{:<width$} < {}",
|
||||
"↓ Cloning",
|
||||
repo_url,
|
||||
width = max_name_width
|
||||
|
@ -235,22 +234,20 @@ fn main() {
|
|||
let judge_repetitions = args.judge_repetitions;
|
||||
let concurrency = args.concurrency;
|
||||
|
||||
let tasks = examples
|
||||
.into_iter()
|
||||
.map(|example| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result =
|
||||
run_example(&example, model, app_state, judge_repetitions, cx).await;
|
||||
(result, example)
|
||||
})
|
||||
let tasks = examples.iter().map(|example| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
let example = example.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result =
|
||||
run_example(&example, model, app_state, judge_repetitions, cx).await;
|
||||
(result, example)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
});
|
||||
|
||||
let results = futures::stream::iter(tasks)
|
||||
.buffer_unordered(concurrency)
|
||||
.collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
println!("\n\n");
|
||||
|
@ -259,26 +256,41 @@ fn main() {
|
|||
println!("========================================");
|
||||
println!("");
|
||||
|
||||
let mut judge_scores = Vec::new();
|
||||
let mut diff_scores = Vec::new();
|
||||
let mut thread_scores = Vec::new();
|
||||
let mut error_count = 0;
|
||||
|
||||
for (result, example) in results {
|
||||
match result {
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
error_count += 1;
|
||||
}
|
||||
Ok(judge_results) => {
|
||||
for judge_result in judge_results {
|
||||
match judge_result {
|
||||
Ok(judge_output) => {
|
||||
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||
let score: u32 = judge_output.score;
|
||||
let score_index = (score.min(5)) as usize;
|
||||
let diff_score: u32 = judge_output.diff.score;
|
||||
let score_index = (diff_score.min(5)) as usize;
|
||||
|
||||
println!(
|
||||
"{} {}{}",
|
||||
SCORES[score_index], example.log_prefix, judge_output.score,
|
||||
"{} {}{} (Diff)",
|
||||
SCORES[score_index],
|
||||
example.log_prefix,
|
||||
judge_output.diff.score,
|
||||
);
|
||||
judge_scores.push(judge_output.score);
|
||||
diff_scores.push(judge_output.diff.score);
|
||||
|
||||
if let Some(thread) = judge_output.thread {
|
||||
let process_score: u32 = thread.score;
|
||||
let score_index = (process_score.min(5)) as usize;
|
||||
println!(
|
||||
"{} {}{} (Thread)",
|
||||
SCORES[score_index], example.log_prefix, thread.score,
|
||||
);
|
||||
thread_scores.push(thread.score);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
|
@ -290,17 +302,39 @@ fn main() {
|
|||
println!(
|
||||
"{} > {}",
|
||||
" ".repeat(max_name_width),
|
||||
example.output_file_path.display()
|
||||
example.example_output_directory().display()
|
||||
);
|
||||
}
|
||||
|
||||
let score_count = judge_scores.len();
|
||||
let average_score = judge_scores
|
||||
let diff_score_count = diff_scores.len();
|
||||
let average_diff_score = diff_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (score_count as f32);
|
||||
println!("\nAverage score: {average_score}");
|
||||
/ (diff_score_count as f32);
|
||||
|
||||
if error_count > 0 {
|
||||
println!("\n{error_count} examples failed to run!");
|
||||
}
|
||||
|
||||
if diff_score_count > 0 {
|
||||
println!("\nAverage code diff score: {average_diff_score}");
|
||||
}
|
||||
|
||||
let thread_score_count = thread_scores.len();
|
||||
|
||||
// We might have gotten no thread scores if we weren't asked to judge the thread.
|
||||
if thread_score_count > 0 {
|
||||
let average_thread_score = thread_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (thread_score_count as f32);
|
||||
|
||||
if diff_score_count > 0 {
|
||||
println!("\nAverage thread score: {average_thread_score}");
|
||||
}
|
||||
}
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_secs(2));
|
||||
|
||||
|
@ -322,45 +356,11 @@ async fn run_example(
|
|||
let run_output = cx
|
||||
.update(|cx| example.run(model.clone(), app_state.clone(), cx))?
|
||||
.await?;
|
||||
let diff = example.repository_diff().await?;
|
||||
|
||||
// Run judge for each repetition
|
||||
let mut results = Vec::new();
|
||||
for round in 0..judge_repetitions {
|
||||
let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
|
||||
let judge_tasks = (0..judge_repetitions)
|
||||
.map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
|
||||
|
||||
if let Ok(judge_output) = &judge_result {
|
||||
let cohort_id = example
|
||||
.output_file_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
|
||||
|
||||
let path = std::path::Path::new(".");
|
||||
let commit_id = get_current_commit_id(path).await.unwrap_or_default();
|
||||
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
score = judge_output.score,
|
||||
analysis = judge_output.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_summary = run_output.diagnostics,
|
||||
commit_id = commit_id
|
||||
);
|
||||
}
|
||||
|
||||
results.push(judge_result);
|
||||
}
|
||||
let results = future::join_all(judge_tasks).await;
|
||||
|
||||
app_state.client.telemetry().flush_events();
|
||||
|
||||
|
@ -537,3 +537,68 @@ pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
|
|||
get_current_commit_id(repo_path).await.unwrap_or_default()
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_judge_repetition(
|
||||
example: Example,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
round: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
|
||||
|
||||
if let Ok(judge_output) = &judge_result {
|
||||
let cohort_id = example
|
||||
.run_directory_path
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
|
||||
|
||||
let path = std::path::Path::new(".");
|
||||
let commit_id = get_current_commit_id(path).await.unwrap_or_default();
|
||||
|
||||
if let Some(thread) = &judge_output.thread {
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
diff_score = judge_output.diff.score,
|
||||
diff_analysis = judge_output.diff.analysis,
|
||||
thread_score = thread.score,
|
||||
thread_analysis = thread.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
diagnostics_after = run_output.diagnostics_after,
|
||||
commit_id = commit_id
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
diff_score = judge_output.diff.score,
|
||||
diff_analysis = judge_output.diff.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
diagnostics_after = run_output.diagnostics_after,
|
||||
commit_id = commit_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
judge_result
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue