From 277a3f8d6ffb6f36aab8a1a45b4f8e52d607b8c2 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 4 Apr 2025 13:20:18 +0200 Subject: [PATCH] Implement edit rejection in `ActionLog` (#28080) Release Notes: - Fixed a bug that would prevent rejecting certain agent edits. --- crates/agent/src/agent_diff.rs | 21 +- crates/agent/src/thread.rs | 17 +- crates/assistant_tool/src/action_log.rs | 369 ++++++++++++++++++++++-- 3 files changed, 368 insertions(+), 39 deletions(-) diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index ec1043e6fb..6f432ccbb2 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -3,7 +3,7 @@ use anyhow::Result; use buffer_diff::DiffHunkStatus; use collections::HashSet; use editor::{ - AnchorRangeExt, Direction, Editor, EditorEvent, MultiBuffer, ToPoint, + Direction, Editor, EditorEvent, MultiBuffer, ToPoint, actions::{GoToHunk, GoToPreviousHunk}, scroll::Autoscroll, }; @@ -350,13 +350,16 @@ impl AgentDiff { self.update_selection(&diff_hunks_in_ranges, window, cx); } - let point_ranges = ranges - .into_iter() - .map(|range| range.to_point(&snapshot)) - .collect(); - self.editor.update(cx, |editor, cx| { - editor.restore_hunks_in_ranges(point_ranges, window, cx) - }); + for hunk in &diff_hunks_in_ranges { + let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id); + if let Some(buffer) = buffer { + self.thread + .update(cx, |thread, cx| { + thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx) + }) + .detach_and_log_err(cx); + } + } } fn update_selection( @@ -986,7 +989,7 @@ mod tests { Point::new(3, 0)..Point::new(3, 0) ); - // Restoring a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end. + // Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end. editor.update_in(cx, |editor, window, cx| { editor.change_selections(None, window, cx, |selections| { selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)]) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 207f47cae0..611c93e9b4 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -290,7 +290,7 @@ impl Thread { last_restore_checkpoint: None, pending_checkpoint: None, tool_use: ToolUseState::new(tools.clone()), - action_log: cx.new(|_| ActionLog::new()), + action_log: cx.new(|_| ActionLog::new(project.clone())), initial_project_snapshot: { let project_snapshot = Self::project_snapshot(project, cx); cx.foreground_executor() @@ -354,11 +354,11 @@ impl Thread { pending_completions: Vec::new(), last_restore_checkpoint: None, pending_checkpoint: None, - project, + project: project.clone(), prompt_builder, tools, tool_use, - action_log: cx.new(|_| ActionLog::new()), + action_log: cx.new(|_| ActionLog::new(project)), initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), cumulative_token_usage: serialized.cumulative_token_usage, feedback: None, @@ -1757,6 +1757,17 @@ impl Thread { .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); } + pub fn reject_edits_in_range( + &mut self, + buffer: Entity, + buffer_range: Range, + cx: &mut Context, + ) -> Task> { + self.action_log.update(cx, |action_log, cx| { + action_log.reject_edits_in_range(buffer, buffer_range, cx) + }) + } + pub fn action_log(&self) -> &Entity { &self.action_log } diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index b8dd354b09..f839570f94 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -4,6 +4,7 @@ use collections::BTreeMap; use futures::{StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; use language::{Anchor, Buffer, BufferEvent, DiskState, Point}; +use project::{Project, ProjectItem}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; use util::RangeExt; @@ -14,14 +15,17 @@ pub struct ActionLog { tracked_buffers: BTreeMap, TrackedBuffer>, /// Has the model edited a file since it last checked diagnostics? edited_since_project_diagnostics_check: bool, + /// The project this action log is associated with + project: Entity, } impl ActionLog { - /// Creates a new, empty action log. - pub fn new() -> Self { + /// Creates a new, empty action log associated with the given project. + pub fn new(project: Entity) -> Self { Self { tracked_buffers: BTreeMap::default(), edited_since_project_diagnostics_check: false, + project, } } @@ -324,14 +328,14 @@ impl ActionLog { { true } else { - let old_bytes = tracked_buffer + let old_range = tracked_buffer .base_text .point_to_offset(Point::new(edit.old.start, 0)) ..tracked_buffer.base_text.point_to_offset(cmp::min( Point::new(edit.old.end, 0), tracked_buffer.base_text.max_point(), )); - let new_bytes = tracked_buffer + let new_range = tracked_buffer .snapshot .point_to_offset(Point::new(edit.new.start, 0)) ..tracked_buffer.snapshot.point_to_offset(cmp::min( @@ -339,10 +343,10 @@ impl ActionLog { tracked_buffer.snapshot.max_point(), )); tracked_buffer.base_text.replace( - old_bytes, + old_range, &tracked_buffer .snapshot - .text_for_range(new_bytes) + .text_for_range(new_range) .collect::(), ); delta += edit.new_len() as i32 - edit.old_len() as i32; @@ -354,6 +358,87 @@ impl ActionLog { } } + pub fn reject_edits_in_range( + &mut self, + buffer: Entity, + buffer_range: Range, + cx: &mut Context, + ) -> Task> { + let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else { + return Task::ready(Ok(())); + }; + + match tracked_buffer.status { + TrackedBufferStatus::Created => { + let delete = buffer + .read(cx) + .entry_id(cx) + .and_then(|entry_id| { + self.project + .update(cx, |project, cx| project.delete_entry(entry_id, false, cx)) + }) + .unwrap_or(Task::ready(Ok(()))); + self.tracked_buffers.remove(&buffer); + cx.notify(); + delete + } + TrackedBufferStatus::Deleted => { + buffer.update(cx, |buffer, cx| { + buffer.set_text(tracked_buffer.base_text.to_string(), cx) + }); + let save = self + .project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + + // Clear all tracked changes for this buffer and start over as if we just read it. + self.tracked_buffers.remove(&buffer); + self.track_buffer(buffer.clone(), false, cx); + cx.notify(); + save + } + TrackedBufferStatus::Modified => { + buffer.update(cx, |buffer, cx| { + let buffer_range = + buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer); + + let mut edits_to_revert = Vec::new(); + for edit in tracked_buffer.unreviewed_changes.edits() { + if buffer_range.end.row < edit.new.start { + break; + } else if buffer_range.start.row > edit.new.end { + continue; + } + + let old_range = tracked_buffer + .base_text + .point_to_offset(Point::new(edit.old.start, 0)) + ..tracked_buffer.base_text.point_to_offset(cmp::min( + Point::new(edit.old.end, 0), + tracked_buffer.base_text.max_point(), + )); + let old_text = tracked_buffer + .base_text + .chunks_in_range(old_range) + .collect::(); + + let new_range = tracked_buffer + .snapshot + .anchor_before(Point::new(edit.new.start, 0)) + ..tracked_buffer.snapshot.anchor_after(cmp::min( + Point::new(edit.new.end, 0), + tracked_buffer.snapshot.max_point(), + )); + edits_to_revert.push((new_range, old_text)); + } + + buffer.edit(edits_to_revert, None, cx); + }); + self.project + .update(cx, |project, cx| project.save_buffer(buffer, cx)) + } + } + } + pub fn keep_all_edits(&mut self, cx: &mut Context) { self.tracked_buffers .retain(|_buffer, tracked_buffer| match tracked_buffer.status { @@ -575,9 +660,22 @@ mod tests { } } + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + #[gpui::test(iterations = 10)] async fn test_keep_edits(cx: &mut TestAppContext) { - let action_log = cx.new(|_| ActionLog::new()); + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx)); cx.update(|cx| { @@ -643,7 +741,11 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_deletions(cx: &mut TestAppContext) { - let action_log = cx.new(|_| ActionLog::new()); + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx)); cx.update(|cx| { @@ -713,7 +815,11 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_overlapping_user_edits(cx: &mut TestAppContext) { - let action_log = cx.new(|_| ActionLog::new()); + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx)); cx.update(|cx| { @@ -797,15 +903,12 @@ mod tests { } #[gpui::test(iterations = 10)] - async fn test_creation(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); + async fn test_creating_files(cx: &mut TestAppContext) { + init_test(cx); - let action_log = cx.new(|_| ActionLog::new()); + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); let fs = FakeFs::new(cx.executor()); fs.insert_tree(path!("/dir"), json!({})).await; @@ -864,12 +967,7 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_deleting_files(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); + init_test(cx); let fs = FakeFs::new(cx.executor()); fs.insert_tree( @@ -886,7 +984,7 @@ mod tests { .read_with(cx, |project, cx| project.find_project_path("dir/file2", cx)) .unwrap(); - let action_log = cx.new(|_| ActionLog::new()); + let action_log = cx.new(|_| ActionLog::new(project.clone())); let buffer1 = project .update(cx, |project, cx| { project.open_buffer(file1_path.clone(), cx) @@ -976,15 +1074,222 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + #[gpui::test(iterations = 10)] + async fn test_reject_edits(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx) + .unwrap() + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndE\nXYZf\nghi\njkl\nmnO" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![ + HunkStatus { + range: Point::new(1, 0)..Point::new(3, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\n".into(), + }, + HunkStatus { + range: Point::new(5, 0)..Point::new(5, 3), + diff_status: DiffHunkStatusKind::Modified, + old_text: "mno".into(), + } + ], + )] + ); + + action_log + .update(cx, |log, cx| { + log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndef\nghi\njkl\nmnO" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(4, 0)..Point::new(4, 3), + diff_status: DiffHunkStatusKind::Modified, + old_text: "mno".into(), + }], + )] + ); + + action_log + .update(cx, |log, cx| { + log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndef\nghi\njkl\nmno" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + + #[gpui::test(iterations = 10)] + async fn test_reject_deleted_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "content"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path.clone(), cx)) + .await + .unwrap(); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.will_delete_buffer(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| { + project.delete_file(file_path.clone(), false, cx) + }) + .unwrap() + .await + .unwrap(); + cx.run_until_parked(); + assert!(!fs.is_file(path!("/dir/file").as_ref()).await); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(0, 0)..Point::new(0, 0), + diff_status: DiffHunkStatusKind::Deleted, + old_text: "content".into(), + }] + )] + ); + + action_log + .update(cx, |log, cx| { + log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "content"); + assert!(fs.is_file(path!("/dir/file").as_ref()).await); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + + #[gpui::test(iterations = 10)] + async fn test_reject_created_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + cx.update(|cx| { + buffer.update(cx, |buffer, cx| buffer.set_text("content", cx)); + action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx)); + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + assert!(fs.is_file(path!("/dir/new_file").as_ref()).await); + cx.run_until_parked(); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![HunkStatus { + range: Point::new(0, 0)..Point::new(0, 7), + diff_status: DiffHunkStatusKind::Added, + old_text: "".into(), + }], + )] + ); + + action_log + .update(cx, |log, cx| { + log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert!(!fs.is_file(path!("/dir/new_file").as_ref()).await); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 100)] async fn test_random_diffs(mut rng: StdRng, cx: &mut TestAppContext) { + init_test(cx); + let operations = env::var("OPERATIONS") .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) .unwrap_or(20); - let action_log = cx.new(|_| ActionLog::new()); let text = RandomCharIter::new(&mut rng).take(50).collect::(); - let buffer = cx.new(|cx| Buffer::local(text, cx)); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": text})).await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); for _ in 0..operations { @@ -992,10 +1297,20 @@ mod tests { 0..25 => { action_log.update(cx, |log, cx| { let range = buffer.read(cx).random_byte_range(0, &mut rng); - log::info!("keeping all edits in range {:?}", range); + log::info!("keeping edits in range {:?}", range); log.keep_edits_in_range(buffer.clone(), range, cx) }); } + 25..50 => { + action_log + .update(cx, |log, cx| { + let range = buffer.read(cx).random_byte_range(0, &mut rng); + log::info!("rejecting edits in range {:?}", range); + log.reject_edits_in_range(buffer.clone(), range, cx) + }) + .await + .unwrap(); + } _ => { let is_agent_change = rng.gen_bool(0.5); if is_agent_change {