Allow StreamingEditFileTool to also create files (#29785)

Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-02 11:57:04 +02:00 committed by GitHub
parent f619d5f02a
commit 35539847a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2914 additions and 140 deletions

View file

@ -4,10 +4,11 @@ use crate::{
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use anyhow::anyhow;
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::assert_eq(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_zode() {
let input_file_path = "root/zode.py";
let edit_description = "Create the main Zode CLI script";
eval(
200,
1.,
EvalInput {
conversation: vec![
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
message(
Assistant,
[
tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: "root/eval/react.py".into(),
start_line: None,
end_line: None,
},
),
tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: "root/eval/react_test.py".into(),
start_line: None,
end_line: None,
},
),
],
),
message(
User,
[
tool_result(
"tool_1",
"read_file",
include_str!("evals/fixtures/zode/react.py"),
),
tool_result(
"tool_2",
"read_file",
include_str!("evals/fixtures/zode/react_test.py"),
),
],
),
message(
Assistant,
[
text(
"Now that I understand what we need to build, I'll create the main Python script:",
),
tool_use(
"tool_3",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: true,
},
),
],
),
],
input_path: input_file_path.into(),
input_content: None,
edit_description: edit_description.into(),
assertion: EvalAssertion::new(async move |sample, _, _cx| {
let invalid_starts = [' ', '`', '\n'];
let mut message = String::new();
for start in invalid_starts {
if sample.text.starts_with(start) {
message.push_str(&format!("The sample starts with a {:?}\n", start));
break;
}
}
// Remove trailing newline.
message.pop();
if message.is_empty() {
Ok(EvalAssertionOutcome {
score: 100,
message: None,
})
} else {
Ok(EvalAssertionOutcome {
score: 0,
message: Some(message),
})
}
}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
@ -574,11 +678,135 @@ fn tool_result(
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
input_content: Option<String>,
edit_description: String,
assertion: EvalAssertion,
}
#[derive(Clone)]
struct EvalSample {
text: String,
edit_output: EditAgentOutput,
diff: String,
}
trait AssertionFn: 'static + Send + Sync {
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
}
impl<F> AssertionFn for F
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
(self)(sample, judge_model, cx).boxed_local()
}
}
#[derive(Clone)]
struct EvalAssertion(Arc<dyn AssertionFn>);
impl EvalAssertion {
fn new<F>(f: F) -> Self
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
EvalAssertion(Arc::new(f))
}
fn assert_eq(expected: impl Into<String>) -> Self {
let expected = expected.into();
Self::new(async move |sample, _judge, _cx| {
Ok(EvalAssertionOutcome {
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
100
} else {
0
},
message: None,
})
})
}
fn judge_diff(assertions: &'static str) -> Self {
Self::new(async move |sample, judge, cx| {
let prompt = DiffJudgeTemplate {
diff: sample.diff.clone(),
assertions,
}
.render(&Templates::new())
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = judge
.stream_completion_text(request, &cx.to_async())
.await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
})
}
async fn run(
&self,
input: &EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &mut TestAppContext,
) -> Result<EvalAssertionOutcome> {
self.0.assert(input, judge_model, cx).await
}
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.entry(output.sample.text.clone())
.or_insert(Vec::new())
.push(output);
}
@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
sample: EvalSample,
assertion: EvalAssertionOutcome,
}
impl Display for EvalOutput {
@ -684,14 +910,14 @@ impl Display for EvalOutput {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(f, "Diff:\n{}", self.sample.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
self.sample.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
Ok(())
}
}
@ -777,96 +1003,45 @@ impl EditAgentTest {
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
let (edit_output, _) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
} else {
let (edit_output, _) = self.agent.overwrite(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
let sample = EvalSample {
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
diff: language::unified_diff(
eval.input_content.as_deref().unwrap_or_default(),
&buffer_text,
),
text: buffer_text,
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
let assertion = eval
.assertion
.run(&sample, self.judge_model.clone(), cx)
.await?;
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
Ok(EvalOutput { assertion, sample })
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
struct EvalAssertionOutcome {
score: usize,
message: Option<String>,
}