diff --git a/Cargo.lock b/Cargo.lock index 711122d5a0..8d581b6aec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,9 +658,9 @@ name = "assistant_tools" version = "0.1.0" dependencies = [ "agent_settings", - "aho-corasick", "anyhow", "assistant_tool", + "async-watch", "buffer_diff", "chrono", "client", diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 78a0f855ef..004a9ead7b 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -3414,8 +3414,8 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief".into()); - fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.stream_last_completion_response("Brief"); + fake_model.stream_last_completion_response(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -3508,7 +3508,7 @@ fn main() {{ }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.stream_last_completion_response("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -3550,7 +3550,7 @@ fn main() {{ fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.stream_last_completion_response("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/assistant_context_editor/src/context/context_tests.rs b/crates/assistant_context_editor/src/context/context_tests.rs index 2379ec4474..dba3bfde61 100644 --- a/crates/assistant_context_editor/src/context/context_tests.rs +++ b/crates/assistant_context_editor/src/context/context_tests.rs @@ -1210,8 +1210,8 @@ async fn test_summarization(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("Brief".into()); - fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.stream_last_completion_response("Brief"); + fake_model.stream_last_completion_response(" Introduction"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1274,7 +1274,7 @@ async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { }); cx.run_until_parked(); - fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.stream_last_completion_response("A successful summary"); fake_model.end_last_completion_stream(); cx.run_until_parked(); @@ -1356,7 +1356,7 @@ fn setup_context_editor_with_fake_model( fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { cx.run_until_parked(); - fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.stream_last_completion_response("Assistant response"); fake_model.end_last_completion_stream(); cx.run_until_parked(); } diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 8ab8ab67db..8f81f1a695 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -16,9 +16,9 @@ eval = [] [dependencies] agent_settings.workspace = true -aho-corasick.workspace = true anyhow.workspace = true assistant_tool.workspace = true +async-watch.workspace = true buffer_diff.workspace = true chrono.workspace = true collections.workspace = true diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/assistant_tools/src/edit_agent.rs index d8e0ddfd3d..edff6cd70a 100644 --- a/crates/assistant_tools/src/edit_agent.rs +++ b/crates/assistant_tools/src/edit_agent.rs @@ -2,9 +2,9 @@ mod create_file_parser; mod edit_parser; #[cfg(test)] mod evals; +mod streaming_fuzzy_matcher; use crate::{Template, Templates}; -use aho_corasick::AhoCorasick; use anyhow::Result; use assistant_tool::ActionLog; use create_file_parser::{CreateFileParser, CreateFileParserEvent}; @@ -15,8 +15,8 @@ use futures::{ pin_mut, stream::BoxStream, }; -use gpui::{AppContext, AsyncApp, Entity, SharedString, Task}; -use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point}; +use gpui::{AppContext, AsyncApp, Entity, Task}; +use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelToolChoice, MessageContent, Role, @@ -24,8 +24,9 @@ use language_model::{ use project::{AgentLocation, Project}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll}; +use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll}; use streaming_diff::{CharOperation, StreamingDiff}; +use streaming_fuzzy_matcher::StreamingFuzzyMatcher; use util::debug_panic; #[derive(Serialize)] @@ -50,8 +51,9 @@ impl Template for EditFilePromptTemplate { #[derive(Clone, Debug, PartialEq, Eq)] pub enum EditAgentOutputEvent { + ResolvingEditRange(Range), + UnresolvedEditRange, Edited, - OldTextNotFound(SharedString), } #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] @@ -132,8 +134,6 @@ impl EditAgent { .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?; this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx) .await?; - this.project - .update(cx, |project, cx| project.set_agent_location(None, cx))?; parse_task.await }); (task, output_events_rx) @@ -202,18 +202,6 @@ impl EditAgent { Task>, mpsc::UnboundedReceiver, ) { - self.project - .update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: language::Anchor::MIN, - }), - cx, - ); - }) - .ok(); - let this = self.clone(); let (events_tx, events_rx) = mpsc::unbounded(); let conversation = conversation.clone(); @@ -226,139 +214,74 @@ impl EditAgent { } .render(&this.templates)?; let edit_chunks = this.request(conversation, prompt, cx).await?; - - let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx); - while let Some(event) = inner_events.next().await { - events_tx.unbounded_send(event).ok(); - } - output.await + this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx) + .await }); (output, events_rx) } - fn apply_edit_chunks( - &self, - buffer: Entity, - edit_chunks: impl 'static + Send + Stream>, - cx: &mut AsyncApp, - ) -> ( - Task>, - mpsc::UnboundedReceiver, - ) { - let (output_events_tx, output_events_rx) = mpsc::unbounded(); - let this = self.clone(); - let task = cx.spawn(async move |mut cx| { - this.action_log - .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?; - let output = this - .apply_edit_chunks_internal(buffer, edit_chunks, output_events_tx, &mut cx) - .await; - this.project - .update(cx, |project, cx| project.set_agent_location(None, cx))?; - output - }); - (task, output_events_rx) - } - - async fn apply_edit_chunks_internal( + async fn apply_edit_chunks( &self, buffer: Entity, edit_chunks: impl 'static + Send + Stream>, output_events: mpsc::UnboundedSender, cx: &mut AsyncApp, ) -> Result { - let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx); - while let Some(edit_event) = edit_events.next().await { - let EditParserEvent::OldText(old_text_query) = edit_event? else { + self.action_log + .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?; + + let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx); + let mut edit_events = edit_events.peekable(); + while let Some(edit_event) = Pin::new(&mut edit_events).peek().await { + // Skip events until we're at the start of a new edit. + let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else { + edit_events.next().await.unwrap()?; continue; }; - // Skip edits with an empty old text. - if old_text_query.is_empty() { - continue; + let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + + // Resolve the old text in the background, updating the agent + // location as we keep refining which range it corresponds to. + let (resolve_old_text, mut old_range) = + Self::resolve_old_text(snapshot.text.clone(), edit_events, cx); + while let Ok(old_range) = old_range.recv().await { + if let Some(old_range) = old_range { + let old_range = snapshot.anchor_before(old_range.start) + ..snapshot.anchor_before(old_range.end); + self.project.update(cx, |project, cx| { + project.set_agent_location( + Some(AgentLocation { + buffer: buffer.downgrade(), + position: old_range.end, + }), + cx, + ); + })?; + output_events + .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range)) + .ok(); + } } - let old_text_query = SharedString::from(old_text_query); + let (edit_events_, resolved_old_text) = resolve_old_text.await?; + edit_events = edit_events_; - let (edits_tx, edits_rx) = mpsc::unbounded(); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let old_range = cx - .background_spawn({ - let snapshot = snapshot.clone(); - let old_text_query = old_text_query.clone(); - async move { Self::resolve_location(&snapshot, &old_text_query) } - }) - .await; - let Some(old_range) = old_range else { - // We couldn't find the old text in the buffer. Report the error. + // If we can't resolve the old text, restart the loop waiting for a + // new edit (or for the stream to end). + let Some(resolved_old_text) = resolved_old_text else { output_events - .unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query)) + .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange) .ok(); continue; }; - let compute_edits = cx.background_spawn(async move { - let buffer_start_indent = - snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row); - let old_text_start_indent = old_text_query - .lines() - .next() - .map_or(buffer_start_indent, |line| { - LineIndent::from_iter(line.chars()) - }); - let indent_delta = if buffer_start_indent.tabs > 0 { - IndentDelta::Tabs( - buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize, - ) - } else { - IndentDelta::Spaces( - buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize, - ) - }; - - let old_text = snapshot - .text_for_range(old_range.clone()) - .collect::(); - let mut diff = StreamingDiff::new(old_text); - let mut edit_start = old_range.start; - let mut new_text_chunks = - Self::reindent_new_text_chunks(indent_delta, &mut edit_events); - let mut done = false; - while !done { - let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await - { - diff.push_new(&new_text_chunk?) - } else { - done = true; - mem::take(&mut diff).finish() - }; - - for op in char_operations { - match op { - CharOperation::Insert { text } => { - let edit_start = snapshot.anchor_after(edit_start); - edits_tx - .unbounded_send((edit_start..edit_start, Arc::from(text)))?; - } - CharOperation::Delete { bytes } => { - let edit_end = edit_start + bytes; - let edit_range = snapshot.anchor_after(edit_start) - ..snapshot.anchor_before(edit_end); - edit_start = edit_end; - edits_tx.unbounded_send((edit_range, Arc::from("")))?; - } - CharOperation::Keep { bytes } => edit_start += bytes, - } - } - } - - drop(new_text_chunks); - anyhow::Ok(edit_events) - }); - - // TODO: group all edits into one transaction - let mut edits_rx = edits_rx.ready_chunks(32); - while let Some(edits) = edits_rx.next().await { + // Compute edits in the background and apply them as they become + // available. + let (compute_edits, edits) = + Self::compute_edits(snapshot, resolved_old_text, edit_events, cx); + let mut edits = edits.ready_chunks(32); + while let Some(edits) = edits.next().await { if edits.is_empty() { continue; } @@ -472,6 +395,118 @@ impl EditAgent { (output, rx) } + fn resolve_old_text( + snapshot: TextBufferSnapshot, + mut edit_events: T, + cx: &mut AsyncApp, + ) -> ( + Task)>>, + async_watch::Receiver>>, + ) + where + T: 'static + Send + Unpin + Stream>, + { + let (old_range_tx, old_range_rx) = async_watch::channel(None); + let task = cx.background_spawn(async move { + let mut matcher = StreamingFuzzyMatcher::new(snapshot); + while let Some(edit_event) = edit_events.next().await { + let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else { + break; + }; + + old_range_tx.send(matcher.push(&chunk))?; + if done { + break; + } + } + + let old_range = matcher.finish(); + old_range_tx.send(old_range.clone())?; + if let Some(old_range) = old_range { + let line_indent = + LineIndent::from_iter(matcher.query_lines().first().unwrap().chars()); + Ok(( + edit_events, + Some(ResolvedOldText { + range: old_range, + indent: line_indent, + }), + )) + } else { + Ok((edit_events, None)) + } + }); + + (task, old_range_rx) + } + + fn compute_edits( + snapshot: BufferSnapshot, + resolved_old_text: ResolvedOldText, + mut edit_events: T, + cx: &mut AsyncApp, + ) -> ( + Task>, + UnboundedReceiver<(Range, Arc)>, + ) + where + T: 'static + Send + Unpin + Stream>, + { + let (edits_tx, edits_rx) = mpsc::unbounded(); + let compute_edits = cx.background_spawn(async move { + let buffer_start_indent = snapshot + .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row); + let indent_delta = if buffer_start_indent.tabs > 0 { + IndentDelta::Tabs( + buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize, + ) + } else { + IndentDelta::Spaces( + buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize, + ) + }; + + let old_text = snapshot + .text_for_range(resolved_old_text.range.clone()) + .collect::(); + let mut diff = StreamingDiff::new(old_text); + let mut edit_start = resolved_old_text.range.start; + let mut new_text_chunks = + Self::reindent_new_text_chunks(indent_delta, &mut edit_events); + let mut done = false; + while !done { + let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await { + diff.push_new(&new_text_chunk?) + } else { + done = true; + mem::take(&mut diff).finish() + }; + + for op in char_operations { + match op { + CharOperation::Insert { text } => { + let edit_start = snapshot.anchor_after(edit_start); + edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?; + } + CharOperation::Delete { bytes } => { + let edit_end = edit_start + bytes; + let edit_range = + snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end); + edit_start = edit_end; + edits_tx.unbounded_send((edit_range, Arc::from("")))?; + } + CharOperation::Keep { bytes } => edit_start += bytes, + } + } + } + + drop(new_text_chunks); + anyhow::Ok(edit_events) + }); + + (compute_edits, edits_rx) + } + fn reindent_new_text_chunks( delta: IndentDelta, mut stream: impl Unpin + Stream>, @@ -621,134 +656,11 @@ impl EditAgent { Ok(self.model.stream_completion_text(request, cx).await?.stream) } - - fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option> { - let range = Self::resolve_location_exact(buffer, search_query) - .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?; - - // Expand the range to include entire lines. - let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left)); - start.column = 0; - let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right)); - if end.column > 0 { - end.column = buffer.line_len(end.row); - } - - Some(buffer.point_to_offset(start)..buffer.point_to_offset(end)) - } - - fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option> { - let search = AhoCorasick::new([search_query]).ok()?; - let mat = search - .stream_find_iter(buffer.bytes_in_range(0..buffer.len())) - .next()? - .expect("buffer can't error"); - Some(mat.range()) - } - - fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option> { - const INSERTION_COST: u32 = 3; - const DELETION_COST: u32 = 10; - - let buffer_line_count = buffer.max_point().row as usize + 1; - let query_line_count = search_query.lines().count(); - let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1); - let mut leading_deletion_cost = 0_u32; - for (row, query_line) in search_query.lines().enumerate() { - let query_line = query_line.trim(); - leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST); - matrix.set( - row + 1, - 0, - SearchState::new(leading_deletion_cost, SearchDirection::Diagonal), - ); - - let mut buffer_lines = buffer.as_rope().chunks().lines(); - let mut col = 0; - while let Some(buffer_line) = buffer_lines.next() { - let buffer_line = buffer_line.trim(); - let up = SearchState::new( - matrix.get(row, col + 1).cost.saturating_add(DELETION_COST), - SearchDirection::Up, - ); - let left = SearchState::new( - matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST), - SearchDirection::Left, - ); - let diagonal = SearchState::new( - if fuzzy_eq(query_line, buffer_line) { - matrix.get(row, col).cost - } else { - matrix - .get(row, col) - .cost - .saturating_add(DELETION_COST + INSERTION_COST) - }, - SearchDirection::Diagonal, - ); - matrix.set(row + 1, col + 1, up.min(left).min(diagonal)); - col += 1; - } - } - - // Traceback to find the best match - let mut buffer_row_end = buffer_line_count as u32; - let mut best_cost = u32::MAX; - for col in 1..=buffer_line_count { - let cost = matrix.get(query_line_count, col).cost; - if cost < best_cost { - best_cost = cost; - buffer_row_end = col as u32; - } - } - - let mut matched_lines = 0; - let mut query_row = query_line_count; - let mut buffer_row_start = buffer_row_end; - while query_row > 0 && buffer_row_start > 0 { - let current = matrix.get(query_row, buffer_row_start as usize); - match current.direction { - SearchDirection::Diagonal => { - query_row -= 1; - buffer_row_start -= 1; - matched_lines += 1; - } - SearchDirection::Up => { - query_row -= 1; - } - SearchDirection::Left => { - buffer_row_start -= 1; - } - } - } - - let matched_buffer_row_count = buffer_row_end - buffer_row_start; - let matched_ratio = - matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32); - if matched_ratio >= 0.8 { - let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0)); - let buffer_end_ix = buffer.point_to_offset(Point::new( - buffer_row_end - 1, - buffer.line_len(buffer_row_end - 1), - )); - Some(buffer_start_ix..buffer_end_ix) - } else { - None - } - } } -fn fuzzy_eq(left: &str, right: &str) -> bool { - const THRESHOLD: f64 = 0.8; - - let min_levenshtein = left.len().abs_diff(right.len()); - let min_normalized_levenshtein = - 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64); - if min_normalized_levenshtein < THRESHOLD { - return false; - } - - strsim::normalized_levenshtein(left, right) >= THRESHOLD +struct ResolvedOldText { + range: Range, + indent: LineIndent, } #[derive(Copy, Clone, Debug)] @@ -773,61 +685,18 @@ impl IndentDelta { } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -enum SearchDirection { - Up, - Left, - Diagonal, -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct SearchState { - cost: u32, - direction: SearchDirection, -} - -impl SearchState { - fn new(cost: u32, direction: SearchDirection) -> Self { - Self { cost, direction } - } -} - -struct SearchMatrix { - cols: usize, - data: Vec, -} - -impl SearchMatrix { - fn new(rows: usize, cols: usize) -> Self { - SearchMatrix { - cols, - data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols], - } - } - - fn get(&self, row: usize, col: usize) -> SearchState { - self.data[row * self.cols + col] - } - - fn set(&mut self, row: usize, col: usize, cost: SearchState) { - self.data[row * self.cols + col] = cost; - } -} - #[cfg(test)] mod tests { use super::*; use fs::FakeFs; use futures::stream; - use gpui::{App, AppContext, TestAppContext}; + use gpui::{AppContext, TestAppContext}; use indoc::indoc; use language_model::fake_provider::FakeLanguageModel; use project::{AgentLocation, Project}; use rand::prelude::*; use rand::rngs::StdRng; use std::cmp; - use unindent::Unindent; - use util::test::{generate_marked_text, marked_text_ranges}; #[gpui::test(iterations = 100)] async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) { @@ -842,7 +711,16 @@ mod tests { cx, ) }); - let raw_edits = simulate_llm_output( + let (apply, _events) = agent.edit( + buffer.clone(), + String::new(), + &LanguageModelRequest::default(), + &mut cx.to_async(), + ); + cx.run_until_parked(); + + simulate_llm_output( + &agent, indoc! {" jkl @@ -852,9 +730,8 @@ mod tests { &mut rng, cx, ); - let (apply, _events) = - agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async()); apply.await.unwrap(); + pretty_assertions::assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), indoc! {" @@ -879,7 +756,16 @@ mod tests { cx, ) }); - let raw_edits = simulate_llm_output( + let (apply, _events) = agent.edit( + buffer.clone(), + String::new(), + &LanguageModelRequest::default(), + &mut cx.to_async(), + ); + cx.run_until_parked(); + + simulate_llm_output( + &agent, indoc! {" ipsum @@ -896,9 +782,8 @@ mod tests { &mut rng, cx, ); - let (apply, _events) = - agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async()); apply.await.unwrap(); + pretty_assertions::assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), indoc! {" @@ -915,7 +800,16 @@ mod tests { async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) { let agent = init_test(cx).await; let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx)); - let raw_edits = simulate_llm_output( + let (apply, _events) = agent.edit( + buffer.clone(), + String::new(), + &LanguageModelRequest::default(), + &mut cx.to_async(), + ); + cx.run_until_parked(); + + simulate_llm_output( + &agent, indoc! {" def @@ -934,9 +828,8 @@ mod tests { &mut rng, cx, ); - let (apply, _events) = - agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async()); apply.await.unwrap(); + assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "abc\nDeF\nghi" @@ -947,7 +840,16 @@ mod tests { async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) { let agent = init_test(cx).await; let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx)); - let raw_edits = simulate_llm_output( + let (apply, _events) = agent.edit( + buffer.clone(), + String::new(), + &LanguageModelRequest::default(), + &mut cx.to_async(), + ); + cx.run_until_parked(); + + simulate_llm_output( + &agent, indoc! {" jkl @@ -966,9 +868,8 @@ mod tests { &mut rng, cx, ); - let (apply, _events) = - agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async()); apply.await.unwrap(); + assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), "ABC\ndef\nghi" @@ -978,47 +879,46 @@ mod tests { #[gpui::test] async fn test_edit_events(cx: &mut TestAppContext) { let agent = init_test(cx).await; + let model = agent.model.as_fake(); let project = agent .action_log .read_with(cx, |log, _| log.project().clone()); - let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx)); - let (chunks_tx, chunks_rx) = mpsc::unbounded(); - let (apply, mut events) = agent.apply_edit_chunks( + let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx)); + + let mut async_cx = cx.to_async(); + let (apply, mut events) = agent.edit( buffer.clone(), - chunks_rx.map(|chunk: &str| Ok(chunk.to_string())), - &mut cx.to_async(), + String::new(), + &LanguageModelRequest::default(), + &mut async_cx, ); + cx.run_until_parked(); - chunks_tx.unbounded_send("a").unwrap(); + model.stream_last_completion_response("a"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abc\ndef\nghi" + "abc\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), None ); - chunks_tx.unbounded_send("bc").unwrap(); + model.stream_last_completion_response("bc"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), vec![]); assert_eq!( - buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abc\ndef\nghi" + drain_events(&mut events), + vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with( + cx, + |buffer, _| buffer.anchor_before(Point::new(0, 0)) + ..buffer.anchor_before(Point::new(0, 3)) + ))] ); - assert_eq!( - project.read_with(cx, |project, _| project.agent_location()), - None - ); - - chunks_tx.unbounded_send("abX").unwrap(); - cx.run_until_parked(); - assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXc\ndef\nghi" + "abc\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), @@ -1028,12 +928,27 @@ mod tests { }) ); - chunks_tx.unbounded_send("cY").unwrap(); + model.stream_last_completion_response("abX"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXcY\ndef\nghi" + "abXc\ndef\nghi\njkl" + ); + assert_eq!( + project.read_with(cx, |project, _| project.agent_location()), + Some(AgentLocation { + buffer: buffer.downgrade(), + position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3))) + }) + ); + + model.stream_last_completion_response("cY"); + cx.run_until_parked(); + assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), + "abXcY\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), @@ -1043,13 +958,13 @@ mod tests { }) ); - chunks_tx.unbounded_send("").unwrap(); - chunks_tx.unbounded_send("hall").unwrap(); + model.stream_last_completion_response(""); + model.stream_last_completion_response("hall"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXcY\ndef\nghi" + "abXcY\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), @@ -1059,18 +974,16 @@ mod tests { }) ); - chunks_tx.unbounded_send("ucinated old").unwrap(); - chunks_tx.unbounded_send("").unwrap(); + model.stream_last_completion_response("ucinated old"); + model.stream_last_completion_response(""); cx.run_until_parked(); assert_eq!( drain_events(&mut events), - vec![EditAgentOutputEvent::OldTextNotFound( - "hallucinated old".into() - )] + vec![EditAgentOutputEvent::UnresolvedEditRange] ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXcY\ndef\nghi" + "abXcY\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), @@ -1080,13 +993,13 @@ mod tests { }) ); - chunks_tx.unbounded_send("hallucinated new").unwrap(); + model.stream_last_completion_response("hallucinated new"); cx.run_until_parked(); assert_eq!(drain_events(&mut events), vec![]); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXcY\ndef\nghi" + "abXcY\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), @@ -1096,24 +1009,52 @@ mod tests { }) ); - chunks_tx.unbounded_send("gh").unwrap(); - chunks_tx.unbounded_send("i").unwrap(); - chunks_tx.unbounded_send("").unwrap(); + model.stream_last_completion_response("\nghi\nj"); cx.run_until_parked(); - assert_eq!(drain_events(&mut events), vec![]); + assert_eq!( + drain_events(&mut events), + vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with( + cx, + |buffer, _| buffer.anchor_before(Point::new(2, 0)) + ..buffer.anchor_before(Point::new(2, 3)) + ))] + ); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), - "abXcY\ndef\nghi" + "abXcY\ndef\nghi\njkl" ); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), Some(AgentLocation { buffer: buffer.downgrade(), - position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5))) + position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3))) }) ); - chunks_tx.unbounded_send("GHI").unwrap(); + model.stream_last_completion_response("kl"); + model.stream_last_completion_response(""); + cx.run_until_parked(); + assert_eq!( + drain_events(&mut events), + vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with( + cx, + |buffer, _| buffer.anchor_before(Point::new(2, 0)) + ..buffer.anchor_before(Point::new(3, 3)) + ))] + ); + assert_eq!( + buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), + "abXcY\ndef\nghi\njkl" + ); + assert_eq!( + project.read_with(cx, |project, _| project.agent_location()), + Some(AgentLocation { + buffer: buffer.downgrade(), + position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3))) + }) + ); + + model.stream_last_completion_response("GHI"); cx.run_until_parked(); assert_eq!( drain_events(&mut events), @@ -1131,7 +1072,7 @@ mod tests { }) ); - drop(chunks_tx); + model.end_last_completion_stream(); apply.await.unwrap(); assert_eq!( buffer.read_with(cx, |buffer, _| buffer.snapshot().text()), @@ -1140,7 +1081,10 @@ mod tests { assert_eq!(drain_events(&mut events), vec![]); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), - None + Some(AgentLocation { + buffer: buffer.downgrade(), + position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3))) + }) ); } @@ -1238,162 +1182,10 @@ mod tests { assert_eq!(drain_events(&mut events), vec![]); assert_eq!( project.read_with(cx, |project, _| project.agent_location()), - None - ); - } - - #[gpui::test] - fn test_resolve_location(cx: &mut App) { - assert_location_resolution( - concat!( - " Lorem\n", - "« ipsum»\n", - " dolor sit amet\n", - " consecteur", - ), - "ipsum", - cx, - ); - - assert_location_resolution( - concat!( - " Lorem\n", - "« ipsum\n", - " dolor sit amet»\n", - " consecteur", - ), - "ipsum\ndolor sit amet", - cx, - ); - - assert_location_resolution( - &" - «fn foo1(a: usize) -> usize { - 40 - }» - - fn foo2(b: usize) -> usize { - 42 - } - " - .unindent(), - "fn foo1(a: usize) -> u32 {\n40\n}", - cx, - ); - - assert_location_resolution( - &" - class Something { - one() { return 1; } - « two() { return 2222; } - three() { return 333; } - four() { return 4444; } - five() { return 5555; } - six() { return 6666; }» - seven() { return 7; } - eight() { return 8; } - } - " - .unindent(), - &" - two() { return 2222; } - four() { return 4444; } - five() { return 5555; } - six() { return 6666; } - " - .unindent(), - cx, - ); - - assert_location_resolution( - &" - use std::ops::Range; - use std::sync::Mutex; - use std::{ - collections::HashMap, - env, - ffi::{OsStr, OsString}, - fs, - io::{BufRead, BufReader}, - mem, - path::{Path, PathBuf}, - process::Command, - sync::LazyLock, - time::SystemTime, - }; - " - .unindent(), - &" - use std::collections::{HashMap, HashSet}; - use std::ffi::{OsStr, OsString}; - use std::fmt::Write as _; - use std::fs; - use std::io::{BufReader, Read, Write}; - use std::mem; - use std::path::{Path, PathBuf}; - use std::process::Command; - use std::sync::Arc; - " - .unindent(), - cx, - ); - - assert_location_resolution( - indoc! {" - impl Foo { - fn new() -> Self { - Self { - subscriptions: vec![ - cx.observe_window_activation(window, |editor, window, cx| { - let active = window.is_window_active(); - editor.blink_manager.update(cx, |blink_manager, cx| { - if active { - blink_manager.enable(cx); - } else { - blink_manager.disable(cx); - } - }); - }), - ]; - } - } - } - "}, - concat!( - " editor.blink_manager.update(cx, |blink_manager, cx| {\n", - " blink_manager.enable(cx);\n", - " });", - ), - cx, - ); - - assert_location_resolution( - indoc! {r#" - let tool = cx - .update(|cx| working_set.tool(&tool_name, cx)) - .map_err(|err| { - anyhow!("Failed to look up tool '{}': {}", tool_name, err) - })?; - - let Some(tool) = tool else { - return Err(anyhow!("Tool '{}' not found", tool_name)); - }; - - let project = project.clone(); - let action_log = action_log.clone(); - let messages = messages.clone(); - let tool_result = cx - .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx)) - .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?; - - tasks.push(tool_result.output); - "#}, - concat!( - "let tool_result = cx\n", - " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n", - " .output;", - ), - cx, + Some(AgentLocation { + buffer: buffer.downgrade(), + position: language::Anchor::MAX + }) ); } @@ -1480,17 +1272,6 @@ mod tests { assert_eq!(actual_reindented_text, expected_reindented_text); } - #[track_caller] - fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) { - let (text, _) = marked_text_ranges(text_with_expected_range, false); - let buffer = cx.new(|cx| Buffer::local(text.clone(), cx)); - let snapshot = buffer.read(cx).snapshot(); - let mut ranges = Vec::new(); - ranges.extend(EditAgent::resolve_location(&snapshot, query)); - let text_with_actual_range = generate_marked_text(&text, &ranges, false); - pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range); - } - fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec { let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); @@ -1507,18 +1288,22 @@ mod tests { } fn simulate_llm_output( + agent: &EditAgent, output: &str, rng: &mut StdRng, cx: &mut TestAppContext, - ) -> impl 'static + Send + Stream> { + ) { let executor = cx.executor(); - stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| { - let executor = executor.clone(); - async move { + let chunks = to_random_chunks(rng, output); + let model = agent.model.clone(); + cx.background_spawn(async move { + for chunk in chunks { executor.simulate_random_delay().await; - chunk + model.as_fake().stream_last_completion_response(chunk); } + model.as_fake().end_last_completion_stream(); }) + .detach(); } async fn init_test(cx: &mut TestAppContext) -> EditAgent { diff --git a/crates/assistant_tools/src/edit_agent/edit_parser.rs b/crates/assistant_tools/src/edit_agent/edit_parser.rs index ac6a40f9c9..3829491560 100644 --- a/crates/assistant_tools/src/edit_agent/edit_parser.rs +++ b/crates/assistant_tools/src/edit_agent/edit_parser.rs @@ -11,7 +11,7 @@ const END_TAGS: [&str; 3] = [OLD_TEXT_END_TAG, NEW_TEXT_END_TAG, EDITS_END_TAG]; #[derive(Debug)] pub enum EditParserEvent { - OldText(String), + OldTextChunk { chunk: String, done: bool }, NewTextChunk { chunk: String, done: bool }, } @@ -33,7 +33,7 @@ pub struct EditParser { #[derive(Debug, PartialEq)] enum EditParserState { Pending, - WithinOldText, + WithinOldText { start: bool }, AfterOldText, WithinNewText { start: bool }, } @@ -56,20 +56,23 @@ impl EditParser { EditParserState::Pending => { if let Some(start) = self.buffer.find("") { self.buffer.drain(..start + "".len()); - self.state = EditParserState::WithinOldText; + self.state = EditParserState::WithinOldText { start: true }; } else { break; } } - EditParserState::WithinOldText => { - if let Some(tag_range) = self.find_end_tag() { - let mut start = 0; - if self.buffer.starts_with('\n') { - start = 1; + EditParserState::WithinOldText { start } => { + if !self.buffer.is_empty() { + if *start && self.buffer.starts_with('\n') { + self.buffer.remove(0); } - let mut old_text = self.buffer[start..tag_range.start].to_string(); - if old_text.ends_with('\n') { - old_text.pop(); + *start = false; + } + + if let Some(tag_range) = self.find_end_tag() { + let mut chunk = self.buffer[..tag_range.start].to_string(); + if chunk.ends_with('\n') { + chunk.pop(); } self.metrics.tags += 1; @@ -79,8 +82,14 @@ impl EditParser { self.buffer.drain(..tag_range.end); self.state = EditParserState::AfterOldText; - edit_events.push(EditParserEvent::OldText(old_text)); + edit_events.push(EditParserEvent::OldTextChunk { chunk, done: true }); } else { + if !self.ends_with_tag_prefix() { + edit_events.push(EditParserEvent::OldTextChunk { + chunk: mem::take(&mut self.buffer), + done: false, + }); + } break; } } @@ -115,11 +124,7 @@ impl EditParser { self.state = EditParserState::Pending; edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true }); } else { - let mut end_prefixes = END_TAGS - .iter() - .flat_map(|tag| (1..tag.len()).map(move |i| &tag[..i])) - .chain(["\n"]); - if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) { + if !self.ends_with_tag_prefix() { edit_events.push(EditParserEvent::NewTextChunk { chunk: mem::take(&mut self.buffer), done: false, @@ -141,6 +146,14 @@ impl EditParser { Some(start_ix..start_ix + tag.len()) } + fn ends_with_tag_prefix(&self) -> bool { + let mut end_prefixes = END_TAGS + .iter() + .flat_map(|tag| (1..tag.len()).map(move |i| &tag[..i])) + .chain(["\n"]); + end_prefixes.any(|prefix| self.buffer.ends_with(&prefix)) + } + pub fn finish(self) -> EditParserMetrics { self.metrics } @@ -412,20 +425,28 @@ mod tests { chunk_indices.sort(); chunk_indices.push(input.len()); + let mut old_text = Some(String::new()); + let mut new_text = None; let mut pending_edit = Edit::default(); let mut edits = Vec::new(); let mut last_ix = 0; for chunk_ix in chunk_indices { for event in parser.push(&input[last_ix..chunk_ix]) { match event { - EditParserEvent::OldText(old_text) => { - pending_edit.old_text = old_text; + EditParserEvent::OldTextChunk { chunk, done } => { + old_text.as_mut().unwrap().push_str(&chunk); + if done { + pending_edit.old_text = old_text.take().unwrap(); + new_text = Some(String::new()); + } } EditParserEvent::NewTextChunk { chunk, done } => { - pending_edit.new_text.push_str(&chunk); + new_text.as_mut().unwrap().push_str(&chunk); if done { + pending_edit.new_text = new_text.take().unwrap(); edits.push(pending_edit); pending_edit = Edit::default(); + old_text = Some(String::new()); } } } @@ -433,8 +454,6 @@ mod tests { last_ix = chunk_ix; } - assert_eq!(pending_edit, Edit::default(), "unfinished edit"); - edits } } diff --git a/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs b/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs new file mode 100644 index 0000000000..f0a23d28c0 --- /dev/null +++ b/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs @@ -0,0 +1,694 @@ +use language::{Point, TextBufferSnapshot}; +use std::{cmp, ops::Range}; + +const REPLACEMENT_COST: u32 = 1; +const INSERTION_COST: u32 = 3; +const DELETION_COST: u32 = 10; + +/// A streaming fuzzy matcher that can process text chunks incrementally +/// and return the best match found so far at each step. +pub struct StreamingFuzzyMatcher { + snapshot: TextBufferSnapshot, + query_lines: Vec, + incomplete_line: String, + best_match: Option>, + matrix: SearchMatrix, +} + +impl StreamingFuzzyMatcher { + pub fn new(snapshot: TextBufferSnapshot) -> Self { + let buffer_line_count = snapshot.max_point().row as usize + 1; + Self { + snapshot, + query_lines: Vec::new(), + incomplete_line: String::new(), + best_match: None, + matrix: SearchMatrix::new(buffer_line_count + 1), + } + } + + /// Returns the query lines. + pub fn query_lines(&self) -> &[String] { + &self.query_lines + } + + /// Push a new chunk of text and get the best match found so far. + /// + /// This method accumulates text chunks and processes complete lines. + /// Partial lines are buffered internally until a newline is received. + /// + /// # Returns + /// + /// Returns `Some(range)` if a match has been found with the accumulated + /// query so far, or `None` if no suitable match exists yet. + pub fn push(&mut self, chunk: &str) -> Option> { + // Add the chunk to our incomplete line buffer + self.incomplete_line.push_str(chunk); + + if let Some((last_pos, _)) = self.incomplete_line.match_indices('\n').next_back() { + let complete_part = &self.incomplete_line[..=last_pos]; + + // Split into lines and add to query_lines + for line in complete_part.lines() { + self.query_lines.push(line.to_string()); + } + + self.incomplete_line.replace_range(..last_pos + 1, ""); + + self.best_match = self.resolve_location_fuzzy(); + } + + self.best_match.clone() + } + + /// Finish processing and return the final best match. + /// + /// This processes any remaining incomplete line before returning the final + /// match result. + pub fn finish(&mut self) -> Option> { + // Process any remaining incomplete line + if !self.incomplete_line.is_empty() { + self.query_lines.push(self.incomplete_line.clone()); + self.best_match = self.resolve_location_fuzzy(); + } + + self.best_match.clone() + } + + fn resolve_location_fuzzy(&mut self) -> Option> { + let new_query_line_count = self.query_lines.len(); + let old_query_line_count = self.matrix.rows.saturating_sub(1); + if new_query_line_count == old_query_line_count { + return None; + } + + self.matrix.resize_rows(new_query_line_count + 1); + + // Process only the new query lines + for row in old_query_line_count..new_query_line_count { + let query_line = self.query_lines[row].trim(); + let leading_deletion_cost = (row + 1) as u32 * DELETION_COST; + + self.matrix.set( + row + 1, + 0, + SearchState::new(leading_deletion_cost, SearchDirection::Up), + ); + + let mut buffer_lines = self.snapshot.as_rope().chunks().lines(); + let mut col = 0; + while let Some(buffer_line) = buffer_lines.next() { + let buffer_line = buffer_line.trim(); + let up = SearchState::new( + self.matrix + .get(row, col + 1) + .cost + .saturating_add(DELETION_COST), + SearchDirection::Up, + ); + let left = SearchState::new( + self.matrix + .get(row + 1, col) + .cost + .saturating_add(INSERTION_COST), + SearchDirection::Left, + ); + let diagonal = SearchState::new( + if query_line == buffer_line { + self.matrix.get(row, col).cost + } else if fuzzy_eq(query_line, buffer_line) { + self.matrix.get(row, col).cost + REPLACEMENT_COST + } else { + self.matrix + .get(row, col) + .cost + .saturating_add(DELETION_COST + INSERTION_COST) + }, + SearchDirection::Diagonal, + ); + self.matrix + .set(row + 1, col + 1, up.min(left).min(diagonal)); + col += 1; + } + } + + // Traceback to find the best match + let buffer_line_count = self.snapshot.max_point().row as usize + 1; + let mut buffer_row_end = buffer_line_count as u32; + let mut best_cost = u32::MAX; + for col in 1..=buffer_line_count { + let cost = self.matrix.get(new_query_line_count, col).cost; + if cost < best_cost { + best_cost = cost; + buffer_row_end = col as u32; + } + } + + let mut matched_lines = 0; + let mut query_row = new_query_line_count; + let mut buffer_row_start = buffer_row_end; + while query_row > 0 && buffer_row_start > 0 { + let current = self.matrix.get(query_row, buffer_row_start as usize); + match current.direction { + SearchDirection::Diagonal => { + query_row -= 1; + buffer_row_start -= 1; + matched_lines += 1; + } + SearchDirection::Up => { + query_row -= 1; + } + SearchDirection::Left => { + buffer_row_start -= 1; + } + } + } + + let matched_buffer_row_count = buffer_row_end - buffer_row_start; + let matched_ratio = matched_lines as f32 + / (matched_buffer_row_count as f32).max(new_query_line_count as f32); + if matched_ratio >= 0.8 { + let buffer_start_ix = self + .snapshot + .point_to_offset(Point::new(buffer_row_start, 0)); + let buffer_end_ix = self.snapshot.point_to_offset(Point::new( + buffer_row_end - 1, + self.snapshot.line_len(buffer_row_end - 1), + )); + Some(buffer_start_ix..buffer_end_ix) + } else { + None + } + } +} + +fn fuzzy_eq(left: &str, right: &str) -> bool { + const THRESHOLD: f64 = 0.8; + + let min_levenshtein = left.len().abs_diff(right.len()); + let min_normalized_levenshtein = + 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64); + if min_normalized_levenshtein < THRESHOLD { + return false; + } + + strsim::normalized_levenshtein(left, right) >= THRESHOLD +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum SearchDirection { + Up, + Left, + Diagonal, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct SearchState { + cost: u32, + direction: SearchDirection, +} + +impl SearchState { + fn new(cost: u32, direction: SearchDirection) -> Self { + Self { cost, direction } + } +} + +struct SearchMatrix { + cols: usize, + rows: usize, + data: Vec, +} + +impl SearchMatrix { + fn new(cols: usize) -> Self { + SearchMatrix { + cols, + rows: 0, + data: Vec::new(), + } + } + + fn resize_rows(&mut self, needed_rows: usize) { + debug_assert!(needed_rows > self.rows); + self.rows = needed_rows; + self.data.resize( + self.rows * self.cols, + SearchState::new(0, SearchDirection::Diagonal), + ); + } + + fn get(&self, row: usize, col: usize) -> SearchState { + debug_assert!(row < self.rows && col < self.cols); + self.data[row * self.cols + col] + } + + fn set(&mut self, row: usize, col: usize, state: SearchState) { + debug_assert!(row < self.rows && col < self.cols); + self.data[row * self.cols + col] = state; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + use language::{BufferId, TextBuffer}; + use rand::prelude::*; + use util::test::{generate_marked_text, marked_text_ranges}; + + #[test] + fn test_empty_query() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + "Hello world\nThis is a test\nFoo bar baz", + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + assert_eq!(push(&mut finder, ""), None); + assert_eq!(finish(finder), None); + } + + #[test] + fn test_streaming_exact_match() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + "Hello world\nThis is a test\nFoo bar baz", + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + + // Push partial query + assert_eq!(push(&mut finder, "This"), None); + + // Complete the line + assert_eq!( + push(&mut finder, " is a test\n"), + Some("This is a test".to_string()) + ); + + // Finish should return the same result + assert_eq!(finish(finder), Some("This is a test".to_string())); + } + + #[test] + fn test_streaming_fuzzy_match() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + indoc! {" + function foo(a, b) { + return a + b; + } + + function bar(x, y) { + return x * y; + } + "}, + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + + // Push a fuzzy query that should match the first function + assert_eq!( + push(&mut finder, "function foo(a, c) {\n").as_deref(), + Some("function foo(a, b) {") + ); + assert_eq!( + push(&mut finder, " return a + c;\n}\n").as_deref(), + Some(concat!( + "function foo(a, b) {\n", + " return a + b;\n", + "}" + )) + ); + } + + #[test] + fn test_incremental_improvement() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + "Line 1\nLine 2\nLine 3\nLine 4\nLine 5", + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + + // No match initially + assert_eq!(push(&mut finder, "Lin"), None); + + // Get a match when we complete a line + assert_eq!(push(&mut finder, "e 3\n"), Some("Line 3".to_string())); + + // The match might change if we add more specific content + assert_eq!( + push(&mut finder, "Line 4\n"), + Some("Line 3\nLine 4".to_string()) + ); + assert_eq!(finish(finder), Some("Line 3\nLine 4".to_string())); + } + + #[test] + fn test_incomplete_lines_buffering() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + indoc! {" + The quick brown fox + jumps over the lazy dog + Pack my box with five dozen liquor jugs + "}, + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + + // Push text in small chunks across line boundaries + assert_eq!(push(&mut finder, "jumps "), None); // No newline yet + assert_eq!(push(&mut finder, "over the"), None); // Still no newline + assert_eq!(push(&mut finder, " lazy"), None); // Still incomplete + + // Complete the line + assert_eq!( + push(&mut finder, " dog\n"), + Some("jumps over the lazy dog".to_string()) + ); + } + + #[test] + fn test_multiline_fuzzy_match() { + let buffer = TextBuffer::new( + 0, + BufferId::new(1).unwrap(), + indoc! {r#" + impl Display for User { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "User: {} ({})", self.name, self.email) + } + } + + impl Debug for User { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("User") + .field("name", &self.name) + .field("email", &self.email) + .finish() + } + } + "#}, + ); + let snapshot = buffer.snapshot(); + + let mut finder = StreamingFuzzyMatcher::new(snapshot.clone()); + + assert_eq!( + push(&mut finder, "impl Debug for User {\n"), + Some("impl Debug for User {".to_string()) + ); + assert_eq!( + push( + &mut finder, + " fn fmt(&self, f: &mut Formatter) -> Result {\n" + ) + .as_deref(), + Some(concat!( + "impl Debug for User {\n", + " fn fmt(&self, f: &mut Formatter) -> fmt::Result {" + )) + ); + assert_eq!( + push(&mut finder, " f.debug_struct(\"User\")\n").as_deref(), + Some(concat!( + "impl Debug for User {\n", + " fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n", + " f.debug_struct(\"User\")" + )) + ); + assert_eq!( + push( + &mut finder, + " .field(\"name\", &self.username)\n" + ) + .as_deref(), + Some(concat!( + "impl Debug for User {\n", + " fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n", + " f.debug_struct(\"User\")\n", + " .field(\"name\", &self.name)" + )) + ); + assert_eq!( + finish(finder).as_deref(), + Some(concat!( + "impl Debug for User {\n", + " fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n", + " f.debug_struct(\"User\")\n", + " .field(\"name\", &self.name)" + )) + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_single_line(mut rng: StdRng) { + assert_location_resolution( + concat!( + " Lorem\n", + "« ipsum»\n", + " dolor sit amet\n", + " consecteur", + ), + "ipsum", + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_multiline(mut rng: StdRng) { + assert_location_resolution( + concat!( + " Lorem\n", + "« ipsum\n", + " dolor sit amet»\n", + " consecteur", + ), + "ipsum\ndolor sit amet", + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_function_with_typo(mut rng: StdRng) { + assert_location_resolution( + indoc! {" + «fn foo1(a: usize) -> usize { + 40 + }» + + fn foo2(b: usize) -> usize { + 42 + } + "}, + "fn foo1(a: usize) -> u32 {\n40\n}", + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_class_methods(mut rng: StdRng) { + assert_location_resolution( + indoc! {" + class Something { + one() { return 1; } + « two() { return 2222; } + three() { return 333; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; }» + seven() { return 7; } + eight() { return 8; } + } + "}, + indoc! {" + two() { return 2222; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; } + "}, + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_imports_no_match(mut rng: StdRng) { + assert_location_resolution( + indoc! {" + use std::ops::Range; + use std::sync::Mutex; + use std::{ + collections::HashMap, + env, + ffi::{OsStr, OsString}, + fs, + io::{BufRead, BufReader}, + mem, + path::{Path, PathBuf}, + process::Command, + sync::LazyLock, + time::SystemTime, + }; + "}, + indoc! {" + use std::collections::{HashMap, HashSet}; + use std::ffi::{OsStr, OsString}; + use std::fmt::Write as _; + use std::fs; + use std::io::{BufReader, Read, Write}; + use std::mem; + use std::path::{Path, PathBuf}; + use std::process::Command; + use std::sync::Arc; + "}, + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_nested_closure(mut rng: StdRng) { + assert_location_resolution( + indoc! {" + impl Foo { + fn new() -> Self { + Self { + subscriptions: vec![ + cx.observe_window_activation(window, |editor, window, cx| { + let active = window.is_window_active(); + editor.blink_manager.update(cx, |blink_manager, cx| { + if active { + blink_manager.enable(cx); + } else { + blink_manager.disable(cx); + } + }); + }), + ]; + } + } + } + "}, + concat!( + " editor.blink_manager.update(cx, |blink_manager, cx| {\n", + " blink_manager.enable(cx);\n", + " });", + ), + &mut rng, + ); + } + + #[gpui::test(iterations = 100)] + fn test_resolve_location_tool_invocation(mut rng: StdRng) { + assert_location_resolution( + indoc! {r#" + let tool = cx + .update(|cx| working_set.tool(&tool_name, cx)) + .map_err(|err| { + anyhow!("Failed to look up tool '{}': {}", tool_name, err) + })?; + + let Some(tool) = tool else { + return Err(anyhow!("Tool '{}' not found", tool_name)); + }; + + let project = project.clone(); + let action_log = action_log.clone(); + let messages = messages.clone(); + let tool_result = cx + .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx)) + .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?; + + tasks.push(tool_result.output); + "#}, + concat!( + "let tool_result = cx\n", + " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n", + " .output;", + ), + &mut rng, + ); + } + + #[track_caller] + fn assert_location_resolution(text_with_expected_range: &str, query: &str, rng: &mut StdRng) { + let (text, expected_ranges) = marked_text_ranges(text_with_expected_range, false); + let buffer = TextBuffer::new(0, BufferId::new(1).unwrap(), text.clone()); + let snapshot = buffer.snapshot(); + + let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone()); + + // Split query into random chunks + let chunks = to_random_chunks(rng, query); + + // Push chunks incrementally + for chunk in &chunks { + matcher.push(chunk); + } + + let result = matcher.finish(); + + // If no expected ranges, we expect no match + if expected_ranges.is_empty() { + assert_eq!( + result, None, + "Expected no match for query: {:?}, but found: {:?}", + query, result + ); + } else { + let mut actual_ranges = Vec::new(); + if let Some(range) = result { + actual_ranges.push(range); + } + + let text_with_actual_range = generate_marked_text(&text, &actual_ranges, false); + pretty_assertions::assert_eq!( + text_with_actual_range, + text_with_expected_range, + "Query: {:?}, Chunks: {:?}", + query, + chunks + ); + } + } + + fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec { + let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50)); + let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count); + chunk_indices.sort(); + chunk_indices.push(input.len()); + + let mut chunks = Vec::new(); + let mut last_ix = 0; + for chunk_ix in chunk_indices { + chunks.push(input[last_ix..chunk_ix].to_string()); + last_ix = chunk_ix; + } + chunks + } + + fn push(finder: &mut StreamingFuzzyMatcher, chunk: &str) -> Option { + finder + .push(chunk) + .map(|range| finder.snapshot.text_for_range(range).collect::()) + } + + fn finish(mut finder: StreamingFuzzyMatcher) -> Option { + let snapshot = finder.snapshot.clone(); + finder + .finish() + .map(|range| snapshot.text_for_range(range).collect::()) + } +} diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 51f63317ad..11ae95396a 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -12,13 +12,13 @@ use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey}; use futures::StreamExt; use gpui::{ - Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, EntityId, Task, + Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task, TextStyleRefinement, WeakEntity, pulsating_between, }; use indoc::formatdoc; use language::{ - Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer, - language_settings::SoftWrap, + Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope, + TextBuffer, language_settings::SoftWrap, }; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use markdown::{Markdown, MarkdownElement, MarkdownStyle}; @@ -27,6 +27,8 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use std::{ + cmp::Reverse, + ops::Range, path::{Path, PathBuf}, sync::Arc, time::Duration, @@ -98,7 +100,7 @@ pub enum EditFileMode { pub struct EditFileToolOutput { pub original_path: PathBuf, pub new_text: String, - pub old_text: String, + pub old_text: Arc, pub raw_output: Option, } @@ -200,10 +202,14 @@ impl Tool for EditFileTool { let old_text = cx .background_spawn({ let old_snapshot = old_snapshot.clone(); - async move { old_snapshot.text() } + async move { Arc::new(old_snapshot.text()) } }) .await; + if let Some(card) = card_clone.as_ref() { + card.update(cx, |card, cx| card.initialize(buffer.clone(), cx))?; + } + let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { edit_agent.edit( buffer.clone(), @@ -225,26 +231,15 @@ impl Tool for EditFileTool { match event { EditAgentOutputEvent::Edited => { if let Some(card) = card_clone.as_ref() { - let new_snapshot = - buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let new_text = cx - .background_spawn({ - let new_snapshot = new_snapshot.clone(); - async move { new_snapshot.text() } - }) - .await; - card.update(cx, |card, cx| { - card.set_diff( - project_path.path.clone(), - old_text.clone(), - new_text, - cx, - ); - }) - .log_err(); + card.update(cx, |card, cx| card.update_diff(cx))?; + } + } + EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, + EditAgentOutputEvent::ResolvingEditRange(range) => { + if let Some(card) = card_clone.as_ref() { + card.update(cx, |card, cx| card.reveal_range(range, cx))?; } } - EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true, } } let agent_output = output.await?; @@ -266,13 +261,14 @@ impl Tool for EditFileTool { let output = EditFileToolOutput { original_path: project_path.path.to_path_buf(), new_text: new_text.clone(), - old_text: old_text.clone(), + old_text, raw_output: Some(agent_output), }; if let Some(card) = card_clone { card.update(cx, |card, cx| { - card.set_diff(project_path.path.clone(), old_text, new_text, cx); + card.update_diff(cx); + card.finalize(cx) }) .log_err(); } @@ -282,12 +278,15 @@ impl Tool for EditFileTool { anyhow::ensure!( !hallucinated_old_text, formatdoc! {" - Some edits were produced but none of them could be applied. - Read the relevant sections of {input_path} again so that - I can perform the requested edits. - "} + Some edits were produced but none of them could be applied. + Read the relevant sections of {input_path} again so that + I can perform the requested edits. + "} ); - Ok("No edits were made.".to_string().into()) + Ok(ToolResultOutput { + content: ToolResultContent::Text("No edits were made.".into()), + output: serde_json::to_value(output).ok(), + }) } else { Ok(ToolResultOutput { content: ToolResultContent::Text(format!( @@ -318,16 +317,48 @@ impl Tool for EditFileTool { }; let card = cx.new(|cx| { - let mut card = EditFileToolCard::new(output.original_path.clone(), project, window, cx); - card.set_diff( - output.original_path.into(), - output.old_text, - output.new_text, - cx, - ); - card + EditFileToolCard::new(output.original_path.clone(), project.clone(), window, cx) }); + cx.spawn({ + let path: Arc = output.original_path.into(); + let language_registry = project.read(cx).languages().clone(); + let card = card.clone(); + async move |cx| { + let buffer = + build_buffer(output.new_text, path.clone(), &language_registry, cx).await?; + let buffer_diff = + build_buffer_diff(output.old_text.clone(), &buffer, &language_registry, cx) + .await?; + card.update(cx, |card, cx| { + card.multibuffer.update(cx, |multibuffer, cx| { + let snapshot = buffer.read(cx).snapshot(); + let diff = buffer_diff.read(cx); + let diff_hunk_ranges = diff + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot)) + .collect::>(); + + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(&buffer, cx), + buffer, + diff_hunk_ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + multibuffer.add_diff(buffer_diff, cx); + let end = multibuffer.len(cx); + card.total_lines = + Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1); + }); + + cx.notify(); + })?; + anyhow::Ok(()) + } + }) + .detach_and_log_err(cx); + Some(card.into()) } } @@ -402,12 +433,15 @@ pub struct EditFileToolCard { editor: Entity, multibuffer: Entity, project: Entity, + buffer: Option>, + base_text: Option>, + buffer_diff: Option>, + revealed_ranges: Vec>, diff_task: Option>>, preview_expanded: bool, error_expanded: Option>, full_height_expanded: bool, total_lines: Option, - editor_unique_id: EntityId, } impl EditFileToolCard { @@ -442,11 +476,14 @@ impl EditFileToolCard { editor }); Self { - editor_unique_id: editor.entity_id(), path, project, editor, multibuffer, + buffer: None, + base_text: None, + buffer_diff: None, + revealed_ranges: Vec::new(), diff_task: None, preview_expanded: true, error_expanded: None, @@ -455,46 +492,184 @@ impl EditFileToolCard { } } - pub fn has_diff(&self) -> bool { - self.total_lines.is_some() + pub fn initialize(&mut self, buffer: Entity, cx: &mut App) { + let buffer_snapshot = buffer.read(cx).snapshot(); + let base_text = buffer_snapshot.text(); + let language_registry = buffer.read(cx).language_registry(); + let text_snapshot = buffer.read(cx).text_snapshot(); + + // Create a buffer diff with the current text as the base + let buffer_diff = cx.new(|cx| { + let mut diff = BufferDiff::new(&text_snapshot, cx); + let _ = diff.set_base_text( + buffer_snapshot.clone(), + language_registry, + text_snapshot, + cx, + ); + diff + }); + + self.buffer = Some(buffer.clone()); + self.base_text = Some(base_text.into()); + self.buffer_diff = Some(buffer_diff.clone()); + + // Add the diff to the multibuffer + self.multibuffer + .update(cx, |multibuffer, cx| multibuffer.add_diff(buffer_diff, cx)); } - pub fn set_diff( - &mut self, - path: Arc, - old_text: String, - new_text: String, - cx: &mut Context, - ) { - let language_registry = self.project.read(cx).languages().clone(); - self.diff_task = Some(cx.spawn(async move |this, cx| { - let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?; - let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?; + pub fn is_loading(&self) -> bool { + self.total_lines.is_none() + } + pub fn update_diff(&mut self, cx: &mut Context) { + let Some(buffer) = self.buffer.as_ref() else { + return; + }; + let Some(buffer_diff) = self.buffer_diff.as_ref() else { + return; + }; + + let buffer = buffer.clone(); + let buffer_diff = buffer_diff.clone(); + let base_text = self.base_text.clone(); + self.diff_task = Some(cx.spawn(async move |this, cx| { + let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?; + let diff_snapshot = BufferDiff::update_diff( + buffer_diff.clone(), + text_snapshot.clone(), + base_text, + false, + false, + None, + None, + cx, + ) + .await?; + buffer_diff.update(cx, |diff, cx| { + diff.set_snapshot(diff_snapshot, &text_snapshot, cx) + })?; + this.update(cx, |this, cx| this.update_visible_ranges(cx)) + })); + } + + pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { + self.revealed_ranges.push(range); + self.update_visible_ranges(cx); + } + + fn update_visible_ranges(&mut self, cx: &mut Context) { + let Some(buffer) = self.buffer.as_ref() else { + return; + }; + + let ranges = self.excerpt_ranges(cx); + self.total_lines = self.multibuffer.update(cx, |multibuffer, cx| { + multibuffer.set_excerpts_for_path( + PathKey::for_buffer(buffer, cx), + buffer.clone(), + ranges, + editor::DEFAULT_MULTIBUFFER_CONTEXT, + cx, + ); + let end = multibuffer.len(cx); + Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1) + }); + cx.notify(); + } + + fn excerpt_ranges(&self, cx: &App) -> Vec> { + let Some(buffer) = self.buffer.as_ref() else { + return Vec::new(); + }; + let Some(diff) = self.buffer_diff.as_ref() else { + return Vec::new(); + }; + + let buffer = buffer.read(cx); + let diff = diff.read(cx); + let mut ranges = diff + .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer)) + .collect::>(); + ranges.extend( + self.revealed_ranges + .iter() + .map(|range| range.to_point(&buffer)), + ); + ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); + + // Merge adjacent ranges + let mut ranges = ranges.into_iter().peekable(); + let mut merged_ranges = Vec::new(); + while let Some(mut range) = ranges.next() { + while let Some(next_range) = ranges.peek() { + if range.end >= next_range.start { + range.end = range.end.max(next_range.end); + ranges.next(); + } else { + break; + } + } + + merged_ranges.push(range); + } + merged_ranges + } + + pub fn finalize(&mut self, cx: &mut Context) -> Result<()> { + let ranges = self.excerpt_ranges(cx); + let buffer = self.buffer.take().context("card was already finalized")?; + let base_text = self + .base_text + .take() + .context("card was already finalized")?; + let language_registry = self.project.read(cx).languages().clone(); + + // Replace the buffer in the multibuffer with the snapshot + let buffer = cx.new(|cx| { + let language = buffer.read(cx).language().cloned(); + let buffer = TextBuffer::new_normalized( + 0, + cx.entity_id().as_non_zero_u64().into(), + buffer.read(cx).line_ending(), + buffer.read(cx).as_rope().clone(), + ); + let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); + buffer.set_language(language, cx); + buffer + }); + + let buffer_diff = cx.spawn({ + let buffer = buffer.clone(); + let language_registry = language_registry.clone(); + async move |_this, cx| { + build_buffer_diff(base_text, &buffer, &language_registry, cx).await + } + }); + + cx.spawn(async move |this, cx| { + let buffer_diff = buffer_diff.await?; this.update(cx, |this, cx| { - this.total_lines = this.multibuffer.update(cx, |multibuffer, cx| { - let snapshot = buffer.read(cx).snapshot(); - let diff = buffer_diff.read(cx); - let diff_hunk_ranges = diff - .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot)) - .collect::>(); + this.multibuffer.update(cx, |multibuffer, cx| { + let path_key = PathKey::for_buffer(&buffer, cx); multibuffer.clear(cx); multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&buffer, cx), + path_key, buffer, - diff_hunk_ranges, + ranges, editor::DEFAULT_MULTIBUFFER_CONTEXT, cx, ); - multibuffer.add_diff(buffer_diff, cx); - let end = multibuffer.len(cx); - Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1) + multibuffer.add_diff(buffer_diff.clone(), cx); }); cx.notify(); }) - })); + }) + .detach_and_log_err(cx); + Ok(()) } } @@ -512,7 +687,7 @@ impl ToolCard for EditFileToolCard { }; let path_label_button = h_flex() - .id(("edit-tool-path-label-button", self.editor_unique_id)) + .id(("edit-tool-path-label-button", self.editor.entity_id())) .w_full() .max_w_full() .px_1() @@ -611,7 +786,7 @@ impl ToolCard for EditFileToolCard { ) .child( Disclosure::new( - ("edit-file-error-disclosure", self.editor_unique_id), + ("edit-file-error-disclosure", self.editor.entity_id()), self.error_expanded.is_some(), ) .opened_icon(IconName::ChevronUp) @@ -633,10 +808,10 @@ impl ToolCard for EditFileToolCard { ), ) }) - .when(error_message.is_none() && self.has_diff(), |header| { + .when(error_message.is_none() && !self.is_loading(), |header| { header.child( Disclosure::new( - ("edit-file-disclosure", self.editor_unique_id), + ("edit-file-disclosure", self.editor.entity_id()), self.preview_expanded, ) .opened_icon(IconName::ChevronUp) @@ -772,10 +947,10 @@ impl ToolCard for EditFileToolCard { ), ) }) - .when(!self.has_diff() && error_message.is_none(), |card| { + .when(self.is_loading() && error_message.is_none(), |card| { card.child(waiting_for_diff) }) - .when(self.preview_expanded && self.has_diff(), |card| { + .when(self.preview_expanded && !self.is_loading(), |card| { card.child( v_flex() .relative() @@ -797,7 +972,7 @@ impl ToolCard for EditFileToolCard { .when(is_collapsible, |card| { card.child( h_flex() - .id(("expand-button", self.editor_unique_id)) + .id(("expand-button", self.editor.entity_id())) .flex_none() .cursor_pointer() .h_5() @@ -871,19 +1046,23 @@ async fn build_buffer( } async fn build_buffer_diff( - mut old_text: String, + old_text: Arc, buffer: &Entity, language_registry: &Arc, cx: &mut AsyncApp, ) -> Result> { - LineEnding::normalize(&mut old_text); - let buffer = cx.update(|cx| buffer.read(cx).snapshot())?; + let old_text_rope = cx + .background_spawn({ + let old_text = old_text.clone(); + async move { Rope::from(old_text.as_str()) } + }) + .await; let base_buffer = cx .update(|cx| { Buffer::build_snapshot( - old_text.clone().into(), + old_text_rope, buffer.language().cloned(), Some(language_registry.clone()), cx, @@ -895,7 +1074,7 @@ async fn build_buffer_diff( .update(|cx| { BufferDiffSnapshot::new_with_base_buffer( buffer.text.clone(), - Some(old_text.into()), + Some(old_text), base_buffer, cx, ) diff --git a/crates/language/src/syntax_map/syntax_map_tests.rs b/crates/language/src/syntax_map/syntax_map_tests.rs index 211edad87c..f9b950c8f4 100644 --- a/crates/language/src/syntax_map/syntax_map_tests.rs +++ b/crates/language/src/syntax_map/syntax_map_tests.rs @@ -1076,7 +1076,7 @@ fn test_edit_sequence(language_name: &str, steps: &[&str], cx: &mut App) -> (Buf .now_or_never() .unwrap() .unwrap(); - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), Default::default()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); let mut mutated_syntax_map = SyntaxMap::new(&buffer); mutated_syntax_map.set_language_registry(registry.clone()); diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index e94322608c..b4ba0c057f 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -107,14 +107,18 @@ impl FakeLanguageModel { self.current_completion_txs.lock().len() } - pub fn stream_completion_response(&self, request: &LanguageModelRequest, chunk: String) { + pub fn stream_completion_response( + &self, + request: &LanguageModelRequest, + chunk: impl Into, + ) { let current_completion_txs = self.current_completion_txs.lock(); let tx = current_completion_txs .iter() .find(|(req, _)| req == request) .map(|(_, tx)| tx) .unwrap(); - tx.unbounded_send(chunk).unwrap(); + tx.unbounded_send(chunk.into()).unwrap(); } pub fn end_completion_stream(&self, request: &LanguageModelRequest) { @@ -123,7 +127,7 @@ impl FakeLanguageModel { .retain(|(req, _)| req != request); } - pub fn stream_last_completion_response(&self, chunk: String) { + pub fn stream_last_completion_response(&self, chunk: impl Into) { self.stream_completion_response(self.pending_completions().last().unwrap(), chunk); } diff --git a/crates/project/src/buffer_store.rs b/crates/project/src/buffer_store.rs index 6c5ec8157e..8d54cd046e 100644 --- a/crates/project/src/buffer_store.rs +++ b/crates/project/src/buffer_store.rs @@ -622,7 +622,7 @@ impl LocalBufferStore { Ok(buffer) => Ok(buffer), Err(error) if is_not_found_error(&error) => cx.new(|cx| { let buffer_id = BufferId::from(cx.entity_id().as_non_zero_u64()); - let text_buffer = text::Buffer::new(0, buffer_id, "".into()); + let text_buffer = text::Buffer::new(0, buffer_id, ""); Buffer::build( text_buffer, Some(Arc::new(File { diff --git a/crates/text/src/tests.rs b/crates/text/src/tests.rs index f2a14d64b4..a096f1281f 100644 --- a/crates/text/src/tests.rs +++ b/crates/text/src/tests.rs @@ -16,7 +16,7 @@ fn init_logger() { #[test] fn test_edit() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "abc".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "abc"); assert_eq!(buffer.text(), "abc"); buffer.edit([(3..3, "def")]); assert_eq!(buffer.text(), "abcdef"); @@ -175,7 +175,7 @@ fn test_line_endings() { LineEnding::Windows ); - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "one\r\ntwo\rthree".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "one\r\ntwo\rthree"); assert_eq!(buffer.text(), "one\ntwo\nthree"); assert_eq!(buffer.line_ending(), LineEnding::Windows); buffer.check_invariants(); @@ -189,7 +189,7 @@ fn test_line_endings() { #[test] fn test_line_len() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); buffer.edit([(0..0, "abcd\nefg\nhij")]); buffer.edit([(12..12, "kl\nmno")]); buffer.edit([(18..18, "\npqrs\n")]); @@ -206,7 +206,7 @@ fn test_line_len() { #[test] fn test_common_prefix_at_position() { let text = "a = str; b = δα"; - let buffer = Buffer::new(0, BufferId::new(1).unwrap(), text.into()); + let buffer = Buffer::new(0, BufferId::new(1).unwrap(), text); let offset1 = offset_after(text, "str"); let offset2 = offset_after(text, "δα"); @@ -257,7 +257,7 @@ fn test_text_summary_for_range() { let buffer = Buffer::new( 0, BufferId::new(1).unwrap(), - "ab\nefg\nhklm\nnopqrs\ntuvwxyz".into(), + "ab\nefg\nhklm\nnopqrs\ntuvwxyz", ); assert_eq!( buffer.text_summary_for_range::(0..2), @@ -347,7 +347,7 @@ fn test_text_summary_for_range() { #[test] fn test_chars_at() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); buffer.edit([(0..0, "abcd\nefgh\nij")]); buffer.edit([(12..12, "kl\nmno")]); buffer.edit([(18..18, "\npqrs")]); @@ -369,7 +369,7 @@ fn test_chars_at() { assert_eq!(chars.collect::(), "PQrs"); // Regression test: - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); buffer.edit([(0..0, "[workspace]\nmembers = [\n \"xray_core\",\n \"xray_server\",\n \"xray_cli\",\n \"xray_wasm\",\n]\n")]); buffer.edit([(60..60, "\n")]); @@ -379,7 +379,7 @@ fn test_chars_at() { #[test] fn test_anchors() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); buffer.edit([(0..0, "abc")]); let left_anchor = buffer.anchor_before(2); let right_anchor = buffer.anchor_after(2); @@ -497,7 +497,7 @@ fn test_anchors() { #[test] fn test_anchors_at_start_and_end() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), ""); let before_start_anchor = buffer.anchor_before(0); let after_end_anchor = buffer.anchor_after(0); @@ -520,7 +520,7 @@ fn test_anchors_at_start_and_end() { #[test] fn test_undo_redo() { - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "1234".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "1234"); // Set group interval to zero so as to not group edits in the undo stack. buffer.set_group_interval(Duration::from_secs(0)); @@ -557,7 +557,7 @@ fn test_undo_redo() { #[test] fn test_history() { let mut now = Instant::now(); - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "123456".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "123456"); buffer.set_group_interval(Duration::from_millis(300)); let transaction_1 = buffer.start_transaction_at(now).unwrap(); @@ -624,7 +624,7 @@ fn test_history() { #[test] fn test_finalize_last_transaction() { let now = Instant::now(); - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "123456".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "123456"); buffer.history.group_interval = Duration::from_millis(1); buffer.start_transaction_at(now); @@ -660,7 +660,7 @@ fn test_finalize_last_transaction() { #[test] fn test_edited_ranges_for_transaction() { let now = Instant::now(); - let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "1234567".into()); + let mut buffer = Buffer::new(0, BufferId::new(1).unwrap(), "1234567"); buffer.start_transaction_at(now); buffer.edit([(2..4, "cd")]); @@ -699,9 +699,9 @@ fn test_edited_ranges_for_transaction() { fn test_concurrent_edits() { let text = "abcdef"; - let mut buffer1 = Buffer::new(1, BufferId::new(1).unwrap(), text.into()); - let mut buffer2 = Buffer::new(2, BufferId::new(1).unwrap(), text.into()); - let mut buffer3 = Buffer::new(3, BufferId::new(1).unwrap(), text.into()); + let mut buffer1 = Buffer::new(1, BufferId::new(1).unwrap(), text); + let mut buffer2 = Buffer::new(2, BufferId::new(1).unwrap(), text); + let mut buffer3 = Buffer::new(3, BufferId::new(1).unwrap(), text); let buf1_op = buffer1.edit([(1..2, "12")]); assert_eq!(buffer1.text(), "a12cdef"); diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index fc7fbfb8f4..b18a7598be 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -677,7 +677,8 @@ impl FromIterator for LineIndent { } impl Buffer { - pub fn new(replica_id: u16, remote_id: BufferId, mut base_text: String) -> Buffer { + pub fn new(replica_id: u16, remote_id: BufferId, base_text: impl Into) -> Buffer { + let mut base_text = base_text.into(); let line_ending = LineEnding::detect(&base_text); LineEnding::normalize(&mut base_text); Self::new_normalized(replica_id, remote_id, line_ending, Rope::from(base_text))