diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 8fdcbbcb58..bf2f42640d 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -1,7 +1,7 @@ use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent}; use anyhow::Result; use buffer_diff::DiffHunkStatus; -use collections::HashSet; +use collections::{HashMap, HashSet}; use editor::{ Direction, Editor, EditorEvent, MultiBuffer, ToPoint, actions::{GoToHunk, GoToPreviousHunk}, @@ -355,16 +355,24 @@ impl AgentDiff { self.update_selection(&diff_hunks_in_ranges, window, cx); } + let mut ranges_by_buffer = HashMap::default(); for hunk in &diff_hunks_in_ranges { let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id); if let Some(buffer) = buffer { - self.thread - .update(cx, |thread, cx| { - thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx) - }) - .detach_and_log_err(cx); + ranges_by_buffer + .entry(buffer.clone()) + .or_insert_with(Vec::new) + .push(hunk.buffer_range.clone()); } } + + for (buffer, ranges) in ranges_by_buffer { + self.thread + .update(cx, |thread, cx| { + thread.reject_edits_in_ranges(buffer, ranges, cx) + }) + .detach_and_log_err(cx); + } } fn update_selection( diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 4dfefdc3e6..af81499b5f 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1801,14 +1801,14 @@ impl Thread { .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); } - pub fn reject_edits_in_range( + pub fn reject_edits_in_ranges( &mut self, buffer: Entity, - buffer_range: Range, + buffer_ranges: Vec>, cx: &mut Context, ) -> Task> { self.action_log.update(cx, |action_log, cx| { - action_log.reject_edits_in_range(buffer, buffer_range, cx) + action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx) }) } diff --git a/crates/assistant_tool/src/action_log.rs b/crates/assistant_tool/src/action_log.rs index fa305a512e..3a10d5ec74 100644 --- a/crates/assistant_tool/src/action_log.rs +++ b/crates/assistant_tool/src/action_log.rs @@ -3,7 +3,7 @@ use buffer_diff::BufferDiff; use collections::BTreeMap; use futures::{StreamExt, channel::mpsc}; use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity}; -use language::{Anchor, Buffer, BufferEvent, DiskState, Point}; +use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint}; use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle}; use std::{cmp, ops::Range, sync::Arc}; use text::{Edit, Patch, Rope}; @@ -363,10 +363,10 @@ impl ActionLog { } } - pub fn reject_edits_in_range( + pub fn reject_edits_in_ranges( &mut self, buffer: Entity, - buffer_range: Range, + buffer_ranges: Vec>, cx: &mut Context, ) -> Task> { let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else { @@ -403,29 +403,15 @@ impl ActionLog { } TrackedBufferStatus::Modified => { buffer.update(cx, |buffer, cx| { - let buffer_range = - buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer); + let mut buffer_row_ranges = buffer_ranges + .into_iter() + .map(|range| { + range.start.to_point(buffer).row..range.end.to_point(buffer).row + }) + .peekable(); let mut edits_to_revert = Vec::new(); for edit in tracked_buffer.unreviewed_changes.edits() { - if buffer_range.end.row < edit.new.start { - break; - } else if buffer_range.start.row > edit.new.end { - continue; - } - - let old_range = tracked_buffer - .base_text - .point_to_offset(Point::new(edit.old.start, 0)) - ..tracked_buffer.base_text.point_to_offset(cmp::min( - Point::new(edit.old.end, 0), - tracked_buffer.base_text.max_point(), - )); - let old_text = tracked_buffer - .base_text - .chunks_in_range(old_range) - .collect::(); - let new_range = tracked_buffer .snapshot .anchor_before(Point::new(edit.new.start, 0)) @@ -433,7 +419,35 @@ impl ActionLog { Point::new(edit.new.end, 0), tracked_buffer.snapshot.max_point(), )); - edits_to_revert.push((new_range, old_text)); + let new_row_range = new_range.start.to_point(buffer).row + ..new_range.end.to_point(buffer).row; + + let mut revert = false; + while let Some(buffer_row_range) = buffer_row_ranges.peek() { + if buffer_row_range.end < new_row_range.start { + buffer_row_ranges.next(); + } else if buffer_row_range.start > new_row_range.end { + break; + } else { + revert = true; + break; + } + } + + if revert { + let old_range = tracked_buffer + .base_text + .point_to_offset(Point::new(edit.old.start, 0)) + ..tracked_buffer.base_text.point_to_offset(cmp::min( + Point::new(edit.old.end, 0), + tracked_buffer.base_text.max_point(), + )); + let old_text = tracked_buffer + .base_text + .chunks_in_range(old_range) + .collect::(); + edits_to_revert.push((new_range, old_text)); + } } buffer.edit(edits_to_revert, None, cx); @@ -599,6 +613,7 @@ fn point_to_row_edit(edit: Edit, old_text: &Rope, new_text: &Rope) -> Edi } } +#[derive(Copy, Clone, Debug)] enum ChangeAuthor { User, Agent, @@ -1135,9 +1150,48 @@ mod tests { )] ); + // If the rejected range doesn't overlap with any hunk, we ignore it. action_log .update(cx, |log, cx| { - log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx) + log.reject_edits_in_ranges( + buffer.clone(), + vec![Point::new(4, 0)..Point::new(4, 0)], + cx, + ) + }) + .await + .unwrap(); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndE\nXYZf\nghi\njkl\nmnO" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![ + HunkStatus { + range: Point::new(1, 0)..Point::new(3, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\n".into(), + }, + HunkStatus { + range: Point::new(5, 0)..Point::new(5, 3), + diff_status: DiffHunkStatusKind::Modified, + old_text: "mno".into(), + } + ], + )] + ); + + action_log + .update(cx, |log, cx| { + log.reject_edits_in_ranges( + buffer.clone(), + vec![Point::new(0, 0)..Point::new(1, 0)], + cx, + ) }) .await .unwrap(); @@ -1160,7 +1214,11 @@ mod tests { action_log .update(cx, |log, cx| { - log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx) + log.reject_edits_in_ranges( + buffer.clone(), + vec![Point::new(4, 0)..Point::new(4, 0)], + cx, + ) }) .await .unwrap(); @@ -1172,6 +1230,82 @@ mod tests { assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); } + #[gpui::test(iterations = 10)] + async fn test_reject_multiple_edits(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| { + buffer + .edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx) + .unwrap() + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndE\nXYZf\nghi\njkl\nmnO" + ); + assert_eq!( + unreviewed_hunks(&action_log, cx), + vec![( + buffer.clone(), + vec![ + HunkStatus { + range: Point::new(1, 0)..Point::new(3, 0), + diff_status: DiffHunkStatusKind::Modified, + old_text: "def\n".into(), + }, + HunkStatus { + range: Point::new(5, 0)..Point::new(5, 3), + diff_status: DiffHunkStatusKind::Modified, + old_text: "mno".into(), + } + ], + )] + ); + + action_log.update(cx, |log, cx| { + let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0)) + ..buffer.read(cx).anchor_before(Point::new(1, 0)); + let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0)) + ..buffer.read(cx).anchor_before(Point::new(5, 3)); + + log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx) + .detach(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndef\nghi\njkl\nmno" + ); + }); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.text()), + "abc\ndef\nghi\njkl\nmno" + ); + assert_eq!(unreviewed_hunks(&action_log, cx), vec![]); + } + #[gpui::test(iterations = 10)] async fn test_reject_deleted_file(cx: &mut TestAppContext) { init_test(cx); @@ -1215,7 +1349,11 @@ mod tests { action_log .update(cx, |log, cx| { - log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx) + log.reject_edits_in_ranges( + buffer.clone(), + vec![Point::new(0, 0)..Point::new(0, 0)], + cx, + ) }) .await .unwrap(); @@ -1266,7 +1404,11 @@ mod tests { action_log .update(cx, |log, cx| { - log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx) + log.reject_edits_in_ranges( + buffer.clone(), + vec![Point::new(0, 0)..Point::new(0, 11)], + cx, + ) }) .await .unwrap(); @@ -1312,7 +1454,7 @@ mod tests { .update(cx, |log, cx| { let range = buffer.read(cx).random_byte_range(0, &mut rng); log::info!("rejecting edits in range {:?}", range); - log.reject_edits_in_range(buffer.clone(), range, cx) + log.reject_edits_in_ranges(buffer.clone(), vec![range], cx) }) .await .unwrap();