Allow StreamingEditFileTool to also create files (#29785)

Refs #29733 

This pull request introduces a new field to the `StreamingEditFileTool`
that lets the model create or overwrite a file in a streaming way. When
one of the `assistant.stream_edits` setting / `agent-stream-edits`
feature flag is enabled, we are going to disable the `CreateFileTool` so
that the agent model can only use `StreamingEditFileTool` for file
creation.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-02 11:57:04 +02:00 committed by GitHub
parent f619d5f02a
commit 35539847a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2914 additions and 140 deletions

View file

@ -10,6 +10,7 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
use futures::{
Stream, StreamExt,
channel::mpsc::{self, UnboundedReceiver},
pin_mut,
stream::BoxStream,
};
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
#[derive(Serialize)]
pub struct EditAgentTemplate {
struct CreateFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditAgentTemplate {
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
impl Template for CreateFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
}
#[derive(Serialize)]
struct EditFilePromptTemplate {
path: Option<PathBuf>,
edit_description: String,
}
impl Template for EditFilePromptTemplate {
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EditAgentOutputEvent {
Edited,
HallucinatedOldText(SharedString),
OldTextNotFound(SharedString),
}
#[derive(Clone, Debug)]
@ -64,6 +75,82 @@ impl EditAgent {
}
}
pub fn overwrite(
&self,
buffer: Entity<Buffer>,
edit_description: String,
previous_messages: Vec<LanguageModelRequestMessage>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = CreateFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let new_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
output.await
});
(output, events_rx)
}
fn replace_text_with_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
cx: &mut AsyncApp,
) -> (
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |cx| {
// Ensure the buffer is tracked by the action log.
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
});
(task, output_events_rx)
}
pub fn edit(
&self,
buffer: Entity<Buffer>,
@ -78,10 +165,15 @@ impl EditAgent {
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let edit_chunks = this
.request_edits(snapshot, edit_description, previous_messages, cx)
.await?;
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditFilePromptTemplate {
path,
edit_description,
}
.render(&this.templates)?;
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
while let Some(event) = inner_events.next().await {
events_tx.unbounded_send(event).ok();
}
@ -90,7 +182,7 @@ impl EditAgent {
(output, events_rx)
}
fn apply_edits(
fn apply_edit_chunks(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
@ -138,7 +230,7 @@ impl EditAgent {
let Some(old_range) = old_range else {
// We couldn't find the old text in the buffer. Report the error.
output_events
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
.ok();
continue;
};
@ -232,7 +324,7 @@ impl EditAgent {
) {
let (tx, rx) = mpsc::unbounded();
let output = cx.background_spawn(async move {
futures::pin_mut!(chunks);
pin_mut!(chunks);
let mut parser = EditParser::new();
let mut raw_edits = String::new();
@ -336,20 +428,12 @@ impl EditAgent {
})
}
async fn request_edits(
async fn request(
&self,
snapshot: BufferSnapshot,
edit_description: String,
mut messages: Vec<LanguageModelRequestMessage>,
prompt: String,
cx: &mut AsyncApp,
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
let prompt = EditAgentTemplate {
path,
edit_description,
}
.render(&self.templates)?;
let mut message_content = Vec::new();
if let Some(last_message) = messages.last_mut() {
if last_message.role == Role::Assistant {
@ -611,7 +695,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
pretty_assertions::assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -648,7 +733,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -679,7 +765,8 @@ mod tests {
&mut rng,
cx,
);
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
let (apply, _events) =
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@ -692,7 +779,7 @@ mod tests {
let agent = init_test(cx).await;
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.apply_edits(
let (apply, mut events) = agent.apply_edit_chunks(
buffer.clone(),
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
&mut cx.to_async(),
@ -744,7 +831,7 @@ mod tests {
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::HallucinatedOldText(
vec![EditAgentOutputEvent::OldTextNotFound(
"hallucinated old".into()
)]
);