Never use the indentation that comes from OpenAI

This commit is contained in:
Antonio Scandurra 2023-09-11 16:33:25 +02:00
parent 6d9333dc3b
commit b8c437529c
2 changed files with 261 additions and 136 deletions

View file

@ -4,12 +4,14 @@ use crate::{
OpenAIRequest,
};
use anyhow::Result;
use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
use futures::{
channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
};
use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
use language::{IndentSize, Point, Rope, TransactionId};
use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range, sync::Arc};
pub trait CompletionProvider {
@ -57,10 +59,17 @@ pub enum Event {
Undone,
}
#[derive(Clone)]
pub enum CodegenKind {
Transform { range: Range<Anchor> },
Generate { position: Anchor },
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: ModelHandle<MultiBuffer>,
range: Range<Anchor>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
last_equal_ranges: Vec<Range<Anchor>>,
transaction_id: Option<TransactionId>,
error: Option<anyhow::Error>,
@ -76,14 +85,31 @@ impl Entity for Codegen {
impl Codegen {
pub fn new(
buffer: ModelHandle<MultiBuffer>,
range: Range<Anchor>,
mut kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
match &mut kind {
CodegenKind::Transform { range } => {
let mut point_range = range.to_point(&snapshot);
point_range.start.column = 0;
if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
point_range.end.column = snapshot.line_len(point_range.end.row);
}
range.start = snapshot.anchor_before(point_range.start);
range.end = snapshot.anchor_after(point_range.end);
}
CodegenKind::Generate { position } => {
*position = position.bias_right(&snapshot);
}
}
Self {
provider,
buffer: buffer.clone(),
range,
snapshot,
kind,
last_equal_ranges: Default::default(),
transaction_id: Default::default(),
error: Default::default(),
@ -109,7 +135,14 @@ impl Codegen {
}
pub fn range(&self) -> Range<Anchor> {
self.range.clone()
match &self.kind {
CodegenKind::Transform { range } => range.clone(),
CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
}
}
pub fn kind(&self) -> &CodegenKind {
&self.kind
}
pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
@ -125,56 +158,18 @@ impl Codegen {
}
pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
let range = self.range.clone();
let snapshot = self.buffer.read(cx).snapshot(cx);
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
.text_for_range(range.start..range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let selection_end = range.end.to_point(&snapshot);
let mut base_indent: Option<IndentSize> = None;
let mut start_row = selection_start.row;
if snapshot.is_line_blank(start_row) {
if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
start_row = prev_non_blank_row;
}
}
for row in start_row..=selection_end.row {
if snapshot.is_line_blank(row) {
continue;
}
let line_indent = snapshot.indent_size_for_line(row);
if let Some(base_indent) = base_indent.as_mut() {
if line_indent.len < base_indent.len {
*base_indent = line_indent;
}
} else {
base_indent = Some(line_indent);
}
}
let mut normalized_selected_text = selected_text.clone();
if let Some(base_indent) = base_indent {
for row in selection_start.row..=selection_end.row {
let selection_row = row - selection_start.row;
let line_start =
normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
let indent_len = if row == selection_start.row {
base_indent.len.saturating_sub(selection_start.column)
} else {
let line_len = normalized_selected_text.line_len(selection_row);
cmp::min(line_len, base_indent.len)
};
let indent_end = cmp::min(
line_start + indent_len as usize,
normalized_selected_text.len(),
);
normalized_selected_text.replace(line_start..indent_end, "");
}
}
let suggested_line_indent = snapshot
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
.into_values()
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
self.generation = cx.spawn_weak(|this, mut cx| {
@ -188,66 +183,58 @@ impl Codegen {
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut indent_len;
let indent_text;
if let Some(base_indent) = base_indent {
indent_len = base_indent.len;
indent_text = match base_indent.kind {
language::IndentKind::Space => " ",
language::IndentKind::Tab => "\t",
};
} else {
indent_len = 0;
indent_text = "";
};
let mut first_line_len = 0;
let mut first_line_non_whitespace_char_ix = None;
let mut first_line = true;
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
let mut lines = chunk.split('\n');
if let Some(mut line) = lines.next() {
if first_line {
if first_line_non_whitespace_char_ix.is_none() {
if let Some(mut char_ix) =
line.find(|ch: char| !ch.is_whitespace())
{
line = &line[char_ix..];
char_ix += first_line_len;
first_line_non_whitespace_char_ix = Some(char_ix);
let first_line_indent = char_ix
.saturating_sub(selection_start.column as usize)
as usize;
new_text
.push_str(&indent_text.repeat(first_line_indent));
indent_len = indent_len.saturating_sub(char_ix as u32);
}
}
first_line_len += line.len();
}
if first_line_non_whitespace_char_ix.is_some() {
new_text.push_str(line);
}
}
for line in lines {
first_line = false;
new_text.push('\n');
if !line.is_empty() {
new_text.push_str(&indent_text.repeat(indent_len as usize));
}
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
}
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let hunks = diff.push_new(&new_text);
hunks_tx.send(hunks).await?;
new_text.clear();
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta = line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(selection_start.column as usize);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.push_new("\n")).await?;
new_text.clear();
line_indent = None;
first_line = false;
}
}
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
@ -285,7 +272,7 @@ impl Codegen {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start += len;
edit_start = edit_end;
this.last_equal_ranges.push(edit_range);
None
}
@ -410,16 +397,20 @@ mod tests {
use futures::stream;
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{tree_sitter_rust, Buffer, Language, LanguageConfig};
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
#[gpui::test(iterations = 10)]
async fn test_autoindent(
async fn test_transform_autoindent(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
@ -436,15 +427,146 @@ mod tests {
snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx));
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = indoc! {"
let mut x = 0;
while x < 10 {
x += 1;
}
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
deterministic.run_until_parked();
}
provider.finish_completion();
deterministic.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
deterministic: Arc<Deterministic>,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(TestCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);