diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 1d56a6308c..6d4fce2f6d 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,6 +1,6 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, - codegen::{self, Codegen, OpenAICompletionProvider}, + codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider}, stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; @@ -270,24 +270,28 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - let selection = editor.read(cx).selections.newest_anchor().clone(); - let range = selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot); let provider = Arc::new(OpenAICompletionProvider::new( api_key, cx.background().clone(), )); - let codegen = - cx.add_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), range, provider, cx)); - let assist_kind = if editor.read(cx).selections.newest::(cx).is_empty() { - InlineAssistKind::Generate + let selection = editor.read(cx).selections.newest_anchor().clone(); + let codegen_kind = if editor.read(cx).selections.newest::(cx).is_empty() { + CodegenKind::Generate { + position: selection.start, + } } else { - InlineAssistKind::Transform + CodegenKind::Transform { + range: selection.start..selection.end, + } }; + let codegen = cx.add_model(|cx| { + Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) + }); + let measurements = Rc::new(Cell::new(BlockMeasurements::default())); let inline_assistant = cx.add_view(|cx| { let assistant = InlineAssistant::new( inline_assist_id, - assist_kind, measurements.clone(), self.include_conversation_in_next_inline_assist, self.inline_prompt_history.clone(), @@ -330,7 +334,6 @@ impl AssistantPanel { self.pending_inline_assists.insert( inline_assist_id, PendingInlineAssist { - kind: assist_kind, editor: editor.downgrade(), inline_assistant: Some((block_id, inline_assistant.clone())), codegen: codegen.clone(), @@ -348,6 +351,14 @@ impl AssistantPanel { } } }), + cx.observe(&codegen, { + let editor = editor.downgrade(); + move |this, _, cx| { + if let Some(editor) = editor.upgrade(cx) { + this.update_highlights_for_editor(&editor, cx); + } + } + }), cx.subscribe(&codegen, move |this, codegen, event, cx| match event { codegen::Event::Undone => { this.finish_inline_assist(inline_assist_id, false, cx) @@ -542,8 +553,8 @@ impl AssistantPanel { if let Some(language_name) = language_name { writeln!(prompt, "You're an expert {language_name} engineer.").unwrap(); } - match pending_assist.kind { - InlineAssistKind::Transform => { + match pending_assist.codegen.read(cx).kind() { + CodegenKind::Transform { .. } => { writeln!( prompt, "You're currently working inside an editor on this file:" @@ -583,7 +594,7 @@ impl AssistantPanel { ) .unwrap(); } - InlineAssistKind::Generate => { + CodegenKind::Generate { .. } => { writeln!( prompt, "You're currently working inside an editor on this file:" @@ -2649,12 +2660,6 @@ enum InlineAssistantEvent { }, } -#[derive(Copy, Clone)] -enum InlineAssistKind { - Transform, - Generate, -} - struct InlineAssistant { id: usize, prompt_editor: ViewHandle, @@ -2769,7 +2774,6 @@ impl View for InlineAssistant { impl InlineAssistant { fn new( id: usize, - kind: InlineAssistKind, measurements: Rc>, include_conversation: bool, prompt_history: VecDeque, @@ -2781,9 +2785,9 @@ impl InlineAssistant { Some(Arc::new(|theme| theme.assistant.inline.editor.clone())), cx, ); - let placeholder = match kind { - InlineAssistKind::Transform => "Enter transformation prompt…", - InlineAssistKind::Generate => "Enter generation prompt…", + let placeholder = match codegen.read(cx).kind() { + CodegenKind::Transform { .. } => "Enter transformation prompt…", + CodegenKind::Generate { .. } => "Enter generation prompt…", }; editor.set_placeholder_text(placeholder, cx); editor @@ -2929,7 +2933,6 @@ struct BlockMeasurements { } struct PendingInlineAssist { - kind: InlineAssistKind, editor: WeakViewHandle, inline_assistant: Option<(BlockId, ViewHandle)>, codegen: ModelHandle, diff --git a/crates/ai/src/codegen.rs b/crates/ai/src/codegen.rs index 9657d9a492..bc13e294fa 100644 --- a/crates/ai/src/codegen.rs +++ b/crates/ai/src/codegen.rs @@ -4,12 +4,14 @@ use crate::{ OpenAIRequest, }; use anyhow::Result; -use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint}; +use editor::{ + multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint, +}; use futures::{ channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, }; use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; -use language::{IndentSize, Point, Rope, TransactionId}; +use language::{Rope, TransactionId}; use std::{cmp, future, ops::Range, sync::Arc}; pub trait CompletionProvider { @@ -57,10 +59,17 @@ pub enum Event { Undone, } +#[derive(Clone)] +pub enum CodegenKind { + Transform { range: Range }, + Generate { position: Anchor }, +} + pub struct Codegen { provider: Arc, buffer: ModelHandle, - range: Range, + snapshot: MultiBufferSnapshot, + kind: CodegenKind, last_equal_ranges: Vec>, transaction_id: Option, error: Option, @@ -76,14 +85,31 @@ impl Entity for Codegen { impl Codegen { pub fn new( buffer: ModelHandle, - range: Range, + mut kind: CodegenKind, provider: Arc, cx: &mut ModelContext, ) -> Self { + let snapshot = buffer.read(cx).snapshot(cx); + match &mut kind { + CodegenKind::Transform { range } => { + let mut point_range = range.to_point(&snapshot); + point_range.start.column = 0; + if point_range.end.column > 0 || point_range.start.row == point_range.end.row { + point_range.end.column = snapshot.line_len(point_range.end.row); + } + range.start = snapshot.anchor_before(point_range.start); + range.end = snapshot.anchor_after(point_range.end); + } + CodegenKind::Generate { position } => { + *position = position.bias_right(&snapshot); + } + } + Self { provider, buffer: buffer.clone(), - range, + snapshot, + kind, last_equal_ranges: Default::default(), transaction_id: Default::default(), error: Default::default(), @@ -109,7 +135,14 @@ impl Codegen { } pub fn range(&self) -> Range { - self.range.clone() + match &self.kind { + CodegenKind::Transform { range } => range.clone(), + CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position, + } + } + + pub fn kind(&self) -> &CodegenKind { + &self.kind } pub fn last_equal_ranges(&self) -> &[Range] { @@ -125,56 +158,18 @@ impl Codegen { } pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { - let range = self.range.clone(); - let snapshot = self.buffer.read(cx).snapshot(cx); + let range = self.range(); + let snapshot = self.snapshot.clone(); let selected_text = snapshot .text_for_range(range.start..range.end) .collect::(); let selection_start = range.start.to_point(&snapshot); - let selection_end = range.end.to_point(&snapshot); - - let mut base_indent: Option = None; - let mut start_row = selection_start.row; - if snapshot.is_line_blank(start_row) { - if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { - start_row = prev_non_blank_row; - } - } - for row in start_row..=selection_end.row { - if snapshot.is_line_blank(row) { - continue; - } - - let line_indent = snapshot.indent_size_for_line(row); - if let Some(base_indent) = base_indent.as_mut() { - if line_indent.len < base_indent.len { - *base_indent = line_indent; - } - } else { - base_indent = Some(line_indent); - } - } - - let mut normalized_selected_text = selected_text.clone(); - if let Some(base_indent) = base_indent { - for row in selection_start.row..=selection_end.row { - let selection_row = row - selection_start.row; - let line_start = - normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); - let indent_len = if row == selection_start.row { - base_indent.len.saturating_sub(selection_start.column) - } else { - let line_len = normalized_selected_text.line_len(selection_row); - cmp::min(line_len, base_indent.len) - }; - let indent_end = cmp::min( - line_start + indent_len as usize, - normalized_selected_text.len(), - ); - normalized_selected_text.replace(line_start..indent_end, ""); - } - } + let suggested_line_indent = snapshot + .suggested_indents(selection_start.row..selection_start.row + 1, cx) + .into_values() + .next() + .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row)); let response = self.provider.complete(prompt); self.generation = cx.spawn_weak(|this, mut cx| { @@ -188,66 +183,58 @@ impl Codegen { futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); - let mut indent_len; - let indent_text; - if let Some(base_indent) = base_indent { - indent_len = base_indent.len; - indent_text = match base_indent.kind { - language::IndentKind::Space => " ", - language::IndentKind::Tab => "\t", - }; - } else { - indent_len = 0; - indent_text = ""; - }; - - let mut first_line_len = 0; - let mut first_line_non_whitespace_char_ix = None; - let mut first_line = true; let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; while let Some(chunk) = chunks.next().await { let chunk = chunk?; - let mut lines = chunk.split('\n'); - if let Some(mut line) = lines.next() { - if first_line { - if first_line_non_whitespace_char_ix.is_none() { - if let Some(mut char_ix) = - line.find(|ch: char| !ch.is_whitespace()) - { - line = &line[char_ix..]; - char_ix += first_line_len; - first_line_non_whitespace_char_ix = Some(char_ix); - let first_line_indent = char_ix - .saturating_sub(selection_start.column as usize) - as usize; - new_text - .push_str(&indent_text.repeat(first_line_indent)); - indent_len = indent_len.saturating_sub(char_ix as u32); - } - } - first_line_len += line.len(); - } - - if first_line_non_whitespace_char_ix.is_some() { - new_text.push_str(line); - } - } - - for line in lines { - first_line = false; - new_text.push('\n'); - if !line.is_empty() { - new_text.push_str(&indent_text.repeat(indent_len as usize)); - } + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { new_text.push_str(line); - } + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); - let hunks = diff.push_new(&new_text); - hunks_tx.send(hunks).await?; - new_text.clear(); + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub(selection_start.column as usize); + } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); + } + } + + if lines.peek().is_some() { + hunks_tx.send(diff.push_new(&new_text)).await?; + hunks_tx.send(diff.push_new("\n")).await?; + new_text.clear(); + line_indent = None; + first_line = false; + } + } } + hunks_tx.send(diff.push_new(&new_text)).await?; hunks_tx.send(diff.finish()).await?; anyhow::Ok(()) @@ -285,7 +272,7 @@ impl Codegen { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); - edit_start += len; + edit_start = edit_end; this.last_equal_ranges.push(edit_range); None } @@ -410,16 +397,20 @@ mod tests { use futures::stream; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; - use language::{tree_sitter_rust, Buffer, Language, LanguageConfig}; + use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; + use settings::SettingsStore; #[gpui::test(iterations = 10)] - async fn test_autoindent( + async fn test_transform_autoindent( cx: &mut TestAppContext, mut rng: StdRng, deterministic: Arc, ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + let text = indoc! {" fn main() { let x = 0; @@ -436,15 +427,146 @@ mod tests { snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4)) }); let provider = Arc::new(TestCompletionProvider::new()); - let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx)); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Transform { range }, + provider.clone(), + cx, + ) + }); codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); - let mut new_text = indoc! {" - let mut x = 0; - while x < 10 { - x += 1; - } + let mut new_text = concat!( + " let mut x = 0;\n", + " while x < 10 {\n", + " x += 1;\n", + " }", + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_past_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + le + } "}; + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 6)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Generate { position }, + provider.clone(), + cx, + ) + }); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = concat!( + "t mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + provider.send_completion(chunk); + new_text = suffix; + deterministic.run_until_parked(); + } + provider.finish_completion(); + deterministic.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_before_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + deterministic: Arc, + ) { + cx.set_global(cx.read(SettingsStore::test)); + cx.update(language_settings::init); + + let text = concat!( + "fn main() {\n", + " \n", + "}\n" // + ); + let buffer = + cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx)); + let position = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 2)) + }); + let provider = Arc::new(TestCompletionProvider::new()); + let codegen = cx.add_model(|cx| { + Codegen::new( + buffer.clone(), + CodegenKind::Generate { position }, + provider.clone(), + cx, + ) + }); + codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let mut new_text = concat!( + "let mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); while !new_text.is_empty() { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len);