Fix eval judging missing final response (#29638)

Fixed issue where eval thread judges were not considering the last
response in the thread.

The problem was that they were getting the full list of messages from
`last_request`, which (being a request!) did not have the response yet.

Release Notes:

- N/A
This commit is contained in:
Richard Feldman 2025-04-29 23:02:46 -04:00 committed by GitHub
parent d566864891
commit c8685dc90f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 56 additions and 26 deletions

View file

@ -1,4 +1,4 @@
use agent::ThreadStore;
use agent::{Message, MessageSegment, ThreadStore};
use anyhow::{Context, Result, anyhow, bail};
use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress;
@ -60,7 +60,7 @@ pub struct RunOutput {
pub response_count: usize,
pub token_usage: TokenUsage,
pub tool_metrics: ToolMetrics,
pub last_request: LanguageModelRequest,
pub all_messages: String,
pub programmatic_assertions: AssertionsReport,
}
@ -309,19 +309,15 @@ impl ExampleInstance {
let thread_store = thread_store.await?;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
let last_request = Rc::new(RefCell::new(None));
thread.update(cx, |thread, _cx| {
let mut request_count = 0;
let last_request = Rc::clone(&last_request);
let previous_diff = Rc::new(RefCell::new("".to_string()));
let example_output_dir = this.run_directory.clone();
let last_diff_file_path = last_diff_file_path.clone();
let messages_json_file_path = example_output_dir.join("last.messages.json");
let this = this.clone();
thread.set_request_callback(move |request, response_events| {
*last_request.borrow_mut() = Some(request.clone());
request_count += 1;
let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
@ -397,10 +393,6 @@ impl ExampleInstance {
}
let Some(last_request) = last_request.borrow_mut().take() else {
return Err(anyhow!("No requests ran."));
};
if let Some(diagnostics_before) = &diagnostics_before {
fs::write(this.run_directory.join("diagnostics_before.txt"), diagnostics_before)?;
}
@ -423,7 +415,7 @@ impl ExampleInstance {
response_count,
token_usage: thread.cumulative_token_usage(),
tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(),
last_request,
all_messages: messages_to_markdown(thread.messages()),
programmatic_assertions: example_cx.assertions,
}
})
@ -526,23 +518,23 @@ impl ExampleInstance {
if thread_assertions.is_empty() {
return (
"No diff assertions".to_string(),
"No thread assertions".to_string(),
AssertionsReport::default(),
);
}
let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
let judge_diff_prompt_name = "judge_thread_prompt";
let judge_thread_prompt_name = "judge_thread_prompt";
let mut hbs = Handlebars::new();
hbs.register_template_string(judge_diff_prompt_name, judge_thread_prompt)
hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)
.unwrap();
let request_markdown = RequestMarkdown::new(&run_output.last_request);
let complete_messages = &run_output.all_messages;
let to_prompt = |assertion: String| {
hbs.render(
judge_diff_prompt_name,
judge_thread_prompt_name,
&JudgeThreadInput {
messages: request_markdown.messages.clone(),
messages: complete_messages.clone(),
assertion,
},
)
@ -817,6 +809,51 @@ pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
}
}
fn messages_to_markdown<'a>(message_iter: impl IntoIterator<Item = &'a Message>) -> String {
let mut messages = String::new();
let mut assistant_message_number: u32 = 1;
for message in message_iter {
push_role(&message.role, &mut messages, &mut assistant_message_number);
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => {
messages.push_str(&text);
messages.push_str("\n\n");
}
MessageSegment::Thinking { text, signature } => {
messages.push_str("**Thinking**:\n\n");
if let Some(sig) = signature {
messages.push_str(&format!("Signature: {}\n\n", sig));
}
messages.push_str(&text);
messages.push_str("\n");
}
MessageSegment::RedactedThinking(items) => {
messages.push_str(&format!(
"**Redacted Thinking**: {} item(s)\n\n",
items.len()
));
}
}
}
}
messages
}
fn push_role(role: &Role, buf: &mut String, assistant_message_number: &mut u32) {
match role {
Role::System => buf.push_str("# ⚙️ SYSTEM\n\n"),
Role::User => buf.push_str("# 👤 USER\n\n"),
Role::Assistant => {
buf.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
*assistant_message_number = *assistant_message_number + 1;
}
}
}
pub async fn send_language_model_request(
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
@ -875,14 +912,7 @@ impl RequestMarkdown {
// Print the messages
for message in &request.messages {
match message.role {
Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
Role::User => messages.push_str("# 👤 USER\n\n"),
Role::Assistant => {
messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
assistant_message_number += 1;
}
};
push_role(&message.role, &mut messages, &mut assistant_message_number);
for content in &message.content {
match content {