Systematically optimize agentic editing performance (#28961)
Now that we've established a proper eval in tree, this PR is reboots of our agent loop back to a set of minimal tools and simpler prompts. We should aim to get this branch feeling subjectively competitive with what's on main and then merge it, and build from there. Let's invest in our eval and use it to drive better performance of the agent loop. How you can help: Pick an example, and then make the outcome faster or better. It's fine to even use your own subjective judgment, as our evaluation criteria likely need tuning as well at this point. Focus on making the agent work better in your own subjective experience first. Let's focus on simple/practical improvements to make this thing work better, then determine how we can craft our judgment criteria to lock those improvements in. Release Notes: - N/A --------- Co-authored-by: Max <max@zed.dev> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Agus <agus@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com> Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Michael Sloan <mgsloan@gmail.com>
This commit is contained in:
parent
8102a16747
commit
bab28560ef
68 changed files with 1575 additions and 478 deletions
|
@ -28,7 +28,7 @@ language.workspace = true
|
|||
language_extension.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
languages.workspace = true
|
||||
languages = { workspace = true, features = ["load-grammars"] }
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
|
@ -36,6 +36,7 @@ prompt_store.workspace = true
|
|||
release_channel.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
shellexpand.workspace = true
|
||||
telemetry.workspace = true
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
url = "https://github.com/dani-garcia/vaultwarden.git"
|
||||
revision = "3a1f1bae002bebf26ce3a38b879c1ba26529af1e"
|
||||
language_extension = "rs"
|
||||
allow_preexisting_diagnostics = true
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
|
||||
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait.
|
||||
|
||||
The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
1. The first tool call should be to path search including "find_replace_file_tool.rs" in the string. (*Not* regex_search, for example, or reading the file based on a guess at the path.) This is because we gave the model a filename and it needs to turn that into a real path.
|
||||
2. After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path.
|
||||
3. When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information on what path the Render trait might be in.
|
|
@ -1,3 +1,4 @@
|
|||
url = "https://github.com/huggingface/candle.git"
|
||||
revision = "3164a19a5dc18f5e0f7a063ae85a0cfd289e98f1"
|
||||
language_extension = "rs"
|
||||
allow_preexisting_diagnostics = true
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
url = "https://github.com/firecracker-microvm/firecracker.git"
|
||||
revision = "5eaa6e08e350cd38c8102848913a096312e59097"
|
||||
language_extension = "rs"
|
||||
allow_preexisting_diagnostics = true
|
||||
|
|
|
@ -9,8 +9,7 @@ use ::fs::RealFs;
|
|||
use anyhow::{Result, anyhow};
|
||||
use clap::Parser;
|
||||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
|
@ -183,7 +182,7 @@ fn main() {
|
|||
println!(
|
||||
"{}Logging to: {}",
|
||||
example.log_prefix,
|
||||
example.output_file_path.display()
|
||||
example.example_output_directory().display()
|
||||
);
|
||||
|
||||
let repo_url = example.base.url.clone();
|
||||
|
@ -192,7 +191,7 @@ fn main() {
|
|||
|
||||
if !repo_path.join(".git").is_dir() {
|
||||
println!(
|
||||
"{:<width$} < {}",
|
||||
"{:<width$} < {}",
|
||||
"↓ Cloning",
|
||||
repo_url,
|
||||
width = max_name_width
|
||||
|
@ -235,22 +234,20 @@ fn main() {
|
|||
let judge_repetitions = args.judge_repetitions;
|
||||
let concurrency = args.concurrency;
|
||||
|
||||
let tasks = examples
|
||||
.into_iter()
|
||||
.map(|example| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result =
|
||||
run_example(&example, model, app_state, judge_repetitions, cx).await;
|
||||
(result, example)
|
||||
})
|
||||
let tasks = examples.iter().map(|example| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
let example = example.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result =
|
||||
run_example(&example, model, app_state, judge_repetitions, cx).await;
|
||||
(result, example)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
});
|
||||
|
||||
let results = futures::stream::iter(tasks)
|
||||
.buffer_unordered(concurrency)
|
||||
.collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
println!("\n\n");
|
||||
|
@ -259,26 +256,41 @@ fn main() {
|
|||
println!("========================================");
|
||||
println!("");
|
||||
|
||||
let mut judge_scores = Vec::new();
|
||||
let mut diff_scores = Vec::new();
|
||||
let mut thread_scores = Vec::new();
|
||||
let mut error_count = 0;
|
||||
|
||||
for (result, example) in results {
|
||||
match result {
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
error_count += 1;
|
||||
}
|
||||
Ok(judge_results) => {
|
||||
for judge_result in judge_results {
|
||||
match judge_result {
|
||||
Ok(judge_output) => {
|
||||
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||
let score: u32 = judge_output.score;
|
||||
let score_index = (score.min(5)) as usize;
|
||||
let diff_score: u32 = judge_output.diff.score;
|
||||
let score_index = (diff_score.min(5)) as usize;
|
||||
|
||||
println!(
|
||||
"{} {}{}",
|
||||
SCORES[score_index], example.log_prefix, judge_output.score,
|
||||
"{} {}{} (Diff)",
|
||||
SCORES[score_index],
|
||||
example.log_prefix,
|
||||
judge_output.diff.score,
|
||||
);
|
||||
judge_scores.push(judge_output.score);
|
||||
diff_scores.push(judge_output.diff.score);
|
||||
|
||||
if let Some(thread) = judge_output.thread {
|
||||
let process_score: u32 = thread.score;
|
||||
let score_index = (process_score.min(5)) as usize;
|
||||
println!(
|
||||
"{} {}{} (Thread)",
|
||||
SCORES[score_index], example.log_prefix, thread.score,
|
||||
);
|
||||
thread_scores.push(thread.score);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
|
@ -290,17 +302,39 @@ fn main() {
|
|||
println!(
|
||||
"{} > {}",
|
||||
" ".repeat(max_name_width),
|
||||
example.output_file_path.display()
|
||||
example.example_output_directory().display()
|
||||
);
|
||||
}
|
||||
|
||||
let score_count = judge_scores.len();
|
||||
let average_score = judge_scores
|
||||
let diff_score_count = diff_scores.len();
|
||||
let average_diff_score = diff_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (score_count as f32);
|
||||
println!("\nAverage score: {average_score}");
|
||||
/ (diff_score_count as f32);
|
||||
|
||||
if error_count > 0 {
|
||||
println!("\n{error_count} examples failed to run!");
|
||||
}
|
||||
|
||||
if diff_score_count > 0 {
|
||||
println!("\nAverage code diff score: {average_diff_score}");
|
||||
}
|
||||
|
||||
let thread_score_count = thread_scores.len();
|
||||
|
||||
// We might have gotten no thread scores if we weren't asked to judge the thread.
|
||||
if thread_score_count > 0 {
|
||||
let average_thread_score = thread_scores
|
||||
.into_iter()
|
||||
.map(|score| score as f32)
|
||||
.sum::<f32>()
|
||||
/ (thread_score_count as f32);
|
||||
|
||||
if diff_score_count > 0 {
|
||||
println!("\nAverage thread score: {average_thread_score}");
|
||||
}
|
||||
}
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_secs(2));
|
||||
|
||||
|
@ -322,45 +356,11 @@ async fn run_example(
|
|||
let run_output = cx
|
||||
.update(|cx| example.run(model.clone(), app_state.clone(), cx))?
|
||||
.await?;
|
||||
let diff = example.repository_diff().await?;
|
||||
|
||||
// Run judge for each repetition
|
||||
let mut results = Vec::new();
|
||||
for round in 0..judge_repetitions {
|
||||
let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
|
||||
let judge_tasks = (0..judge_repetitions)
|
||||
.map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
|
||||
|
||||
if let Ok(judge_output) = &judge_result {
|
||||
let cohort_id = example
|
||||
.output_file_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
|
||||
|
||||
let path = std::path::Path::new(".");
|
||||
let commit_id = get_current_commit_id(path).await.unwrap_or_default();
|
||||
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
score = judge_output.score,
|
||||
analysis = judge_output.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_summary = run_output.diagnostics,
|
||||
commit_id = commit_id
|
||||
);
|
||||
}
|
||||
|
||||
results.push(judge_result);
|
||||
}
|
||||
let results = future::join_all(judge_tasks).await;
|
||||
|
||||
app_state.client.telemetry().flush_events();
|
||||
|
||||
|
@ -537,3 +537,68 @@ pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
|
|||
get_current_commit_id(repo_path).await.unwrap_or_default()
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_judge_repetition(
|
||||
example: Example,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
round: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
|
||||
|
||||
if let Ok(judge_output) = &judge_result {
|
||||
let cohort_id = example
|
||||
.run_directory_path
|
||||
.file_name()
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
|
||||
|
||||
let path = std::path::Path::new(".");
|
||||
let commit_id = get_current_commit_id(path).await.unwrap_or_default();
|
||||
|
||||
if let Some(thread) = &judge_output.thread {
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
diff_score = judge_output.diff.score,
|
||||
diff_analysis = judge_output.diff.analysis,
|
||||
thread_score = thread.score,
|
||||
thread_analysis = thread.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
diagnostics_after = run_output.diagnostics_after,
|
||||
commit_id = commit_id
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
diff_score = judge_output.diff.score,
|
||||
diff_analysis = judge_output.diff.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_before = run_output.diagnostics_before,
|
||||
diagnostics_after = run_output.diagnostics_after,
|
||||
commit_id = commit_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
judge_result
|
||||
}
|
||||
|
|
|
@ -10,14 +10,16 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
|
|||
use handlebars::Handlebars;
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
StopReason, TokenUsage,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::{LspStore, Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Write as _;
|
||||
use std::fs::File;
|
||||
use std::io::Write as _;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
|
@ -45,6 +47,19 @@ pub struct ExampleBase {
|
|||
pub insert_id: Option<String>,
|
||||
#[serde(default = "default_true")]
|
||||
pub require_lsp: bool,
|
||||
#[serde(default)]
|
||||
pub allow_preexisting_diagnostics: bool,
|
||||
}
|
||||
|
||||
impl ExampleBase {
|
||||
pub fn repo_name(&self) -> String {
|
||||
self.url
|
||||
.split('/')
|
||||
.next_back()
|
||||
.unwrap_or(&"")
|
||||
.trim_end_matches(".git")
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -54,14 +69,12 @@ pub struct Example {
|
|||
pub base: ExampleBase,
|
||||
/// Content of `prompt.md`
|
||||
pub prompt: String,
|
||||
/// Content of `criteria.md`
|
||||
pub criteria: String,
|
||||
/// Markdown output file to append to
|
||||
pub output_file: Option<Arc<Mutex<File>>>,
|
||||
/// Path to the output run directory.
|
||||
pub run_dir: PathBuf,
|
||||
/// Path to markdown output file
|
||||
pub output_file_path: PathBuf,
|
||||
/// Content of `diff_criteria.md`
|
||||
pub diff_criteria: String,
|
||||
/// Content of `thread_criteria.md`, if that file exists (it's optional)
|
||||
pub thread_criteria: Option<String>,
|
||||
/// Path to the directory containing the requests and responses for the agentic loop
|
||||
pub run_directory_path: PathBuf,
|
||||
/// Prefix used for logging that identifies this example
|
||||
pub log_prefix: String,
|
||||
}
|
||||
|
@ -69,41 +82,65 @@ pub struct Example {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RunOutput {
|
||||
pub repository_diff: String,
|
||||
pub diagnostics: String,
|
||||
pub ran_diagnostics_check: bool,
|
||||
pub diagnostics_before: Option<String>,
|
||||
pub diagnostics_after: Option<String>,
|
||||
pub response_count: usize,
|
||||
pub token_usage: TokenUsage,
|
||||
pub tool_use_counts: HashMap<Arc<str>, u32>,
|
||||
pub last_request: LanguageModelRequest,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeInput {
|
||||
pub struct JudgeDiffInput {
|
||||
pub repository_diff: String,
|
||||
pub ran_diagnostics_check: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub diagnostics_before: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub diagnostics_after: Option<String>,
|
||||
pub criteria: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeOutput {
|
||||
pub struct JudgeThreadInput {
|
||||
pub messages: String,
|
||||
pub criteria: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeResponse {
|
||||
pub analysis: String,
|
||||
pub score: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JudgeOutput {
|
||||
pub thread: Option<JudgeResponse>,
|
||||
pub diff: JudgeResponse,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
/// Load an example from a directory containing base.toml, prompt.md, and criteria.md
|
||||
pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
|
||||
let name = Self::name_from_path(dir_path);
|
||||
let base_path = dir_path.join("base.toml");
|
||||
let prompt_path = dir_path.join("prompt.md");
|
||||
let criteria_path = dir_path.join("criteria.md");
|
||||
let output_file_path = run_dir.join(format!("{}.md", name));
|
||||
let diff_criteria_path = dir_path.join("diff_criteria.md");
|
||||
let thread_criteria_path = dir_path.join("thread_criteria.md");
|
||||
let thread_criteria = if thread_criteria_path.exists() {
|
||||
Some(fs::read_to_string(thread_criteria_path.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Example {
|
||||
name: name.clone(),
|
||||
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
|
||||
prompt: fs::read_to_string(prompt_path.clone())?,
|
||||
criteria: fs::read_to_string(criteria_path.clone())?,
|
||||
run_dir: run_dir.to_path_buf(),
|
||||
output_file: None,
|
||||
output_file_path,
|
||||
thread_criteria,
|
||||
diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
|
||||
run_directory_path: run_dir.to_path_buf(),
|
||||
log_prefix: name,
|
||||
})
|
||||
}
|
||||
|
@ -111,10 +148,13 @@ impl Example {
|
|||
pub fn set_repetition_number(&mut self, repetition_number: u32) {
|
||||
if repetition_number > 0 {
|
||||
self.name = format!("{}-{}", self.name, repetition_number);
|
||||
self.output_file_path = self.run_dir.join(format!("{}.md", self.name));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn example_output_directory(&self) -> PathBuf {
|
||||
self.run_directory_path.join(&self.name)
|
||||
}
|
||||
|
||||
pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
|
||||
self.log_prefix = format!(
|
||||
"{}{:<width$}\x1b[0m | ",
|
||||
|
@ -134,6 +174,7 @@ impl Example {
|
|||
.context(format!("No such directory {WORKTREES_DIR}"))
|
||||
.unwrap()
|
||||
.join(&self.name)
|
||||
.join(self.base.repo_name())
|
||||
}
|
||||
|
||||
/// Set up the example by checking out the specified Git revision
|
||||
|
@ -187,20 +228,11 @@ impl Example {
|
|||
.await?;
|
||||
}
|
||||
|
||||
// Create the output file
|
||||
let output_file = Arc::new(Mutex::new(File::create(&self.output_file_path)?));
|
||||
self.output_file = Some(output_file);
|
||||
std::fs::create_dir_all(self.example_output_directory())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the output file, panicking if it's not set
|
||||
fn output_file(&self) -> Arc<Mutex<File>> {
|
||||
self.output_file
|
||||
.clone()
|
||||
.expect("Output file not created. Call setup() first.")
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
@ -305,6 +337,11 @@ impl Example {
|
|||
None
|
||||
};
|
||||
|
||||
let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
|
||||
if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
|
||||
return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
|
||||
}
|
||||
|
||||
if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
|
||||
return Err(anyhow!("Setup only mode"));
|
||||
}
|
||||
|
@ -312,15 +349,32 @@ impl Example {
|
|||
let thread_store = thread_store.await?;
|
||||
let thread =
|
||||
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
|
||||
let last_request = Rc::new(RefCell::new(None));
|
||||
|
||||
{
|
||||
let output_file_ref = this.output_file();
|
||||
let mut output_file = output_file_ref.lock().unwrap();
|
||||
writeln!(&mut output_file, "👤 USER:").log_err();
|
||||
writeln!(&mut output_file, "{}", this.prompt).log_err();
|
||||
writeln!(&mut output_file, "🤖 ASSISTANT:").log_err();
|
||||
output_file.flush().log_err();
|
||||
}
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let mut request_count = 0;
|
||||
let example_dir_path = this.example_output_directory();
|
||||
|
||||
let last_request = Rc::clone(&last_request);
|
||||
thread.set_request_callback(move |request, response_events| {
|
||||
*last_request.borrow_mut() = Some(request.clone());
|
||||
|
||||
request_count += 1;
|
||||
let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
|
||||
let last_messages_file_path = example_dir_path.join("last.messages.md");
|
||||
let request_markdown = RequestMarkdown::new(request);
|
||||
let response_events_markdown = response_events_to_markdown(response_events);
|
||||
|
||||
let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
|
||||
fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
|
||||
fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
|
||||
|
||||
if request_count == 1 {
|
||||
let tools_file_path = example_dir_path.join("tools.md");
|
||||
fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
|
||||
}
|
||||
});
|
||||
})?;
|
||||
|
||||
let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
|
||||
Mutex::new(HashMap::default()).into();
|
||||
|
@ -332,8 +386,6 @@ impl Example {
|
|||
});
|
||||
|
||||
let event_handler_task = cx.spawn({
|
||||
// Need to clone the Arc here because the reference from output_file() won't live long enough
|
||||
let output_file = this.output_file.clone().unwrap();
|
||||
let log_prefix = this.log_prefix.clone();
|
||||
let tool_use_counts = tool_use_counts.clone();
|
||||
let thread = thread.downgrade();
|
||||
|
@ -349,8 +401,6 @@ impl Example {
|
|||
return Err(anyhow!("ThreadEvent channel ended early"));
|
||||
};
|
||||
|
||||
let mut output_file = output_file.lock().unwrap();
|
||||
|
||||
match event {
|
||||
ThreadEvent::Stopped(reason) => match reason {
|
||||
Ok(StopReason::EndTurn) => {
|
||||
|
@ -371,18 +421,7 @@ impl Example {
|
|||
ThreadEvent::ShowError(thread_error) => {
|
||||
break Err(anyhow!(thread_error.clone()));
|
||||
}
|
||||
ThreadEvent::StreamedAssistantText(_, chunk) => {
|
||||
write!(&mut output_file, "{}", chunk).log_err();
|
||||
}
|
||||
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
|
||||
write!(&mut output_file, "{}", chunk).log_err();
|
||||
}
|
||||
ThreadEvent::UsePendingTools { tool_uses } => {
|
||||
writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err();
|
||||
for tool_use in tool_uses {
|
||||
writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input)
|
||||
.log_err();
|
||||
}
|
||||
ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
|
||||
}
|
||||
ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
|
@ -398,8 +437,6 @@ impl Example {
|
|||
format!("TOOL FINISHED: {}", tool_use.name)
|
||||
};
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
|
||||
let mut tool_use_counts = tool_use_counts.lock().unwrap();
|
||||
*tool_use_counts
|
||||
.entry(tool_result.tool_name.clone())
|
||||
|
@ -407,7 +444,6 @@ impl Example {
|
|||
} else {
|
||||
let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
@ -428,8 +464,6 @@ impl Example {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
output_file.flush().log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -451,21 +485,35 @@ impl Example {
|
|||
println!("{}Getting repository diff", this.log_prefix);
|
||||
let repository_diff = this.repository_diff().await?;
|
||||
|
||||
let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name));
|
||||
let example_output_dir = this.example_output_directory();
|
||||
let repository_diff_path = example_output_dir.join("patch.diff");
|
||||
let mut repository_diff_output_file = File::create(&repository_diff_path)?;
|
||||
writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
|
||||
|
||||
println!("{}Getting diagnostics", this.log_prefix);
|
||||
let diagnostics = cx
|
||||
let diagnostics_after = cx
|
||||
.update(move |cx| {
|
||||
cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
|
||||
})?
|
||||
.await?;
|
||||
println!("{}Got diagnostics", this.log_prefix);
|
||||
|
||||
let Some(last_request) = last_request.borrow_mut().take() else {
|
||||
return Err(anyhow!("No requests ran."));
|
||||
};
|
||||
|
||||
drop(subscription);
|
||||
drop(lsp_open_handle_and_store);
|
||||
|
||||
if let Some(diagnostics_before) = &diagnostics_before {
|
||||
fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
|
||||
}
|
||||
|
||||
if let Some(diagnostics_after) = &diagnostics_after {
|
||||
fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
|
||||
}
|
||||
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let response_count = thread
|
||||
.messages()
|
||||
|
@ -473,31 +521,38 @@ impl Example {
|
|||
.count();
|
||||
RunOutput {
|
||||
repository_diff,
|
||||
diagnostics,
|
||||
ran_diagnostics_check: this.base.require_lsp,
|
||||
diagnostics_before,
|
||||
diagnostics_after,
|
||||
response_count,
|
||||
token_usage: thread.cumulative_token_usage(),
|
||||
tool_use_counts: tool_use_counts.lock().unwrap().clone(),
|
||||
last_request,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn judge(
|
||||
async fn judge_diff(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
repository_diff: String,
|
||||
judge_repetitions: u32,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let judge_prompt = include_str!("judge_prompt.hbs");
|
||||
let judge_prompt_name = "judge_prompt";
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_template_string(judge_prompt_name, judge_prompt)?;
|
||||
let prompt = handlebars.render(
|
||||
judge_prompt_name,
|
||||
&JudgeInput {
|
||||
repository_diff,
|
||||
criteria: self.criteria.clone(),
|
||||
) -> Result<(String, JudgeResponse)> {
|
||||
let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
|
||||
let judge_diff_prompt_name = "judge_diff_prompt";
|
||||
let mut hbs = Handlebars::new();
|
||||
hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
|
||||
|
||||
let diff_prompt = hbs.render(
|
||||
judge_diff_prompt_name,
|
||||
&JudgeDiffInput {
|
||||
repository_diff: run_output.repository_diff.clone(),
|
||||
ran_diagnostics_check: run_output.ran_diagnostics_check,
|
||||
diagnostics_before: run_output.diagnostics_before.clone(),
|
||||
diagnostics_after: run_output.diagnostics_after.clone(),
|
||||
criteria: self.diff_criteria.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
|
@ -506,7 +561,7 @@ impl Example {
|
|||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
content: vec![MessageContent::Text(diff_prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
temperature: None,
|
||||
|
@ -514,24 +569,106 @@ impl Example {
|
|||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let response = send_language_model_request(model, request, cx).await?;
|
||||
let diff_response = send_language_model_request(model, request, cx).await?;
|
||||
let diff_output = JudgeResponse::parse(&diff_response)?;
|
||||
|
||||
let judge_file_path = self.run_dir.join(format!(
|
||||
"{}_judge_{}.md",
|
||||
self.name, // This is the eval_name
|
||||
judge_repetitions
|
||||
));
|
||||
println!(
|
||||
"{}Judge #{judge_number} - Diff score: {}",
|
||||
self.log_prefix, diff_output.score
|
||||
);
|
||||
|
||||
let mut judge_output_file = File::create(&judge_file_path)?;
|
||||
writeln!(&mut judge_output_file, "{}", &response).log_err();
|
||||
|
||||
parse_judge_output(&response)
|
||||
Ok((diff_response, diff_output))
|
||||
}
|
||||
|
||||
pub async fn repository_diff(&self) -> Result<String> {
|
||||
async fn judge_thread(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<(String, Option<JudgeResponse>)> {
|
||||
if let Some(criteria) = self.thread_criteria.clone() {
|
||||
let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
|
||||
let judge_thread_prompt_name = "judge_thread_prompt";
|
||||
let mut hbs = Handlebars::new();
|
||||
hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
|
||||
|
||||
let request_markdown = RequestMarkdown::new(&run_output.last_request);
|
||||
let thread_prompt = hbs.render(
|
||||
judge_thread_prompt_name,
|
||||
&JudgeThreadInput {
|
||||
messages: request_markdown.messages,
|
||||
criteria,
|
||||
},
|
||||
)?;
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(thread_prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
temperature: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let thread_response = send_language_model_request(model, request, cx).await?;
|
||||
let thread_output = JudgeResponse::parse(&thread_response)?;
|
||||
|
||||
println!(
|
||||
"{}Judge #{judge_number} - Thread score: {}",
|
||||
self.log_prefix, thread_output.score
|
||||
);
|
||||
|
||||
Ok((thread_response, Some(thread_output)))
|
||||
} else {
|
||||
let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
|
||||
Ok((msg, None))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn judge(
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
run_output: &RunOutput,
|
||||
judge_number: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let mut output_file = File::create(
|
||||
self.example_output_directory()
|
||||
.join(format!("judge_{}.md", judge_number)),
|
||||
)
|
||||
.expect("failed to create judge.md");
|
||||
|
||||
println!("{}Running judge #{judge_number}", self.log_prefix);
|
||||
|
||||
let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
|
||||
let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
|
||||
|
||||
let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
|
||||
|
||||
let (diff_response, diff_output) = diff_result?;
|
||||
let (thread_response, thread_output) = thread_result?;
|
||||
|
||||
writeln!(
|
||||
&mut output_file,
|
||||
"# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
|
||||
)
|
||||
.log_err();
|
||||
|
||||
Ok(JudgeOutput {
|
||||
thread: thread_output,
|
||||
diff: diff_output,
|
||||
})
|
||||
}
|
||||
|
||||
async fn repository_diff(&self) -> Result<String> {
|
||||
let worktree_path = self.worktree_path();
|
||||
run_git(&worktree_path, &["add", "-N"]).await?;
|
||||
run_git(&worktree_path, &["diff"]).await
|
||||
run_git(&worktree_path, &["add", "."]).await?;
|
||||
run_git(&worktree_path, &["diff", "--staged"]).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -599,7 +736,10 @@ fn has_pending_lang_server_work(lsp_store: &Entity<LspStore>, cx: &App) -> bool
|
|||
.any(|(_, status)| !status.pending_work.is_empty())
|
||||
}
|
||||
|
||||
async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> Result<String> {
|
||||
async fn query_lsp_diagnostics(
|
||||
project: Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Option<String>> {
|
||||
let paths_with_diagnostics = project.update(cx, |project, cx| {
|
||||
project
|
||||
.diagnostic_summaries(true, cx)
|
||||
|
@ -608,6 +748,10 @@ async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> R
|
|||
.collect::<Vec<_>>()
|
||||
})?;
|
||||
|
||||
if paths_with_diagnostics.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for project_path in paths_with_diagnostics {
|
||||
let buffer = project
|
||||
|
@ -633,16 +777,18 @@ async fn query_lsp_diagnostics(project: Entity<Project>, cx: &mut AsyncApp) -> R
|
|||
)?;
|
||||
}
|
||||
}
|
||||
anyhow::Ok(output)
|
||||
anyhow::Ok(Some(output))
|
||||
}
|
||||
|
||||
fn parse_judge_output(response: &str) -> Result<JudgeOutput> {
|
||||
let analysis = get_tag("analysis", response)?.to_string();
|
||||
let score = get_tag("score", response)?
|
||||
.parse()
|
||||
.context("error parsing score")?;
|
||||
impl JudgeResponse {
|
||||
fn parse(response: &str) -> Result<Self> {
|
||||
let analysis = get_tag("analysis", response)?.to_string();
|
||||
let score = get_tag("score", response)?
|
||||
.parse()
|
||||
.context("error parsing score")?;
|
||||
|
||||
Ok(JudgeOutput { analysis, score })
|
||||
Ok(Self { analysis, score })
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tag(name: &'static str, response: &str) -> Result<String> {
|
||||
|
@ -724,9 +870,135 @@ pub async fn send_language_model_request(
|
|||
}
|
||||
}
|
||||
|
||||
struct RequestMarkdown {
|
||||
tools: String,
|
||||
messages: String,
|
||||
}
|
||||
|
||||
impl RequestMarkdown {
|
||||
fn new(request: &LanguageModelRequest) -> Self {
|
||||
let mut tools = String::new();
|
||||
let mut messages = String::new();
|
||||
|
||||
// Print the tools
|
||||
if !request.tools.is_empty() {
|
||||
for tool in &request.tools {
|
||||
write!(&mut tools, "# {}\n\n", tool.name).unwrap();
|
||||
write!(&mut tools, "{}\n\n", tool.description).unwrap();
|
||||
write!(
|
||||
&mut tools,
|
||||
"```json\n{}\n```\n\n",
|
||||
serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default()
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Print the messages
|
||||
for message in &request.messages {
|
||||
let role_str = match message.role {
|
||||
Role::User => "👤 USER",
|
||||
Role::Assistant => "🤖 ASSISTANT",
|
||||
Role::System => "⚙️ SYSTEM",
|
||||
};
|
||||
|
||||
messages.push_str(&format!("# {}\n\n", role_str));
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
messages.push_str(text);
|
||||
messages.push_str("\n\n");
|
||||
}
|
||||
MessageContent::Image(_) => {
|
||||
messages.push_str("[IMAGE DATA]\n\n");
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
messages.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
));
|
||||
messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
messages.push_str(&format!(
|
||||
"**Tool Result**: {} (ID: {})\n\n",
|
||||
tool_result.tool_name, tool_result.tool_use_id
|
||||
));
|
||||
if tool_result.is_error {
|
||||
messages.push_str("**ERROR:**\n");
|
||||
}
|
||||
messages.push_str(&format!("{}\n", tool_result.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self { tools, messages }
|
||||
}
|
||||
}
|
||||
|
||||
fn response_events_to_markdown(
|
||||
response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
|
||||
) -> String {
|
||||
let mut response = String::new();
|
||||
// Print the response events if any
|
||||
response.push_str("# Response\n\n");
|
||||
let mut text_buffer = String::new();
|
||||
let mut thinking_buffer = String::new();
|
||||
|
||||
let flush_buffers =
|
||||
|output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
|
||||
if !text_buffer.is_empty() {
|
||||
output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
|
||||
text_buffer.clear();
|
||||
}
|
||||
if !thinking_buffer.is_empty() {
|
||||
output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
|
||||
thinking_buffer.clear();
|
||||
}
|
||||
};
|
||||
|
||||
for event in response_events {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
text_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Thinking(text)) => {
|
||||
thinking_buffer.push_str(text);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!("**Stop**: {:?}\n\n", reason));
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!(
|
||||
"**Tool Use**: {} (ID: {})\n",
|
||||
tool_use.name, tool_use.id
|
||||
));
|
||||
response.push_str(&format!("```json\n{}\n```\n\n", tool_use.input));
|
||||
}
|
||||
Ok(
|
||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
||||
| LanguageModelCompletionEvent::StartMessage { .. },
|
||||
) => {}
|
||||
Err(error) => {
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
response.push_str(&format!("**Error**: {}\n\n", error));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use handlebars::Handlebars;
|
||||
|
||||
#[test]
|
||||
fn test_parse_judge_output() {
|
||||
|
@ -736,7 +1008,7 @@ mod test {
|
|||
"#
|
||||
.unindent();
|
||||
|
||||
let output = parse_judge_output(&response).unwrap();
|
||||
let output = JudgeResponse::parse(&response).unwrap();
|
||||
assert_eq!(
|
||||
output.analysis,
|
||||
"The model did a good job but there were still compilations errors."
|
||||
|
@ -756,8 +1028,158 @@ mod test {
|
|||
"#
|
||||
.unindent();
|
||||
|
||||
let output = parse_judge_output(&response).unwrap();
|
||||
let output = JudgeResponse::parse(&response).unwrap();
|
||||
assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
|
||||
assert_eq!(output.score, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_diagnostics() {
|
||||
// Case 1: Both diagnostics before and after are present
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: Some("Error at line 10: variable not found".to_string()),
|
||||
diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
Error at line 10: variable not found
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
Error at line 15: missing semicolon
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_empty_diagnostics() {
|
||||
// Case 2: Diagnostics check run but no diagnostics found
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
No diagnostics before applying the edits.
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
No diagnostics after applying the edits.
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_with_mixed_diagnostics() {
|
||||
let templates = templates();
|
||||
|
||||
// Case 3: Before diagnostics present, after diagnostics absent
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: Some("Error at line 10: variable not found".to_string()),
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
Error at line 10: variable not found
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
No diagnostics after applying the edits.
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
|
||||
// Case 4: Before diagnostics absent, after diagnostics present
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: true,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
let expected_diagnostics_section = r#"
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
No diagnostics before applying the edits.
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
Error at line 15: missing semicolon
|
||||
</diagnostics_after>
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert!(rendered.contains(&expected_diagnostics_section));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_judge_prompt_without_diagnostics() {
|
||||
let templates = templates();
|
||||
|
||||
// Case 5: No diagnostics check run
|
||||
let input = JudgeDiffInput {
|
||||
repository_diff: "diff content goes here".to_string(),
|
||||
ran_diagnostics_check: false,
|
||||
diagnostics_before: None,
|
||||
diagnostics_after: None,
|
||||
criteria: "Fix all bugs".to_string(),
|
||||
};
|
||||
|
||||
let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
|
||||
|
||||
// Check for the message when no diagnostics were performed
|
||||
let diagnostics_message = "No diagnostic checks were performed.";
|
||||
|
||||
assert!(rendered.contains(diagnostics_message));
|
||||
assert!(!rendered.contains("<diagnostics_before>"));
|
||||
assert!(!rendered.contains("<diagnostics_after>"));
|
||||
}
|
||||
|
||||
const JUDGE_PROMPT_NAME: &str = "judge_prompt";
|
||||
|
||||
fn templates() -> Handlebars<'static> {
|
||||
let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
|
||||
language::LineEnding::normalize(&mut judge_prompt);
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars
|
||||
.register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
|
||||
.unwrap();
|
||||
handlebars
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,28 @@ Use the following criteria to score the above changes.
|
|||
{{criteria}}
|
||||
</criteria>
|
||||
|
||||
{{#if ran_diagnostics_check}}
|
||||
Take into account the diagnostics before and after applying the change:
|
||||
|
||||
<diagnostics_before>
|
||||
{{#if diagnostics_before}}
|
||||
{{{diagnostics_before}}}
|
||||
{{else}}
|
||||
No diagnostics before applying the edits.
|
||||
{{/if}}
|
||||
</diagnostics_before>
|
||||
|
||||
<diagnostics_after>
|
||||
{{#if diagnostics_after}}
|
||||
{{{diagnostics_after}}}
|
||||
{{else}}
|
||||
No diagnostics after applying the edits.
|
||||
{{/if}}
|
||||
</diagnostics_after>
|
||||
{{else}}
|
||||
No diagnostic checks were performed.
|
||||
{{/if}}
|
||||
|
||||
Based on these criteria, give the test output a score between 0 and 5.
|
||||
The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats.
|
||||
|
22
crates/eval/src/judge_thread_prompt.hbs
Normal file
22
crates/eval/src/judge_thread_prompt.hbs
Normal file
|
@ -0,0 +1,22 @@
|
|||
You are an expert software developer tasked with evaluating an AI agent's messages and tool calls in this conversation:
|
||||
|
||||
<messages>
|
||||
{{{messages}}}
|
||||
</messages>
|
||||
|
||||
Use the following criteria to score the above messages.
|
||||
|
||||
<criteria>
|
||||
{{criteria}}
|
||||
</criteria>
|
||||
|
||||
Based on these criteria, give the messages a score between 0 and 5.
|
||||
The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats.
|
||||
|
||||
- 5 means: messages meet all criteria
|
||||
- 0 means: messages don't meet any criteria
|
||||
|
||||
```
|
||||
<analysis>{YOUR ANALYSIS HERE}</analysis>
|
||||
<score>{YOUR SCORE HERE}</score>
|
||||
```
|
Loading…
Add table
Add a link
Reference in a new issue