use crate::{ stream_completion, streaming_diff::{Hunk, StreamingDiff}, OpenAIRequest, }; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint}; use futures::{ channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt, }; use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task}; use language::{IndentSize, Point, Rope, TransactionId}; use std::{cmp, future, ops::Range, sync::Arc}; pub trait CompletionProvider { fn complete( &self, prompt: OpenAIRequest, ) -> BoxFuture<'static, Result>>>; } pub struct OpenAICompletionProvider { api_key: String, executor: Arc, } impl OpenAICompletionProvider { pub fn new(api_key: String, executor: Arc) -> Self { Self { api_key, executor } } } impl CompletionProvider for OpenAICompletionProvider { fn complete( &self, prompt: OpenAIRequest, ) -> BoxFuture<'static, Result>>> { let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); async move { let response = request.await?; let stream = response .filter_map(|response| async move { match response { Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), Err(error) => Some(Err(error)), } }) .boxed(); Ok(stream) } .boxed() } } pub enum Event { Finished, Undone, } pub struct Codegen { provider: Arc, buffer: ModelHandle, range: Range, last_equal_ranges: Vec>, transaction_id: Option, error: Option, generation: Task<()>, idle: bool, _subscription: gpui::Subscription, } impl Entity for Codegen { type Event = Event; } impl Codegen { pub fn new( buffer: ModelHandle, range: Range, provider: Arc, cx: &mut ModelContext, ) -> Self { Self { provider, buffer: buffer.clone(), range, last_equal_ranges: Default::default(), transaction_id: Default::default(), error: Default::default(), idle: true, generation: Task::ready(()), _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), } } fn handle_buffer_event( &mut self, _buffer: ModelHandle, event: &multi_buffer::Event, cx: &mut ModelContext, ) { if let multi_buffer::Event::TransactionUndone { transaction_id } = event { if self.transaction_id == Some(*transaction_id) { self.transaction_id = None; self.generation = Task::ready(()); cx.emit(Event::Undone); } } } pub fn range(&self) -> Range { self.range.clone() } pub fn last_equal_ranges(&self) -> &[Range] { &self.last_equal_ranges } pub fn idle(&self) -> bool { self.idle } pub fn error(&self) -> Option<&anyhow::Error> { self.error.as_ref() } pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { let range = self.range.clone(); let snapshot = self.buffer.read(cx).snapshot(cx); let selected_text = snapshot .text_for_range(range.start..range.end) .collect::(); let selection_start = range.start.to_point(&snapshot); let selection_end = range.end.to_point(&snapshot); let mut base_indent: Option = None; let mut start_row = selection_start.row; if snapshot.is_line_blank(start_row) { if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) { start_row = prev_non_blank_row; } } for row in start_row..=selection_end.row { if snapshot.is_line_blank(row) { continue; } let line_indent = snapshot.indent_size_for_line(row); if let Some(base_indent) = base_indent.as_mut() { if line_indent.len < base_indent.len { *base_indent = line_indent; } } else { base_indent = Some(line_indent); } } let mut normalized_selected_text = selected_text.clone(); if let Some(base_indent) = base_indent { for row in selection_start.row..=selection_end.row { let selection_row = row - selection_start.row; let line_start = normalized_selected_text.point_to_offset(Point::new(selection_row, 0)); let indent_len = if row == selection_start.row { base_indent.len.saturating_sub(selection_start.column) } else { let line_len = normalized_selected_text.line_len(selection_row); cmp::min(line_len, base_indent.len) }; let indent_end = cmp::min( line_start + indent_len as usize, normalized_selected_text.len(), ); normalized_selected_text.replace(line_start..indent_end, ""); } } let response = self.provider.complete(prompt); self.generation = cx.spawn_weak(|this, mut cx| { async move { let generate = async { let mut edit_start = range.start.to_offset(&snapshot); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff = cx.background().spawn(async move { let chunks = strip_markdown_codeblock(response.await?); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); let mut indent_len; let indent_text; if let Some(base_indent) = base_indent { indent_len = base_indent.len; indent_text = match base_indent.kind { language::IndentKind::Space => " ", language::IndentKind::Tab => "\t", }; } else { indent_len = 0; indent_text = ""; }; let mut first_line_len = 0; let mut first_line_non_whitespace_char_ix = None; let mut first_line = true; let mut new_text = String::new(); while let Some(chunk) = chunks.next().await { let chunk = chunk?; let mut lines = chunk.split('\n'); if let Some(mut line) = lines.next() { if first_line { if first_line_non_whitespace_char_ix.is_none() { if let Some(mut char_ix) = line.find(|ch: char| !ch.is_whitespace()) { line = &line[char_ix..]; char_ix += first_line_len; first_line_non_whitespace_char_ix = Some(char_ix); let first_line_indent = char_ix .saturating_sub(selection_start.column as usize) as usize; new_text .push_str(&indent_text.repeat(first_line_indent)); indent_len = indent_len.saturating_sub(char_ix as u32); } } first_line_len += line.len(); } if first_line_non_whitespace_char_ix.is_some() { new_text.push_str(line); } } for line in lines { first_line = false; new_text.push('\n'); if !line.is_empty() { new_text.push_str(&indent_text.repeat(indent_len as usize)); } new_text.push_str(line); } let hunks = diff.push_new(&new_text); hunks_tx.send(hunks).await?; new_text.clear(); } hunks_tx.send(diff.finish()).await?; anyhow::Ok(()) }); while let Some(hunks) = hunks_rx.next().await { let this = if let Some(this) = this.upgrade(&cx) { this } else { break; }; this.update(&mut cx, |this, cx| { this.last_equal_ranges.clear(); let transaction = this.buffer.update(cx, |buffer, cx| { // Avoid grouping assistant edits with user edits. buffer.finalize_last_transaction(cx); buffer.start_transaction(cx); buffer.edit( hunks.into_iter().filter_map(|hunk| match hunk { Hunk::Insert { text } => { let edit_start = snapshot.anchor_after(edit_start); Some((edit_start..edit_start, text)) } Hunk::Remove { len } => { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); edit_start = edit_end; Some((edit_range, String::new())) } Hunk::Keep { len } => { let edit_end = edit_start + len; let edit_range = snapshot.anchor_after(edit_start) ..snapshot.anchor_before(edit_end); edit_start += len; this.last_equal_ranges.push(edit_range); None } }), None, cx, ); buffer.end_transaction(cx) }); if let Some(transaction) = transaction { if let Some(first_transaction) = this.transaction_id { // Group all assistant edits into the first transaction. this.buffer.update(cx, |buffer, cx| { buffer.merge_transactions( transaction, first_transaction, cx, ) }); } else { this.transaction_id = Some(transaction); this.buffer.update(cx, |buffer, cx| { buffer.finalize_last_transaction(cx) }); } } cx.notify(); }); } diff.await?; anyhow::Ok(()) }; let result = generate.await; if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { this.last_equal_ranges.clear(); this.idle = true; if let Err(error) = result { this.error = Some(error); } cx.emit(Event::Finished); cx.notify(); }); } } }); self.error.take(); self.idle = false; cx.notify(); } pub fn undo(&mut self, cx: &mut ModelContext) { if let Some(transaction_id) = self.transaction_id { self.buffer .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); } } } fn strip_markdown_codeblock( stream: impl Stream>, ) -> impl Stream> { let mut first_line = true; let mut buffer = String::new(); let mut starts_with_fenced_code_block = false; stream.filter_map(move |chunk| { let chunk = match chunk { Ok(chunk) => chunk, Err(err) => return future::ready(Some(Err(err))), }; buffer.push_str(&chunk); if first_line { if buffer == "" || buffer == "`" || buffer == "``" { return future::ready(None); } else if buffer.starts_with("```") { starts_with_fenced_code_block = true; if let Some(newline_ix) = buffer.find('\n') { buffer.replace_range(..newline_ix + 1, ""); first_line = false; } else { return future::ready(None); } } } let text = if starts_with_fenced_code_block { buffer .strip_suffix("\n```\n") .or_else(|| buffer.strip_suffix("\n```")) .or_else(|| buffer.strip_suffix("\n``")) .or_else(|| buffer.strip_suffix("\n`")) .or_else(|| buffer.strip_suffix('\n')) .unwrap_or(&buffer) } else { &buffer }; if text.contains('\n') { first_line = false; } let remainder = buffer.split_off(text.len()); let result = if buffer.is_empty() { None } else { Some(Ok(buffer.clone())) }; buffer = remainder; future::ready(result) }) } #[cfg(test)] mod tests { use futures::stream; use super::*; #[gpui::test] async fn test_strip_markdown_codeblock() { assert_eq!( strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "```js\nLorem ipsum dolor\n```" ); assert_eq!( strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "``\nLorem ipsum dolor\n```" ); fn chunks(text: &str, size: usize) -> impl Stream> { stream::iter( text.chars() .collect::>() .chunks(size) .map(|chunk| Ok(chunk.iter().collect::())) .collect::>(), ) } } }