Allow StreamingEditFileTool
to also create files (#29785)
Refs #29733 This pull request introduces a new field to the `StreamingEditFileTool` that lets the model create or overwrite a file in a streaming way. When one of the `assistant.stream_edits` setting / `agent-stream-edits` feature flag is enabled, we are going to disable the `CreateFileTool` so that the agent model can only use `StreamingEditFileTool` for file creation. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com> Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
This commit is contained in:
parent
f619d5f02a
commit
35539847a4
11 changed files with 2914 additions and 140 deletions
|
@ -10,6 +10,7 @@ use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
|
|||
use futures::{
|
||||
Stream, StreamExt,
|
||||
channel::mpsc::{self, UnboundedReceiver},
|
||||
pin_mut,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
||||
|
@ -23,19 +24,29 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
|
|||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EditAgentTemplate {
|
||||
struct CreateFilePromptTemplate {
|
||||
path: Option<PathBuf>,
|
||||
edit_description: String,
|
||||
}
|
||||
|
||||
impl Template for EditAgentTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
|
||||
impl Template for CreateFilePromptTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EditFilePromptTemplate {
|
||||
path: Option<PathBuf>,
|
||||
edit_description: String,
|
||||
}
|
||||
|
||||
impl Template for EditFilePromptTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum EditAgentOutputEvent {
|
||||
Edited,
|
||||
HallucinatedOldText(SharedString),
|
||||
OldTextNotFound(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -64,6 +75,82 @@ impl EditAgent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn overwrite(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_description: String,
|
||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||
) {
|
||||
let this = self.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = CreateFilePromptTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let new_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
events_tx.unbounded_send(event).ok();
|
||||
}
|
||||
output.await
|
||||
});
|
||||
(output, events_rx)
|
||||
}
|
||||
|
||||
fn replace_text_with_chunks(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
|
||||
) {
|
||||
let (output_events_tx, output_events_rx) = mpsc::unbounded();
|
||||
let this = self.clone();
|
||||
let task = cx.spawn(async move |cx| {
|
||||
// Ensure the buffer is tracked by the action log.
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
|
||||
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
})?;
|
||||
|
||||
let mut raw_edits = String::new();
|
||||
pin_mut!(edit_chunks);
|
||||
while let Some(chunk) = edit_chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
raw_edits.push_str(&chunk);
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
|
||||
this.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
})?;
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.ok();
|
||||
}
|
||||
|
||||
Ok(EditAgentOutput {
|
||||
_raw_edits: raw_edits,
|
||||
_parser_metrics: EditParserMetrics::default(),
|
||||
})
|
||||
});
|
||||
(task, output_events_rx)
|
||||
}
|
||||
|
||||
pub fn edit(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
|
@ -78,10 +165,15 @@ impl EditAgent {
|
|||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let edit_chunks = this
|
||||
.request_edits(snapshot, edit_description, previous_messages, cx)
|
||||
.await?;
|
||||
let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = EditFilePromptTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
events_tx.unbounded_send(event).ok();
|
||||
}
|
||||
|
@ -90,7 +182,7 @@ impl EditAgent {
|
|||
(output, events_rx)
|
||||
}
|
||||
|
||||
fn apply_edits(
|
||||
fn apply_edit_chunks(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
|
||||
|
@ -138,7 +230,7 @@ impl EditAgent {
|
|||
let Some(old_range) = old_range else {
|
||||
// We couldn't find the old text in the buffer. Report the error.
|
||||
output_events
|
||||
.unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
|
||||
.unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
|
||||
.ok();
|
||||
continue;
|
||||
};
|
||||
|
@ -232,7 +324,7 @@ impl EditAgent {
|
|||
) {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let output = cx.background_spawn(async move {
|
||||
futures::pin_mut!(chunks);
|
||||
pin_mut!(chunks);
|
||||
|
||||
let mut parser = EditParser::new();
|
||||
let mut raw_edits = String::new();
|
||||
|
@ -336,20 +428,12 @@ impl EditAgent {
|
|||
})
|
||||
}
|
||||
|
||||
async fn request_edits(
|
||||
async fn request(
|
||||
&self,
|
||||
snapshot: BufferSnapshot,
|
||||
edit_description: String,
|
||||
mut messages: Vec<LanguageModelRequestMessage>,
|
||||
prompt: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
let prompt = EditAgentTemplate {
|
||||
path,
|
||||
edit_description,
|
||||
}
|
||||
.render(&self.templates)?;
|
||||
|
||||
let mut message_content = Vec::new();
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
|
@ -611,7 +695,8 @@ mod tests {
|
|||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
pretty_assertions::assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
@ -648,7 +733,8 @@ mod tests {
|
|||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
@ -679,7 +765,8 @@ mod tests {
|
|||
&mut rng,
|
||||
cx,
|
||||
);
|
||||
let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
let (apply, _events) =
|
||||
agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
|
||||
apply.await.unwrap();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
@ -692,7 +779,7 @@ mod tests {
|
|||
let agent = init_test(cx).await;
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
let (apply, mut events) = agent.apply_edits(
|
||||
let (apply, mut events) = agent.apply_edit_chunks(
|
||||
buffer.clone(),
|
||||
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
|
||||
&mut cx.to_async(),
|
||||
|
@ -744,7 +831,7 @@ mod tests {
|
|||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::HallucinatedOldText(
|
||||
vec![EditAgentOutputEvent::OldTextNotFound(
|
||||
"hallucinated old".into()
|
||||
)]
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue