From 937aabfdfdd435807368068f6e47f7d03981919c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Aug 2023 11:24:55 +0200 Subject: [PATCH] Extract a `strip_markdown_codeblock` function --- crates/ai/src/assistant.rs | 197 +++++++++++++++++++++---------------- 1 file changed, 110 insertions(+), 87 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index f7abcdf748..0333a723e9 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -16,7 +16,7 @@ use editor::{ Anchor, Editor, MultiBufferSnapshot, ToOffset, ToPoint, }; use fs::Fs; -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{ actions, elements::*, @@ -620,7 +620,10 @@ impl AssistantPanel { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff = cx.background().spawn(async move { - let mut messages = response.await?; + let chunks = strip_markdown_codeblock(response.await?.filter_map( + |message| async move { message.ok()?.choices.pop()?.delta.content }, + )); + futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); let indentation_len; @@ -636,93 +639,21 @@ impl AssistantPanel { indentation_text = ""; }; - let mut inside_first_line = true; - let mut starts_with_fenced_code_block = None; - let mut has_pending_newline = false; - let mut new_text = String::new(); + let mut new_text = indentation_text + .repeat(indentation_len.saturating_sub(selection_start.column) as usize); - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(mut choice) = message.choices.pop() { - if has_pending_newline { - has_pending_newline = false; - choice - .delta - .content - .get_or_insert(String::new()) - .insert(0, '\n'); - } + while let Some(message) = chunks.next().await { + let mut lines = message.split('\n'); + if let Some(first_line) = lines.next() { + new_text.push_str(first_line); + } - // Buffer a trailing codeblock fence. Note that we don't stop - // right away because this may be an inner fence that we need - // to insert into the editor. - if starts_with_fenced_code_block.is_some() - && choice.delta.content.as_deref() == Some("\n```") - { - new_text.push_str("\n```"); - continue; - } - - // If this was the last completion and we started with a codeblock - // fence and we ended with another codeblock fence, then we can - // stop right away. Otherwise, whatever text we buffered will be - // processed normally. - if choice.finish_reason.is_some() - && starts_with_fenced_code_block.unwrap_or(false) - && new_text == "\n```" - { - break; - } - - if let Some(text) = choice.delta.content { - // Never push a newline if there's nothing after it. This is - // useful to detect if the newline was pushed because of a - // trailing codeblock fence. - let text = if let Some(prefix) = text.strip_suffix('\n') { - has_pending_newline = true; - prefix - } else { - text.as_str() - }; - - if text.is_empty() { - continue; - } - - let mut lines = text.split('\n'); - if let Some(line) = lines.next() { - if starts_with_fenced_code_block.is_none() { - starts_with_fenced_code_block = - Some(line.starts_with("```")); - } - - // Avoid pushing the first line if it's the start of a fenced code block. - if !inside_first_line || !starts_with_fenced_code_block.unwrap() - { - new_text.push_str(&line); - } - } - - for line in lines { - if inside_first_line && starts_with_fenced_code_block.unwrap() { - // If we were inside the first line and that line was the - // start of a fenced code block, we just need to push the - // leading indentation of the original selection. - new_text.push_str(&indentation_text.repeat( - indentation_len.saturating_sub(selection_start.column) - as usize, - )); - } else { - // Otherwise, we need to push a newline and the base indentation. - new_text.push('\n'); - new_text.push_str( - &indentation_text.repeat(indentation_len as usize), - ); - } - - new_text.push_str(line); - inside_first_line = false; - } + for line in lines { + new_text.push('\n'); + if !line.is_empty() { + new_text + .push_str(&indentation_text.repeat(indentation_len as usize)); + new_text.push_str(line); } } @@ -2919,10 +2850,58 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } +fn strip_markdown_codeblock(stream: impl Stream) -> impl Stream { + let mut first_line = true; + let mut buffer = String::new(); + let mut starts_with_fenced_code_block = false; + stream.filter_map(move |chunk| { + buffer.push_str(&chunk); + + if first_line { + if buffer == "" || buffer == "`" || buffer == "``" { + return futures::future::ready(None); + } else if buffer.starts_with("```") { + starts_with_fenced_code_block = true; + if let Some(newline_ix) = buffer.find('\n') { + buffer.replace_range(..newline_ix + 1, ""); + first_line = false; + } else { + return futures::future::ready(None); + } + } + } + + let text = if starts_with_fenced_code_block { + buffer + .strip_suffix("\n```") + .or_else(|| buffer.strip_suffix("\n``")) + .or_else(|| buffer.strip_suffix("\n`")) + .or_else(|| buffer.strip_suffix('\n')) + .unwrap_or(&buffer) + } else { + &buffer + }; + + if text.contains('\n') { + first_line = false; + } + + let remainder = buffer.split_off(text.len()); + let result = if buffer.is_empty() { + None + } else { + Some(buffer.clone()) + }; + buffer = remainder; + futures::future::ready(result) + }) +} + #[cfg(test)] mod tests { use super::*; use crate::MessageId; + use futures::stream; use gpui::AppContext; #[gpui::test] @@ -3291,6 +3270,50 @@ mod tests { ); } + #[gpui::test] + async fn test_strip_markdown_codeblock() { + assert_eq!( + strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + .collect::() + .await, + "Lorem ipsum dolor" + ); + assert_eq!( + strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) + .collect::() + .await, + "```js\nLorem ipsum dolor\n```" + ); + assert_eq!( + strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + .collect::() + .await, + "``\nLorem ipsum dolor\n```" + ); + + fn chunks(text: &str, size: usize) -> impl Stream { + stream::iter( + text.chars() + .collect::>() + .chunks(size) + .map(|chunk| chunk.iter().collect::()) + .collect::>(), + ) + } + } + fn messages( conversation: &ModelHandle, cx: &AppContext,