From 42f02eb4e7adafe27d444b8b9ffbe68ddce9e714 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 21 Aug 2023 15:11:06 +0200 Subject: [PATCH] Incrementally diff input coming from GPT --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/ai.rs | 107 +++++++++++++++-- crates/ai/src/assistant.rs | 99 +--------------- crates/ai/src/refactor.rs | 233 +++++++++++++++++++++++++++++++++---- prompt.md | 11 -- 6 files changed, 315 insertions(+), 137 deletions(-) delete mode 100644 prompt.md diff --git a/Cargo.lock b/Cargo.lock index 69285a1abf..f802d90739 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,6 +116,7 @@ dependencies = [ "serde", "serde_json", "settings", + "similar", "smol", "theme", "tiktoken-rs 0.4.5", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 013565e14f..bae20f7537 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -29,6 +29,7 @@ regex.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +similar = "1.3" smol.workspace = true tiktoken-rs = "0.4" diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 7874bb46a5..511e7fddd7 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -2,27 +2,31 @@ pub mod assistant; mod assistant_settings; mod refactor; -use anyhow::Result; +use anyhow::{anyhow, Result}; pub use assistant::AssistantPanel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; -use futures::StreamExt; -use gpui::AppContext; +use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use gpui::{executor::Background, AppContext}; +use isahc::{http::StatusCode, Request, RequestExt}; use regex::Regex; use serde::{Deserialize, Serialize}; use std::{ cmp::Reverse, ffi::OsStr, fmt::{self, Display}, + io, path::PathBuf, sync::Arc, }; use util::paths::CONVERSATIONS_DIR; +const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; + // Data types for chat completion requests #[derive(Debug, Serialize)] -struct OpenAIRequest { +pub struct OpenAIRequest { model: String, messages: Vec, stream: bool, @@ -116,7 +120,7 @@ struct RequestMessage { } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -struct ResponseMessage { +pub struct ResponseMessage { role: Option, content: Option, } @@ -150,7 +154,7 @@ impl Display for Role { } #[derive(Deserialize, Debug)] -struct OpenAIResponseStreamEvent { +pub struct OpenAIResponseStreamEvent { pub id: Option, pub object: String, pub created: u32, @@ -160,14 +164,14 @@ struct OpenAIResponseStreamEvent { } #[derive(Deserialize, Debug)] -struct Usage { +pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } #[derive(Deserialize, Debug)] -struct ChatChoiceDelta { +pub struct ChatChoiceDelta { pub index: u32, pub delta: ResponseMessage, pub finish_reason: Option, @@ -190,4 +194,91 @@ struct OpenAIChoice { pub fn init(cx: &mut AppContext) { assistant::init(cx); + refactor::init(cx); +} + +pub async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } } diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5026182ed..f134eeeeb6 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::{AssistantDockPosition, AssistantSettings}, - MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent, - RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage, + stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage, + Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -12,26 +12,23 @@ use editor::{ Anchor, Editor, ToOffset, }; use fs::Fs; -use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; +use futures::StreamExt; use gpui::{ actions, elements::*, - executor::Background, geometry::vector::{vec2f, Vector2F}, platform::{CursorStyle, MouseButton}, Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; -use isahc::{http::StatusCode, Request, RequestExt}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use search::BufferSearchBar; -use serde::Deserialize; use settings::SettingsStore; use std::{ cell::RefCell, cmp, env, fmt::Write, - io, iter, + iter, ops::Range, path::{Path, PathBuf}, rc::Rc, @@ -46,8 +43,6 @@ use workspace::{ Save, ToggleZoom, Toolbar, Workspace, }; -const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; - actions!( assistant, [ @@ -2144,92 +2139,6 @@ impl Message { } } -async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/ai/src/refactor.rs b/crates/ai/src/refactor.rs index e1b57680ee..fc6cbdb8c4 100644 --- a/crates/ai/src/refactor.rs +++ b/crates/ai/src/refactor.rs @@ -1,16 +1,24 @@ -use collections::HashMap; -use editor::Editor; +use crate::{stream_completion, OpenAIRequest, RequestMessage, Role}; +use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset}; +use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt}; use gpui::{ actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle, + WeakViewHandle, }; -use std::sync::Arc; +use menu::Confirm; +use serde::Deserialize; +use similar::ChangeTag; +use std::{env, iter, ops::Range, sync::Arc}; +use util::TryFutureExt; use workspace::{Modal, Workspace}; actions!(assistant, [Refactor]); -fn init(cx: &mut AppContext) { +pub fn init(cx: &mut AppContext) { cx.set_global(RefactoringAssistant::new()); cx.add_action(RefactoringModal::deploy); + cx.add_action(RefactoringModal::confirm); } pub struct RefactoringAssistant { @@ -24,10 +32,122 @@ impl RefactoringAssistant { } } - fn refactor(&mut self, editor: &ViewHandle, prompt: &str, cx: &mut AppContext) {} + fn refactor(&mut self, editor: &ViewHandle, prompt: &str, cx: &mut AppContext) { + let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); + let selection = editor.read(cx).selections.newest_anchor().clone(); + let selected_text = buffer + .text_for_range(selection.start..selection.end) + .collect::(); + let language_name = buffer + .language_at(selection.start) + .map(|language| language.name()); + let language_name = language_name.as_deref().unwrap_or(""); + let request = OpenAIRequest { + model: "gpt-4".into(), + messages: vec![ + RequestMessage { + role: Role::User, + content: format!( + "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code." + ), + }], + stream: true, + }; + let api_key = env::var("OPENAI_API_KEY").unwrap(); + let response = stream_completion(api_key, cx.background().clone(), request); + let editor = editor.downgrade(); + self.pending_edits_by_editor.insert( + editor.id(), + cx.spawn(|mut cx| { + async move { + let selection_start = selection.start.to_offset(&buffer); + + // Find unique words in the selected text to use as diff boundaries. + let mut duplicate_words = HashSet::default(); + let mut unique_old_words = HashMap::default(); + for (range, word) in words(&selected_text) { + if !duplicate_words.contains(word) { + if unique_old_words.insert(word, range.end).is_some() { + unique_old_words.remove(word); + duplicate_words.insert(word); + } + } + } + + let mut new_text = String::new(); + let mut messages = response.await?; + let mut new_word_search_start_ix = 0; + let mut last_old_word_end_ix = 0; + + 'outer: loop { + let start = new_word_search_start_ix; + let mut words = words(&new_text[start..]); + while let Some((range, new_word)) = words.next() { + // We found a word in the new text that was unique in the old text. We can use + // it as a diff boundary, and start applying edits. + if let Some(old_word_end_ix) = unique_old_words.remove(new_word) { + if old_word_end_ix > last_old_word_end_ix { + drop(words); + + let remainder = new_text.split_off(start + range.end); + let edits = diff( + selection_start + last_old_word_end_ix, + &selected_text[last_old_word_end_ix..old_word_end_ix], + &new_text, + &buffer, + ); + editor.update(&mut cx, |editor, cx| { + editor + .buffer() + .update(cx, |buffer, cx| buffer.edit(edits, None, cx)) + })?; + + new_text = remainder; + new_word_search_start_ix = 0; + last_old_word_end_ix = old_word_end_ix; + continue 'outer; + } + } + + new_word_search_start_ix = start + range.end; + } + drop(words); + + // Buffer incoming text, stopping if the stream was exhausted. + if let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + if let Some(text) = choice.delta.content { + new_text.push_str(&text); + } + } + } else { + break; + } + } + + let edits = diff( + selection_start + last_old_word_end_ix, + &selected_text[last_old_word_end_ix..], + &new_text, + &buffer, + ); + editor.update(&mut cx, |editor, cx| { + editor + .buffer() + .update(cx, |buffer, cx| buffer.edit(edits, None, cx)) + })?; + + anyhow::Ok(()) + } + .log_err() + }), + ); + } } struct RefactoringModal { + editor: WeakViewHandle, prompt_editor: ViewHandle, has_focus: bool, } @@ -42,7 +162,7 @@ impl View for RefactoringModal { } fn render(&mut self, cx: &mut ViewContext) -> AnyElement { - todo!() + ChildView::new(&self.prompt_editor, cx).into_any() } fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext) { @@ -60,29 +180,96 @@ impl Modal for RefactoringModal { } fn dismiss_on_event(event: &Self::Event) -> bool { - todo!() + // TODO + false } } impl RefactoringModal { fn deploy(workspace: &mut Workspace, _: &Refactor, cx: &mut ViewContext) { - workspace.toggle_modal(cx, |_, cx| { - let prompt_editor = cx.add_view(|cx| { - Editor::auto_height( - 4, - Some(Arc::new(|theme| theme.search.editor.input.clone())), - cx, - ) + if let Some(editor) = workspace + .active_item(cx) + .and_then(|item| Some(item.downcast::()?.downgrade())) + { + workspace.toggle_modal(cx, |_, cx| { + let prompt_editor = cx.add_view(|cx| { + Editor::auto_height( + 4, + Some(Arc::new(|theme| theme.search.editor.input.clone())), + cx, + ) + }); + cx.add_view(|_| RefactoringModal { + editor, + prompt_editor, + has_focus: false, + }) }); - cx.add_view(|_| RefactoringModal { - prompt_editor, - has_focus: false, - }) - }); + } + } + + fn confirm(&mut self, _: &Confirm, cx: &mut ViewContext) { + if let Some(editor) = self.editor.upgrade(cx) { + let prompt = self.prompt_editor.read(cx).text(cx); + cx.update_global(|assistant: &mut RefactoringAssistant, cx| { + assistant.refactor(&editor, &prompt, cx); + }); + } } } +fn words(text: &str) -> impl Iterator, &str)> { + let mut word_start_ix = None; + let mut chars = text.char_indices(); + iter::from_fn(move || { + while let Some((ix, ch)) = chars.next() { + if let Some(start_ix) = word_start_ix { + if !ch.is_alphanumeric() { + let word = &text[start_ix..ix]; + word_start_ix.take(); + return Some((start_ix..ix, word)); + } + } else { + if ch.is_alphanumeric() { + word_start_ix = Some(ix); + } + } + } + None + }) +} -// ABCDEFG -// XCDEFG -// -// +fn diff<'a>( + start_ix: usize, + old_text: &'a str, + new_text: &'a str, + old_buffer_snapshot: &MultiBufferSnapshot, +) -> Vec<(Range, &'a str)> { + let mut edit_start = start_ix; + let mut edits = Vec::new(); + let diff = similar::TextDiff::from_words(old_text, &new_text); + for change in diff.iter_all_changes() { + let value = change.value(); + let edit_end = edit_start + value.len(); + match change.tag() { + ChangeTag::Equal => { + edit_start = edit_end; + } + ChangeTag::Delete => { + edits.push(( + old_buffer_snapshot.anchor_after(edit_start) + ..old_buffer_snapshot.anchor_before(edit_end), + "", + )); + edit_start = edit_end; + } + ChangeTag::Insert => { + edits.push(( + old_buffer_snapshot.anchor_after(edit_start) + ..old_buffer_snapshot.anchor_after(edit_start), + value, + )); + } + } + } + edits +} diff --git a/prompt.md b/prompt.md deleted file mode 100644 index 33213a5859..0000000000 --- a/prompt.md +++ /dev/null @@ -1,11 +0,0 @@ -Given a snippet as the input, you must produce an array of edits. An edit has the following structure: - -{ skip: "skip", delete: "delete", insert: "insert" } - -`skip` is a string in the input that should be left unchanged. `delete` is a string in the input located right after the skipped text that should be deleted. `insert` is a new string that should be inserted after the end of the text in `skip`. It's crucial that a string in the input can only be skipped or deleted once and only once. - -Your task is to produce an array of edits. `delete` and `insert` can be empty if nothing changed. When `skip`, `delete` or `insert` are longer than 20 characters, split them into multiple edits. - -Check your reasoning by concatenating all the strings in `skip` and `delete`. If the text is the same as the input snippet then the edits are valid. - -It's crucial that you reply only with edits. No prose or remarks.