Add the ability to follow the agent as it makes edits (#29839)

Nathan here: I also tacked on a bunch of UI refinement.

Release Notes:

- Introduced the ability to follow the agent around as it reads and
edits files.

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Antonio Scandurra 2025-05-04 10:28:39 +02:00 committed by GitHub
parent 425f32e068
commit 545ae27079
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1255 additions and 567 deletions

View file

@ -19,6 +19,7 @@ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
MessageContent, Role,
};
use project::{AgentLocation, Project};
use serde::Serialize;
use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
@ -59,17 +60,20 @@ pub struct EditAgentOutput {
pub struct EditAgent {
model: Arc<dyn LanguageModel>,
action_log: Entity<ActionLog>,
project: Entity<Project>,
templates: Arc<Templates>,
}
impl EditAgent {
pub fn new(
model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
) -> Self {
EditAgent {
model,
project,
action_log,
templates,
}
@ -118,39 +122,74 @@ impl EditAgent {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |cx| {
// Ensure the buffer is tracked by the action log.
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
this.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
let output = this
.replace_text_with_chunks_internal(buffer, edit_chunks, output_events_tx, cx)
.await;
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx))?;
output
});
(task, output_events_rx)
}
async fn replace_text_with_chunks_internal(
&self,
buffer: Entity<Buffer>,
edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
cx: &mut AsyncApp,
) -> Result<EditAgentOutput> {
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
self.action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx);
});
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX,
}),
cx,
)
});
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
})?;
let mut raw_edits = String::new();
pin_mut!(edit_chunks);
while let Some(chunk) = edit_chunks.next().await {
let chunk = chunk?;
raw_edits.push_str(&chunk);
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX,
}),
cx,
)
});
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
Ok(EditAgentOutput {
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
}
pub fn edit(
&self,
buffer: Entity<Buffer>,
@ -161,6 +200,18 @@ impl EditAgent {
Task<Result<EditAgentOutput>>,
mpsc::UnboundedReceiver<EditAgentOutputEvent>,
) {
self.project
.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MIN,
}),
cx,
);
})
.ok();
let this = self.clone();
let (events_tx, events_rx) = mpsc::unbounded();
let output = cx.spawn(async move |cx| {
@ -194,8 +245,14 @@ impl EditAgent {
let (output_events_tx, output_events_rx) = mpsc::unbounded();
let this = self.clone();
let task = cx.spawn(async move |mut cx| {
this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
.await
this.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
let output = this
.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
.await;
this.project
.update(cx, |project, cx| project.set_agent_location(None, cx))?;
output
});
(task, output_events_rx)
}
@ -207,10 +264,6 @@ impl EditAgent {
output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
cx: &mut AsyncApp,
) -> Result<EditAgentOutput> {
// Ensure the buffer is tracked by the action log.
self.action_log
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
while let Some(edit_event) = edit_events.next().await {
let EditParserEvent::OldText(old_text_query) = edit_event? else {
@ -275,14 +328,15 @@ impl EditAgent {
match op {
CharOperation::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
edits_tx.unbounded_send((edit_start..edit_start, text))?;
edits_tx
.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
}
CharOperation::Delete { bytes } => {
let edit_end = edit_start + bytes;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
edits_tx.unbounded_send((edit_range, String::new()))?;
edits_tx.unbounded_send((edit_range, Arc::from("")))?;
}
CharOperation::Keep { bytes } => edit_start += bytes,
}
@ -296,13 +350,35 @@ impl EditAgent {
// TODO: group all edits into one transaction
let mut edits_rx = edits_rx.ready_chunks(32);
while let Some(edits) = edits_rx.next().await {
if edits.is_empty() {
continue;
}
// Edit the buffer and report edits to the action log as part of the
// same effect cycle, otherwise the edit will be reported as if the
// user made it.
cx.update(|cx| {
buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
let max_edit_end = buffer.update(cx, |buffer, cx| {
buffer.edit(edits.iter().cloned(), None, cx);
let max_edit_end = buffer
.summaries_for_anchors::<Point, _>(
edits.iter().map(|(range, _)| &range.end),
)
.max()
.unwrap();
buffer.anchor_before(max_edit_end)
});
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
self.project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: max_edit_end,
}),
cx,
);
});
})?;
output_events
.unbounded_send(EditAgentOutputEvent::Edited)
@ -657,7 +733,7 @@ mod tests {
use gpui::{App, AppContext, TestAppContext};
use indoc::indoc;
use language_model::fake_provider::FakeLanguageModel;
use project::Project;
use project::{AgentLocation, Project};
use rand::prelude::*;
use rand::rngs::StdRng;
use std::cmp;
@ -775,8 +851,11 @@ mod tests {
}
#[gpui::test]
async fn test_events(cx: &mut TestAppContext) {
async fn test_edit_events(cx: &mut TestAppContext) {
let agent = init_test(cx).await;
let project = agent
.action_log
.read_with(cx, |log, _| log.project().clone());
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.apply_edit_chunks(
@ -792,6 +871,10 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
chunks_tx.unbounded_send("bc</old_text>").unwrap();
cx.run_until_parked();
@ -800,6 +883,10 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
chunks_tx.unbounded_send("<new_text>abX").unwrap();
cx.run_until_parked();
@ -808,6 +895,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXc\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
})
);
chunks_tx.unbounded_send("cY").unwrap();
cx.run_until_parked();
@ -816,6 +910,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("</new_text>").unwrap();
chunks_tx.unbounded_send("<old_text>hall").unwrap();
@ -825,6 +926,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
chunks_tx.unbounded_send("<new_text>").unwrap();
@ -839,6 +947,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
chunks_tx.unbounded_send("text>").unwrap();
@ -848,6 +963,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("<old_text>gh").unwrap();
chunks_tx.unbounded_send("i</old_text>").unwrap();
@ -858,6 +980,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
})
);
chunks_tx.unbounded_send("GHI</new_text>").unwrap();
cx.run_until_parked();
@ -869,6 +998,13 @@ mod tests {
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nGHI"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
})
);
drop(chunks_tx);
apply.await.unwrap();
@ -877,16 +1013,108 @@ mod tests {
"abXcY\ndef\nGHI"
);
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
}
fn drain_events(
stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
) -> Vec<EditAgentOutputEvent> {
let mut events = Vec::new();
while let Ok(Some(event)) = stream.try_next() {
events.push(event);
}
events
}
#[gpui::test]
async fn test_overwrite_events(cx: &mut TestAppContext) {
let agent = init_test(cx).await;
let project = agent
.action_log
.read_with(cx, |log, _| log.project().clone());
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
let (chunks_tx, chunks_rx) = mpsc::unbounded();
let (apply, mut events) = agent.replace_text_with_chunks(
buffer.clone(),
chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
&mut cx.to_async(),
);
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
""
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("jkl\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\n"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("mno\n").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\n"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
chunks_tx.unbounded_send("pqr").unwrap();
cx.run_until_parked();
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\npqr"
);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MAX
})
);
drop(chunks_tx);
apply.await.unwrap();
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"jkl\nmno\npqr"
);
assert_eq!(drain_events(&mut events), vec![]);
assert_eq!(
project.read_with(cx, |project, _| project.agent_location()),
None
);
}
#[gpui::test]
@ -1173,7 +1401,17 @@ mod tests {
cx.update(Project::init_settings);
let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
let model = Arc::new(FakeLanguageModel::default());
let action_log = cx.new(|_| ActionLog::new(project));
EditAgent::new(model, action_log, Templates::new())
let action_log = cx.new(|_| ActionLog::new(project.clone()));
EditAgent::new(model, project, action_log, Templates::new())
}
fn drain_events(
stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
) -> Vec<EditAgentOutputEvent> {
let mut events = Vec::new();
while let Ok(Some(event)) = stream.try_next() {
events.push(event);
}
events
}
}

View file

@ -517,7 +517,7 @@ fn eval_from_pixels_constructor() {
input_path: input_file_path.into(),
input_content: Some(input_file_content.into()),
edit_description: edit_description.into(),
assertion: EvalAssertion::assert_eq(indoc! {"
assertion: EvalAssertion::judge_diff(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
@ -957,7 +957,7 @@ impl EditAgentTest {
cx.spawn(async move |cx| {
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
Self::load_model("google", "gemini-2.5-pro-preview-03-25", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
@ -967,7 +967,7 @@ impl EditAgentTest {
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
agent: EditAgent::new(agent_model, action_log, Templates::new()),
agent: EditAgent::new(agent_model, project.clone(), action_log, Templates::new()),
project,
judge_model,
}

View file

@ -15,7 +15,7 @@ use language::{
language_settings::SoftWrap,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
@ -164,6 +164,19 @@ impl Tool for EditFileTool {
})?
.await?;
// Set the agent's location to the top of the file
project
.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: language::Anchor::MIN,
}),
cx,
);
})
.ok();
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
@ -226,6 +239,7 @@ impl Tool for EditFileTool {
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
let base_version = diff.base_version.clone();
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
@ -233,6 +247,21 @@ impl Tool for EditFileTool {
buffer.snapshot()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
// Set the agent's location to the position of the first edit
if let Some(first_edit) = snapshot.edits_since::<usize>(&base_version).next() {
let position = snapshot.anchor_before(first_edit.new.start);
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position,
}),
cx,
);
})
}
snapshot
})?;

View file

@ -6,8 +6,9 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
use indoc::formatdoc;
use itertools::Itertools;
use language::{Anchor, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
@ -35,11 +36,11 @@ pub struct ReadFileToolInput {
/// Optional line number to start reading on (1-based index)
#[serde(default)]
pub start_line: Option<usize>,
pub start_line: Option<u32>,
/// Optional line number to end reading on (1-based index, inclusive)
#[serde(default)]
pub end_line: Option<usize>,
pub end_line: Option<u32>,
}
pub struct ReadFileTool;
@ -109,7 +110,7 @@ impl Tool for ReadFileTool {
let file_path = input.path.clone();
cx.spawn(async move |cx| {
if !exists.await? {
return Err(anyhow!("{} not found", file_path))
return Err(anyhow!("{} not found", file_path));
}
let buffer = cx
@ -118,25 +119,54 @@ impl Tool for ReadFileTool {
})?
.await?;
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
// Check if specific line ranges are provided
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let lines = text.split('\n').skip(start - 1);
let start_row = start - 1;
if start_row <= buffer.max_point().row {
let column = buffer.line_indent_for_row(start_row).raw_len();
anchor = Some(buffer.anchor_before(Point::new(start_row, column)));
}
let lines = text.split('\n').skip(start_row as usize);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
Itertools::intersperse(lines.take(count), "\n").collect()
Itertools::intersperse(lines.take(count as usize), "\n").collect()
} else {
Itertools::intersperse(lines, "\n").collect()
}
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.track_buffer(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result)
} else {
// No line ranges specified, so check file size to see if it's too big.
@ -165,7 +195,8 @@ impl Tool for ReadFileTool {
})
}
}
}).into()
})
.into()
}
}

View file

@ -170,7 +170,7 @@ impl Tool for StreamingEditFileTool {
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {