diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 0466259b24..f62c91fcb7 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -118,7 +118,7 @@ impl Codegen { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff = cx.background().spawn(async move { - let chunks = strip_markdown_codeblock(response.await?); + let chunks = strip_invalid_spans_from_codeblock(response.await?); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); @@ -279,12 +279,13 @@ impl Codegen { } } -fn strip_markdown_codeblock( +fn strip_invalid_spans_from_codeblock( stream: impl Stream>, ) -> impl Stream> { let mut first_line = true; let mut buffer = String::new(); - let mut starts_with_fenced_code_block = false; + let mut starts_with_markdown_codeblock = false; + let mut includes_start_or_end_span = false; stream.filter_map(move |chunk| { let chunk = match chunk { Ok(chunk) => chunk, @@ -292,11 +293,31 @@ fn strip_markdown_codeblock( }; buffer.push_str(&chunk); + if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") { + includes_start_or_end_span = true; + + buffer = buffer + .strip_prefix("<|S|>") + .or_else(|| buffer.strip_prefix("<|S|")) + .unwrap_or(&buffer) + .to_string(); + } else if buffer.ends_with("|E|>") { + includes_start_or_end_span = true; + } else if buffer.starts_with("<|") + || buffer.starts_with("<|S") + || buffer.starts_with("<|S|") + || buffer.ends_with("|") + || buffer.ends_with("|E") + || buffer.ends_with("|E|") + { + return future::ready(None); + } + if first_line { if buffer == "" || buffer == "`" || buffer == "``" { return future::ready(None); } else if buffer.starts_with("```") { - starts_with_fenced_code_block = true; + starts_with_markdown_codeblock = true; if let Some(newline_ix) = buffer.find('\n') { buffer.replace_range(..newline_ix + 1, ""); first_line = false; @@ -306,16 +327,26 @@ fn strip_markdown_codeblock( } } - let text = if starts_with_fenced_code_block { - buffer + let mut text = buffer.to_string(); + if starts_with_markdown_codeblock { + text = text .strip_suffix("\n```\n") - .or_else(|| buffer.strip_suffix("\n```")) - .or_else(|| buffer.strip_suffix("\n``")) - .or_else(|| buffer.strip_suffix("\n`")) - .or_else(|| buffer.strip_suffix('\n')) - .unwrap_or(&buffer) - } else { - &buffer + .or_else(|| text.strip_suffix("\n```")) + .or_else(|| text.strip_suffix("\n``")) + .or_else(|| text.strip_suffix("\n`")) + .or_else(|| text.strip_suffix('\n')) + .unwrap_or(&text) + .to_string(); + } + + if includes_start_or_end_span { + text = text + .strip_suffix("|E|>") + .or_else(|| text.strip_suffix("E|>")) + .or_else(|| text.strip_prefix("|>")) + .or_else(|| text.strip_prefix(">")) + .unwrap_or(&text) + .to_string(); }; if text.contains('\n') { @@ -328,6 +359,7 @@ fn strip_markdown_codeblock( } else { Some(Ok(buffer.clone())) }; + buffer = remainder; future::ready(result) }) @@ -558,50 +590,82 @@ mod tests { } #[gpui::test] - async fn test_strip_markdown_codeblock() { + async fn test_strip_invalid_spans_from_codeblock() { assert_eq!( - strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2)) + strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2)) + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( - strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) + strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "Lorem ipsum dolor" ); assert_eq!( - strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2)) - .map(|chunk| chunk.unwrap()) - .collect::() - .await, + strip_invalid_spans_from_codeblock(chunks( + "```html\n```js\nLorem ipsum dolor\n```\n```", + 2 + )) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, "```js\nLorem ipsum dolor\n```" ); assert_eq!( - strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) + strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2)) .map(|chunk| chunk.unwrap()) .collect::() .await, "``\nLorem ipsum dolor\n```" ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); + assert_eq!( + strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2)) + .map(|chunk| chunk.unwrap()) + .collect::() + .await, + "Lorem ipsum" + ); fn chunks(text: &str, size: usize) -> impl Stream> { stream::iter( text.chars() diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 25af023c40..b678c6fe3b 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -80,12 +80,12 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S if !flushed_selection { // The collapsed node ends after the selection starts, so we'll flush the selection first. summary.extend(buffer.text_for_range(offset..selected_range.start)); - summary.push_str("<|START|"); + summary.push_str("<|S|"); if selected_range.end == selected_range.start { summary.push_str(">"); } else { summary.extend(buffer.text_for_range(selected_range.clone())); - summary.push_str("|END|>"); + summary.push_str("|E|>"); } offset = selected_range.end; flushed_selection = true; @@ -107,12 +107,12 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> S // Flush selection if we haven't already done so. if !flushed_selection && offset <= selected_range.start { summary.extend(buffer.text_for_range(offset..selected_range.start)); - summary.push_str("<|START|"); + summary.push_str("<|S|"); if selected_range.end == selected_range.start { summary.push_str(">"); } else { summary.extend(buffer.text_for_range(selected_range.clone())); - summary.push_str("|END|>"); + summary.push_str("|E|>"); } offset = selected_range.end; } @@ -260,7 +260,7 @@ pub(crate) mod tests { summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)), indoc! {" struct X { - <|START|>a: usize, + <|S|>a: usize, b: usize, } @@ -286,7 +286,7 @@ pub(crate) mod tests { impl X { fn new() -> Self { - let <|START|a |END|>= 1; + let <|S|a |E|>= 1; let b = 2; Self { a, b } } @@ -307,7 +307,7 @@ pub(crate) mod tests { } impl X { - <|START|> + <|S|> fn new() -> Self {} pub fn a(&self, param: bool) -> usize {} @@ -333,7 +333,7 @@ pub(crate) mod tests { pub fn b(&self) -> usize {} } - <|START|>"} + <|S|>"} ); // Ensure nested functions get collapsed properly. @@ -369,7 +369,7 @@ pub(crate) mod tests { assert_eq!( summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)), indoc! {" - <|START|>struct X { + <|S|>struct X { a: usize, b: usize, }