This commit is contained in:
Antonio Scandurra 2023-08-22 08:16:22 +02:00
parent 5b9d48d723
commit 5453553cfa
3 changed files with 234 additions and 109 deletions

1
Cargo.lock generated
View file

@ -106,6 +106,7 @@ dependencies = [
"fs",
"futures 0.3.28",
"gpui",
"indoc",
"isahc",
"language",
"menu",

View file

@ -24,6 +24,7 @@ workspace = { path = "../workspace" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
indoc.workspace = true
isahc.workspace = true
regex.workspace = true
schemars.workspace = true

View file

@ -1,14 +1,13 @@
use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
use collections::HashMap;
use editor::{Editor, ToOffset};
use futures::StreamExt;
use gpui::{
actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
WeakViewHandle,
};
use menu::Confirm;
use serde::Deserialize;
use similar::ChangeTag;
use similar::{Change, ChangeTag, TextDiff};
use std::{env, iter, ops::Range, sync::Arc};
use util::TryFutureExt;
use workspace::{Modal, Workspace};
@ -33,12 +32,12 @@ impl RefactoringAssistant {
}
fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let selection = editor.read(cx).selections.newest_anchor().clone();
let selected_text = buffer
let selected_text = snapshot
.text_for_range(selection.start..selection.end)
.collect::<String>();
let language_name = buffer
let language_name = snapshot
.language_at(selection.start)
.map(|language| language.name());
let language_name = language_name.as_deref().unwrap_or("");
@ -48,7 +47,7 @@ impl RefactoringAssistant {
RequestMessage {
role: Role::User,
content: format!(
"Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
"Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code. Preserve indentation."
),
}],
stream: true,
@ -60,86 +59,149 @@ impl RefactoringAssistant {
editor.id(),
cx.spawn(|mut cx| {
async move {
let selection_start = selection.start.to_offset(&buffer);
// Find unique words in the selected text to use as diff boundaries.
let mut duplicate_words = HashSet::default();
let mut unique_old_words = HashMap::default();
for (range, word) in words(&selected_text) {
if !duplicate_words.contains(word) {
if unique_old_words.insert(word, range.end).is_some() {
unique_old_words.remove(word);
duplicate_words.insert(word);
}
}
}
let selection_start = selection.start.to_offset(&snapshot);
let mut new_text = String::new();
let mut messages = response.await?;
let mut new_word_search_start_ix = 0;
let mut last_old_word_end_ix = 0;
'outer: loop {
const MIN_DIFF_LEN: usize = 50;
let mut transaction = None;
let start = new_word_search_start_ix;
let mut words = words(&new_text[start..]);
while let Some((range, new_word)) = words.next() {
// We found a word in the new text that was unique in the old text. We can use
// it as a diff boundary, and start applying edits.
if let Some(old_word_end_ix) = unique_old_words.get(new_word).copied() {
if old_word_end_ix.saturating_sub(last_old_word_end_ix)
> MIN_DIFF_LEN
while let Some(message) = messages.next().await {
smol::future::yield_now().await;
let mut message = message?;
if let Some(choice) = message.choices.pop() {
if let Some(text) = choice.delta.content {
new_text.push_str(&text);
println!("-------------------------------------");
println!(
"{}",
similar::TextDiff::from_words(&selected_text, &new_text)
.unified_diff()
);
let mut changes =
similar::TextDiff::from_words(&selected_text, &new_text)
.iter_all_changes()
.collect::<Vec<_>>();
let mut ix = 0;
while ix < changes.len() {
let deletion_start_ix = ix;
let mut deletion_end_ix = ix;
while changes
.get(ix)
.map_or(false, |change| change.tag() == ChangeTag::Delete)
{
ix += 1;
deletion_end_ix += 1;
}
let insertion_start_ix = ix;
let mut insertion_end_ix = ix;
while changes
.get(ix)
.map_or(false, |change| change.tag() == ChangeTag::Insert)
{
ix += 1;
insertion_end_ix += 1;
}
if deletion_end_ix > deletion_start_ix
&& insertion_end_ix > insertion_start_ix
{
for _ in deletion_start_ix..deletion_end_ix {
let deletion = changes.remove(deletion_end_ix);
changes.insert(insertion_end_ix - 1, deletion);
}
}
ix += 1;
}
while changes
.last()
.map_or(false, |change| change.tag() != ChangeTag::Insert)
{
drop(words);
let remainder = new_text.split_off(start + range.end);
let edits = diff(
selection_start + last_old_word_end_ix,
&selected_text[last_old_word_end_ix..old_word_end_ix],
&new_text,
&buffer,
);
editor.update(&mut cx, |editor, cx| {
editor
.buffer()
.update(cx, |buffer, cx| buffer.edit(edits, None, cx))
})?;
new_text = remainder;
new_word_search_start_ix = 0;
last_old_word_end_ix = old_word_end_ix;
continue 'outer;
changes.pop();
}
}
new_word_search_start_ix = start + range.end;
}
drop(words);
editor.update(&mut cx, |editor, cx| {
editor.buffer().update(cx, |buffer, cx| {
if let Some(transaction) = transaction.take() {
buffer.undo(cx); // TODO: Undo the transaction instead
}
// Buffer incoming text, stopping if the stream was exhausted.
if let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
if let Some(text) = choice.delta.content {
new_text.push_str(&text);
}
buffer.start_transaction(cx);
let mut edit_start = selection_start;
dbg!(&changes);
for change in changes {
let value = change.value();
let edit_end = edit_start + value.len();
match change.tag() {
ChangeTag::Equal => {
edit_start = edit_end;
}
ChangeTag::Delete => {
let range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
buffer.edit([(range, "")], None, cx);
edit_start = edit_end;
}
ChangeTag::Insert => {
let insertion_start =
snapshot.anchor_after(edit_start);
buffer.edit(
[(insertion_start..insertion_start, value)],
None,
cx,
);
}
}
}
transaction = buffer.end_transaction(cx);
})
})?;
}
} else {
break;
}
}
let edits = diff(
selection_start + last_old_word_end_ix,
&selected_text[last_old_word_end_ix..],
&new_text,
&buffer,
);
editor.update(&mut cx, |editor, cx| {
editor
.buffer()
.update(cx, |buffer, cx| buffer.edit(edits, None, cx))
editor.buffer().update(cx, |buffer, cx| {
if let Some(transaction) = transaction.take() {
buffer.undo(cx); // TODO: Undo the transaction instead
}
buffer.start_transaction(cx);
let mut edit_start = selection_start;
for change in similar::TextDiff::from_words(&selected_text, &new_text)
.iter_all_changes()
{
let value = change.value();
let edit_end = edit_start + value.len();
match change.tag() {
ChangeTag::Equal => {
edit_start = edit_end;
}
ChangeTag::Delete => {
let range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
buffer.edit([(range, "")], None, cx);
edit_start = edit_end;
}
ChangeTag::Insert => {
let insertion_start = snapshot.anchor_after(edit_start);
buffer.edit(
[(insertion_start..insertion_start, value)],
None,
cx,
);
}
}
}
buffer.end_transaction(cx);
})
})?;
anyhow::Ok(())
@ -197,11 +259,13 @@ impl RefactoringModal {
{
workspace.toggle_modal(cx, |_, cx| {
let prompt_editor = cx.add_view(|cx| {
Editor::auto_height(
let mut editor = Editor::auto_height(
4,
Some(Arc::new(|theme| theme.search.editor.input.clone())),
cx,
)
);
editor.set_text("Replace with match statement.", cx);
editor
});
cx.add_view(|_| RefactoringModal {
editor,
@ -242,38 +306,97 @@ fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
})
}
fn diff<'a>(
start_ix: usize,
old_text: &'a str,
new_text: &'a str,
old_buffer_snapshot: &MultiBufferSnapshot,
) -> Vec<(Range<Anchor>, &'a str)> {
let mut edit_start = start_ix;
let mut edits = Vec::new();
let diff = similar::TextDiff::from_words(old_text, &new_text);
for change in diff.iter_all_changes() {
let value = change.value();
let edit_end = edit_start + value.len();
match change.tag() {
ChangeTag::Equal => {
edit_start = edit_end;
}
ChangeTag::Delete => {
edits.push((
old_buffer_snapshot.anchor_after(edit_start)
..old_buffer_snapshot.anchor_before(edit_end),
"",
));
edit_start = edit_end;
}
ChangeTag::Insert => {
edits.push((
old_buffer_snapshot.anchor_after(edit_start)
..old_buffer_snapshot.anchor_after(edit_start),
value,
));
}
fn streaming_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<Change<'a, str>> {
let changes = TextDiff::configure()
.algorithm(similar::Algorithm::Patience)
.diff_words(old_text, new_text);
let mut changes = changes.iter_all_changes().peekable();
let mut result = vec![];
loop {
let mut deletions = vec![];
let mut insertions = vec![];
while changes
.peek()
.map_or(false, |change| change.tag() == ChangeTag::Delete)
{
deletions.push(changes.next().unwrap());
}
while changes
.peek()
.map_or(false, |change| change.tag() == ChangeTag::Insert)
{
insertions.push(changes.next().unwrap());
}
if !deletions.is_empty() && !insertions.is_empty() {
result.append(&mut insertions);
result.append(&mut deletions);
} else {
result.append(&mut deletions);
result.append(&mut insertions);
}
if let Some(change) = changes.next() {
result.push(change);
} else {
break;
}
}
edits
// Remove all non-inserts at the end.
while result
.last()
.map_or(false, |change| change.tag() != ChangeTag::Insert)
{
result.pop();
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
#[test]
fn test_streaming_diff() {
let old_text = indoc! {"
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
return self
.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format);
}
(Format::A8, Format::Rgb24) => {
return self
.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format);
}
(Format::Rgb24, Format::A8) => {
return self
.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format);
}
(Format::Rgb24, Format::Rgba32) => {
return self.blit_from_with::<BlitRgba32ToRgb24>(
dst_rect, src_bytes, src_stride, src_format,
);
}
(Format::Rgba32, Format::Rgb24)
| (Format::Rgba32, Format::A8)
| (Format::A8, Format::Rgba32) => {
unimplemented!()
}
_ => {}
}
"};
let new_text = indoc! {"
if self.format == src_format
"};
dbg!(streaming_diff(old_text, new_text));
}
}