Allow StreamingEditFileTool to also create files (#29785)

Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-02 11:57:04 +02:00 committed by GitHub
parent f619d5f02a
commit 35539847a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2914 additions and 140 deletions

View file

@ -77,7 +77,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(SymbolInfoTool);
@ -125,12 +124,14 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(CreateFileTool);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(CreateFileTool);
registry.register_tool(EditFileTool);
}
}

View file

@ -10,6 +10,7 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
use futures::{
Stream, StreamExt,
channel::mpsc::{self, UnboundedReceiver},
pin_mut,
stream::BoxStream,
};
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
#[derive(Serialize)]
pub struct EditAgentTemplate {
struct CreateFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditAgentTemplate {
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
impl Template for CreateFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
}
#[derive(Serialize)]
struct EditFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EditAgentOutputEvent {
Edited,
HallucinatedOldText(SharedString),
OldTextNotFound(SharedString),
}
#[derive(Clone, Debug)]
@ -64,6 +75,82 @@ impl EditAgent {
}
}
pub fn overwrite(
&self,
buffer: Entity<Buffer>,
edit_description: String,
previous_messages: Vec<LanguageModelRequestMessage>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = CreateFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let new_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
output.await
});
(output, events_rx)
}
fn replace_text_with_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |cx| {
// Ensure the buffer is tracked by the action log.
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
});
(task, output_events_rx)
}
pub fn edit(
&self,
buffer: Entity<Buffer>,
@ -78,10 +165,15 @@ impl EditAgent {
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let edit_chunks = this
.request_edits(snapshot, edit_description, previous_messages, cx)
.await?;
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
@ -90,7 +182,7 @@ impl EditAgent {
(output, events_rx)
}
fn apply_edits(
fn apply_edit_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
@ -138,7 +230,7 @@ impl EditAgent {
let Some(old_range) = old_range else {
// We couldn't find the old text in the buffer. Report the error.
output_events
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
.ok();
continue;
};
@ -232,7 +324,7 @@ impl EditAgent {
) {
let (tx, rx) = mpsc::unbounded();
let output = cx.background_spawn(async move {
futures::pin_mut!(chunks);
pin_mut!(chunks);
let mut parser = EditParser::new();
let mut raw_edits = String::new();
@ -336,20 +428,12 @@ impl EditAgent {
})
}
async fn request_edits(
async fn request(
&self,
snapshot: BufferSnapshot,
edit_description: String,
mut messages: Vec<LanguageModelRequestMessage>,
prompt: String,
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditAgentTemplate {
path,
edit_description,
}
.render(&self.templates)?;
let mut message_content = Vec::new();
if let Some(last_message) = messages.last_mut() {
if last_message.role == Role::Assistant {
@ -611,7 +695,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
pretty_assertions::assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -648,7 +733,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -679,7 +765,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -692,7 +779,7 @@ mod tests {
let agent = init_test(cx).await;
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.apply_edits(
let (apply, mut events) = agent.apply_edit_chunks(
buffer.clone(),
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
&mut cx.to_async(),
@ -744,7 +831,7 @@ mod tests {
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::HallucinatedOldText(
vec![EditAgentOutputEvent::OldTextNotFound(
"hallucinated old".into()
)]
);

View file

@ -4,10 +4,11 @@ use crate::{
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use anyhow::anyhow;
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
@ -71,14 +72,15 @@ fn eval_extract_handle_command_output() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -126,14 +128,15 @@ fn eval_delete_run_git_blame() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -240,14 +243,15 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
@ -315,14 +319,15 @@ fn eval_disable_cursor_blinking() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
assertion: EvalAssertion::assert_eq(output_file_content),
},
);
}
@ -504,14 +509,15 @@ fn eval_from_pixels_constructor() {
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: false,
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
assertion: EvalAssertion::assert_eq(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@ -519,6 +525,104 @@ fn eval_from_pixels_constructor() {
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_zode() {
let input_file_path = "root/zode.py";
let edit_description = "Create the main Zode CLI script";
eval(
200,
1.,
EvalInput {
conversation: vec![
message(User, [text(include_str!("evals/fixtures/zode/prompt.md"))]),
message(
Assistant,
[
tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: "root/eval/react.py".into(),
start_line: None,
end_line: None,
},
),
tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: "root/eval/react_test.py".into(),
start_line: None,
end_line: None,
},
),
],
),
message(
User,
[
tool_result(
"tool_1",
"read_file",
include_str!("evals/fixtures/zode/react.py"),
),
tool_result(
"tool_2",
"read_file",
include_str!("evals/fixtures/zode/react_test.py"),
),
],
),
message(
Assistant,
[
text(
"Now that I understand what we need to build, I'll create the main Python script:",
),
tool_use(
"tool_3",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
create_or_overwrite: true,
},
),
],
),
],
input_path: input_file_path.into(),
input_content: None,
edit_description: edit_description.into(),
assertion: EvalAssertion::new(async move |sample, _, _cx| {
let invalid_starts = [' ', '`', '\n'];
let mut message = String::new();
for start in invalid_starts {
if sample.text.starts_with(start) {
message.push_str(&format!("The sample starts with a {:?}\n", start));
break;
}
}
// Remove trailing newline.
message.pop();
if message.is_empty() {
Ok(EvalAssertionOutcome {
score: 100,
message: None,
})
} else {
Ok(EvalAssertionOutcome {
score: 0,
message: Some(message),
})
}
}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
@ -574,11 +678,135 @@ fn tool_result(
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
input_content: Option<String>,
edit_description: String,
assertion: EvalAssertion,
}
#[derive(Clone)]
struct EvalSample {
text: String,
edit_output: EditAgentOutput,
diff: String,
}
trait AssertionFn: 'static + Send + Sync {
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>>;
}
impl<F> AssertionFn for F
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
fn assert<'a>(
&'a self,
sample: &'a EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &'a mut TestAppContext,
) -> LocalBoxFuture<'a, Result<EvalAssertionOutcome>> {
(self)(sample, judge_model, cx).boxed_local()
}
}
#[derive(Clone)]
struct EvalAssertion(Arc<dyn AssertionFn>);
impl EvalAssertion {
fn new<F>(f: F) -> Self
where
F: 'static
+ Send
+ Sync
+ AsyncFn(
&EvalSample,
Arc<dyn LanguageModel>,
&mut TestAppContext,
) -> Result<EvalAssertionOutcome>,
{
EvalAssertion(Arc::new(f))
}
fn assert_eq(expected: impl Into<String>) -> Self {
let expected = expected.into();
Self::new(async move |sample, _judge, _cx| {
Ok(EvalAssertionOutcome {
score: if strip_empty_lines(&sample.text) == strip_empty_lines(&expected) {
100
} else {
0
},
message: None,
})
})
}
fn judge_diff(assertions: &'static str) -> Self {
Self::new(async move |sample, judge, cx| {
let prompt = DiffJudgeTemplate {
diff: sample.diff.clone(),
assertions,
}
.render(&Templates::new())
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = judge
.stream_completion_text(request, &cx.to_async())
.await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionOutcome {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
})
}
async fn run(
&self,
input: &EvalSample,
judge_model: Arc<dyn LanguageModel>,
cx: &mut TestAppContext,
) -> Result<EvalAssertionOutcome> {
self.0.assert(input, judge_model, cx).await
}
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
@ -606,12 +834,12 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.entry(output.sample.text.clone())
.or_insert(Vec::new())
.push(output);
}
@ -671,10 +899,8 @@ fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
sample: EvalSample,
assertion: EvalAssertionOutcome,
}
impl Display for EvalOutput {
@ -684,14 +910,14 @@ impl Display for EvalOutput {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(f, "Diff:\n{}", self.sample.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
self.sample.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
Ok(())
}
}
@ -777,96 +1003,45 @@ impl EditAgentTest {
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
let (edit_output, _) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
} else {
let (edit_output, _) = self.agent.overwrite(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
edit_output.await?
};
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
let sample = EvalSample {
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
diff: language::unified_diff(
eval.input_content.as_deref().unwrap_or_default(),
&buffer_text,
),
text: buffer_text,
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
let assertion = eval
.assertion
.run(&sample, self.judge_model.clone(), cx)
.await?;
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
Ok(EvalOutput { assertion, sample })
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
struct EvalAssertionOutcome {
score: usize,
message: Option<String>,
}

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,14 @@
class InputCell:
def __init__(self, initial_value):
self.value = None
class ComputeCell:
def __init__(self, inputs, compute_function):
self.value = None
def add_callback(self, callback):
pass
def remove_callback(self, callback):
pass

View file

@ -0,0 +1,271 @@
# These tests are auto-generated with test data from:
# https://github.com/exercism/problem-specifications/tree/main/exercises/react/canonical-data.json
# File last updated on 2023-07-19
from functools import partial
import unittest
from react import (
InputCell,
ComputeCell,
)
class ReactTest(unittest.TestCase):
def test_input_cells_have_a_value(self):
input = InputCell(10)
self.assertEqual(input.value, 10)
def test_an_input_cell_s_value_can_be_set(self):
input = InputCell(4)
input.value = 20
self.assertEqual(input.value, 20)
def test_compute_cells_calculate_initial_value(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
self.assertEqual(output.value, 2)
def test_compute_cells_take_inputs_in_the_right_order(self):
one = InputCell(1)
two = InputCell(2)
output = ComputeCell(
[
one,
two,
],
lambda inputs: inputs[0] + inputs[1] * 10,
)
self.assertEqual(output.value, 21)
def test_compute_cells_update_value_when_dependencies_are_changed(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
input.value = 3
self.assertEqual(output.value, 4)
def test_compute_cells_can_depend_on_other_compute_cells(self):
input = InputCell(1)
times_two = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 2,
)
times_thirty = ComputeCell(
[
input,
],
lambda inputs: inputs[0] * 30,
)
output = ComputeCell(
[
times_two,
times_thirty,
],
lambda inputs: inputs[0] + inputs[1],
)
self.assertEqual(output.value, 32)
input.value = 3
self.assertEqual(output.value, 96)
def test_compute_cells_fire_callbacks(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callback_cells_only_fire_on_change(self):
input = InputCell(1)
output = ComputeCell([input], lambda inputs: 111 if inputs[0] < 3 else 222)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer[-1], 222)
def test_callbacks_do_not_report_already_reported_values(self):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer[-1], 3)
input.value = 3
self.assertEqual(cb1_observer[-1], 4)
def test_callbacks_can_fire_from_multiple_cells(self):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
plus_one.add_callback(callback1)
minus_one.add_callback(callback2)
input.value = 10
self.assertEqual(cb1_observer[-1], 11)
self.assertEqual(cb2_observer[-1], 9)
def test_callbacks_can_be_added_and_removed(self):
input = InputCell(11)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
cb3_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
callback3 = self.callback_factory(cb3_observer)
output.add_callback(callback1)
output.add_callback(callback2)
input.value = 31
self.assertEqual(cb1_observer[-1], 32)
self.assertEqual(cb2_observer[-1], 32)
output.remove_callback(callback1)
output.add_callback(callback3)
input.value = 41
self.assertEqual(len(cb1_observer), 1)
self.assertEqual(cb2_observer[-1], 42)
self.assertEqual(cb3_observer[-1], 42)
def test_removing_a_callback_multiple_times_doesn_t_interfere_with_other_callbacks(
self,
):
input = InputCell(1)
output = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
cb1_observer = []
cb2_observer = []
callback1 = self.callback_factory(cb1_observer)
callback2 = self.callback_factory(cb2_observer)
output.add_callback(callback1)
output.add_callback(callback2)
output.remove_callback(callback1)
output.remove_callback(callback1)
output.remove_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
self.assertEqual(cb2_observer[-1], 3)
def test_callbacks_should_only_be_called_once_even_if_multiple_dependencies_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one1 = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
minus_one2 = ComputeCell(
[
minus_one1,
],
lambda inputs: inputs[0] - 1,
)
output = ComputeCell(
[
plus_one,
minus_one2,
],
lambda inputs: inputs[0] * inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
output.add_callback(callback1)
input.value = 4
self.assertEqual(cb1_observer[-1], 10)
def test_callbacks_should_not_be_called_if_dependencies_change_but_output_value_doesn_t_change(
self,
):
input = InputCell(1)
plus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] + 1,
)
minus_one = ComputeCell(
[
input,
],
lambda inputs: inputs[0] - 1,
)
always_two = ComputeCell(
[
plus_one,
minus_one,
],
lambda inputs: inputs[0] - inputs[1],
)
cb1_observer = []
callback1 = self.callback_factory(cb1_observer)
always_two.add_callback(callback1)
input.value = 2
self.assertEqual(cb1_observer, [])
input.value = 3
self.assertEqual(cb1_observer, [])
input.value = 4
self.assertEqual(cb1_observer, [])
input.value = 5
self.assertEqual(cb1_observer, [])
# Utility functions.
def callback_factory(self, observer):
def callback(observer, value):
observer.append(value)
return partial(callback, observer)

View file

@ -38,7 +38,7 @@ pub struct StreamingEditFileToolInput {
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
/// The full path of the file to create or modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
@ -58,6 +58,10 @@ pub struct StreamingEditFileToolInput {
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@ -158,7 +162,7 @@ impl Tool for StreamingEditFileTool {
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !exists.await? {
if !input.create_or_overwrite && !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
@ -182,12 +186,21 @@ impl Tool for StreamingEditFileTool {
})
.await;
let (output, mut events) = edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
);
let (output, mut events) = if input.create_or_overwrite {
edit_agent.overwrite(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
} else {
edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
)
};
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
@ -213,7 +226,7 @@ impl Tool for StreamingEditFileTool {
.log_err();
}
}
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
output.await?;

View file

@ -1,4 +1,4 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
Before using this tool:

View file

@ -0,0 +1,12 @@
You are an expert engineer and your task is to write a new file from scratch.
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>
You MUST respond directly with the file's content, without explanations, additional text or triple backticks.
The text you output will be saved verbatim as the content of the file.