Allow StreamingEditFileTool to also create files (#29785)

Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-02 11:57:04 +02:00 committed by GitHub
parent f619d5f02a
commit 35539847a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2914 additions and 140 deletions

View file

@ -4,10 +4,11 @@ use crate::{
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use anyhow::anyhow;
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::assert_eq(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_zode() {
let input_file_path = "root/zode.py";
let edit_description = "Create the main Zode CLI script";
eval(
200,
1.,
EvalInput {
conversation: vec![
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
message(
Assistant,
[
tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: "root/eval/react.py".into(),
start_line: None,
end_line: None,
},
),
tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: "root/eval/react_test.py".into(),
start_line: None,
end_line: None,
},
),
],
),
message(
User,
[
tool_result(
"tool_1",
"read_file",
include_str!("evals/fixtures/zode/react.py"),
),
tool_result(
"tool_2",
"read_file",
include_str!("evals/fixtures/zode/react_test.py"),
),
],
),
message(
Assistant,
[
text(
"Now that I understand what we need to build, I'll create the main Python script:",
),
tool_use(
"tool_3",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: true,
},
),
],
),
],
input_path: input_file_path.into(),
input_content: None,
edit_description: edit_description.into(),
assertion: EvalAssertion::new(async move |sample, _, _cx| {
let invalid_starts = [' ', '`', '\n'];
let mut message = String::new();
for start in invalid_starts {
if sample.text.starts_with(start) {
message.push_str(&format!("The sample starts with a {:?}\n", start));
break;
}
}
// Remove trailing newline.
message.pop();
if message.is_empty() {
Ok(EvalAssertionOutcome {
score: 100,
message: None,
})
} else {
Ok(EvalAssertionOutcome {
score: 0,
message: Some(message),
})
}
}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
@ -574,11 +678,135 @@ fn tool_result(
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
input_content: Option<String>,
edit_description: String,
assertion: EvalAssertion,
}
#[derive(Clone)]
struct EvalSample {
text: String,
edit_output: EditAgentOutput,
diff: String,
}
trait AssertionFn: 'static + Send + Sync {
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
}
impl<F> AssertionFn for F
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
(self)(sample, judge_model, cx).boxed_local()
}
}
#[derive(Clone)]
struct EvalAssertion(Arc<dyn AssertionFn>);
impl EvalAssertion {
fn new<F>(f: F) -> Self
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
EvalAssertion(Arc::new(f))
}
fn assert_eq(expected: impl Into<String>) -> Self {
let expected = expected.into();
Self::new(async move |sample, _judge, _cx| {
Ok(EvalAssertionOutcome {
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
100
} else {
0
},
message: None,
})
})
}
fn judge_diff(assertions: &'static str) -> Self {
Self::new(async move |sample, judge, cx| {
let prompt = DiffJudgeTemplate {
diff: sample.diff.clone(),
assertions,
}
.render(&Templates::new())
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = judge
.stream_completion_text(request, &cx.to_async())
.await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
})
}
async fn run(
&self,
input: &EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &mut TestAppContext,
) -> Result<EvalAssertionOutcome> {
self.0.assert(input, judge_model, cx).await
}
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.entry(output.sample.text.clone())
.or_insert(Vec::new())
.push(output);
}
@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
sample: EvalSample,
assertion: EvalAssertionOutcome,
}
impl Display for EvalOutput {
@ -684,14 +910,14 @@ impl Display for EvalOutput {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(f, "Diff:\n{}", self.sample.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
self.sample.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
Ok(())
}
}
@ -777,96 +1003,45 @@ impl EditAgentTest {
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
let (edit_output, _) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
} else {
let (edit_output, _) = self.agent.overwrite(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
let sample = EvalSample {
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
diff: language::unified_diff(
eval.input_content.as_deref().unwrap_or_default(),
&buffer_text,
),
text: buffer_text,
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
let assertion = eval
.assertion
.run(&sample, self.judge_model.clone(), cx)
.await?;
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
Ok(EvalOutput { assertion, sample })
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
struct EvalAssertionOutcome {
score: usize,
message: Option<String>,
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,14 @@
class InputCell:
def __init__(self, initial_value):
self.value = None
class ComputeCell:
def __init__(self, inputs, compute_function):
self.value = None
def add_callback(self, callback):
pass
def remove_callback(self, callback):
pass

View file

@ -0,0 +1,271 @@
# These tests are auto-generated with test data from:
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
# File last updated on 2023-07-19
from functools import partial
import unittest
from react import (
InputCell,
ComputeCell,
)
class ReactTest(unittest.TestCase):
def test_input_cells_have_a_value(self):
input = InputCell(10)
self.assertEqual(input.value, 10)
def test_an_input_cell_s_value_can_be_set(self):
input = InputCell(4)
input.value = 20
self.assertEqual(input.value, 20)
def test_compute_cells_calculate_initial_value(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
self.assertEqual(output.value, 2)
def test_compute_cells_take_inputs_in_the_right_order(self):
one = InputCell(1)
two = InputCell(2)
output = ComputeCell(
[
one,
two,
],
lambda inputs: inputs[0] + inputs[1] * 10,
)
self.assertEqual(output.value, 21)
def test_compute_cells_update_value_when_dependencies_are_changed(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
input.value = 3
self.assertEqual(output.value, 4)
def test_compute_cells_can_depend_on_other_compute_cells(self):
input = InputCell(1)
times_two = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 2,
)
times_thirty = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 30,
)
output = ComputeCell(
[
times_two,
times_thirty,
],
lambda inputs: inputs[0] + inputs[1],
)
self.assertEqual(output.value, 32)
input.value = 3
self.assertEqual(output.value, 96)
def test_compute_cells_fire_callbacks(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callback_cells_only_fire_on_change(self):
input = InputCell(1)
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer[-1], 222)
def test_callbacks_do_not_report_already_reported_values(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer[-1], 3)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callbacks_can_fire_from_multiple_cells(self):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
plus_one.add_callback(callback1)
minus_one.add_callback(callback2)
input.value = 10
self.assertEqual(cb1_observer[-1], 11)
self.assertEqual(cb2_observer[-1], 9)
def test_callbacks_can_be_added_and_removed(self):
input = InputCell(11)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
cb3_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
callback3 = self.callback_factory(cb3_observer)
output.add_callback(callback1)
output.add_callback(callback2)
input.value = 31
self.assertEqual(cb1_observer[-1], 32)
self.assertEqual(cb2_observer[-1], 32)
output.remove_callback(callback1)
output.add_callback(callback3)
input.value = 41
self.assertEqual(len(cb1_observer), 1)
self.assertEqual(cb2_observer[-1], 42)
self.assertEqual(cb3_observer[-1], 42)
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
self,
):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
output.add_callback(callback1)
output.add_callback(callback2)
output.remove_callback(callback1)
output.remove_callback(callback1)
output.remove_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
self.assertEqual(cb2_observer[-1], 3)
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one1 = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
minus_one2 = ComputeCell(
[
minus_one1,
],
lambda inputs: inputs[0] - 1,
)
output = ComputeCell(
[
plus_one,
minus_one2,
],
lambda inputs: inputs[0] * inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 4
self.assertEqual(cb1_observer[-1], 10)
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
always_two = ComputeCell(
[
plus_one,
minus_one,
],
lambda inputs: inputs[0] - inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
always_two.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 3
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer, [])
input.value = 5
self.assertEqual(cb1_observer, [])
# Utility functions.
def callback_factory(self, observer):
def callback(observer, value):
observer.append(value)
return partial(callback, observer)