WIP
This commit is contained in:
parent
09db455db2
commit
ede86d9187
16 changed files with 5405 additions and 41 deletions
113
crates/assistant2/src/assistant.rs
Normal file
113
crates/assistant2/src/assistant.rs
Normal file
|
@ -0,0 +1,113 @@
|
|||
pub mod assistant_panel;
|
||||
mod assistant_settings;
|
||||
mod codegen;
|
||||
mod prompts;
|
||||
mod streaming_diff;
|
||||
|
||||
use ai::providers::open_ai::Role;
|
||||
use anyhow::Result;
|
||||
pub use assistant_panel::AssistantPanel;
|
||||
use assistant_settings::OpenAIModel;
|
||||
use chrono::{DateTime, Local};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
use futures::StreamExt;
|
||||
use gpui::AppContext;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
|
||||
use util::paths::CONVERSATIONS_DIR;
|
||||
|
||||
#[derive(
|
||||
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
|
||||
)]
|
||||
struct MessageId(usize);
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct MessageMetadata {
|
||||
role: Role,
|
||||
sent_at: DateTime<Local>,
|
||||
status: MessageStatus,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
enum MessageStatus {
|
||||
Pending,
|
||||
Done,
|
||||
Error(Arc<str>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedMessage {
|
||||
id: MessageId,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SavedConversation {
|
||||
id: Option<String>,
|
||||
zed: String,
|
||||
version: String,
|
||||
text: String,
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
model: OpenAIModel,
|
||||
}
|
||||
|
||||
impl SavedConversation {
|
||||
const VERSION: &'static str = "0.1.0";
|
||||
}
|
||||
|
||||
struct SavedConversationMetadata {
|
||||
title: String,
|
||||
path: PathBuf,
|
||||
mtime: chrono::DateTime<chrono::Local>,
|
||||
}
|
||||
|
||||
impl SavedConversationMetadata {
|
||||
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
|
||||
fs.create_dir(&CONVERSATIONS_DIR).await?;
|
||||
|
||||
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
|
||||
let mut conversations = Vec::<SavedConversationMetadata>::new();
|
||||
while let Some(path) = paths.next().await {
|
||||
let path = path?;
|
||||
if path.extension() != Some(OsStr::new("json")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pattern = r" - \d+.zed.json$";
|
||||
let re = Regex::new(pattern).unwrap();
|
||||
|
||||
let metadata = fs.metadata(&path).await?;
|
||||
if let Some((file_name, metadata)) = path
|
||||
.file_name()
|
||||
.and_then(|name| name.to_str())
|
||||
.zip(metadata)
|
||||
{
|
||||
let title = re.replace(file_name, "");
|
||||
conversations.push(Self {
|
||||
title: title.into_owned(),
|
||||
path,
|
||||
mtime: metadata.mtime.into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
|
||||
|
||||
Ok(conversations)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut AppContext) {
|
||||
assistant_panel::init(cx);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
3660
crates/assistant2/src/assistant_panel.rs
Normal file
3660
crates/assistant2/src/assistant_panel.rs
Normal file
File diff suppressed because it is too large
Load diff
80
crates/assistant2/src/assistant_settings.rs
Normal file
80
crates/assistant2/src/assistant_settings.rs
Normal file
|
@ -0,0 +1,80 @@
|
|||
use anyhow;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
pub enum OpenAIModel {
|
||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||
ThreePointFiveTurbo,
|
||||
#[serde(rename = "gpt-4-0613")]
|
||||
Four,
|
||||
#[serde(rename = "gpt-4-1106-preview")]
|
||||
FourTurbo,
|
||||
}
|
||||
|
||||
impl OpenAIModel {
|
||||
pub fn full_name(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
||||
OpenAIModel::Four => "gpt-4-0613",
|
||||
OpenAIModel::FourTurbo => "gpt-4-1106-preview",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_name(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
OpenAIModel::Four => "gpt-4",
|
||||
OpenAIModel::FourTurbo => "gpt-4-turbo",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cycle(&self) -> Self {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
|
||||
OpenAIModel::Four => OpenAIModel::FourTurbo,
|
||||
OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssistantDockPosition {
|
||||
Left,
|
||||
Right,
|
||||
Bottom,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct AssistantSettings {
|
||||
pub button: bool,
|
||||
pub dock: AssistantDockPosition,
|
||||
pub default_width: f32,
|
||||
pub default_height: f32,
|
||||
pub default_open_ai_model: OpenAIModel,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContent {
|
||||
pub button: Option<bool>,
|
||||
pub dock: Option<AssistantDockPosition>,
|
||||
pub default_width: Option<f32>,
|
||||
pub default_height: Option<f32>,
|
||||
pub default_open_ai_model: Option<OpenAIModel>,
|
||||
}
|
||||
|
||||
impl Settings for AssistantSettings {
|
||||
const KEY: Option<&'static str> = Some("assistant");
|
||||
|
||||
type FileContent = AssistantSettingsContent;
|
||||
|
||||
fn load(
|
||||
default_value: &Self::FileContent,
|
||||
user_values: &[&Self::FileContent],
|
||||
_: &mut gpui::AppContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
Self::load_via_json_merge(default_value, user_values)
|
||||
}
|
||||
}
|
695
crates/assistant2/src/codegen.rs
Normal file
695
crates/assistant2/src/codegen.rs
Normal file
|
@ -0,0 +1,695 @@
|
|||
use crate::streaming_diff::{Hunk, StreamingDiff};
|
||||
use ai::completion::{CompletionProvider, CompletionRequest};
|
||||
use anyhow::Result;
|
||||
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
|
||||
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
|
||||
use gpui::{EventEmitter, Model, ModelContext, Task};
|
||||
use language::{Rope, TransactionId};
|
||||
use multi_buffer;
|
||||
use std::{cmp, future, ops::Range, sync::Arc};
|
||||
|
||||
pub enum Event {
|
||||
Finished,
|
||||
Undone,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum CodegenKind {
|
||||
Transform { range: Range<Anchor> },
|
||||
Generate { position: Anchor },
|
||||
}
|
||||
|
||||
pub struct Codegen {
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
buffer: Model<MultiBuffer>,
|
||||
snapshot: MultiBufferSnapshot,
|
||||
kind: CodegenKind,
|
||||
last_equal_ranges: Vec<Range<Anchor>>,
|
||||
transaction_id: Option<TransactionId>,
|
||||
error: Option<anyhow::Error>,
|
||||
generation: Task<()>,
|
||||
idle: bool,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
impl EventEmitter<Event> for Codegen {}
|
||||
|
||||
impl Codegen {
|
||||
pub fn new(
|
||||
buffer: Model<MultiBuffer>,
|
||||
kind: CodegenKind,
|
||||
provider: Arc<dyn CompletionProvider>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let snapshot = buffer.read(cx).snapshot(cx);
|
||||
Self {
|
||||
provider,
|
||||
buffer: buffer.clone(),
|
||||
snapshot,
|
||||
kind,
|
||||
last_equal_ranges: Default::default(),
|
||||
transaction_id: Default::default(),
|
||||
error: Default::default(),
|
||||
idle: true,
|
||||
generation: Task::ready(()),
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_buffer_event(
|
||||
&mut self,
|
||||
_buffer: Model<MultiBuffer>,
|
||||
event: &multi_buffer::Event,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
|
||||
if self.transaction_id == Some(*transaction_id) {
|
||||
self.transaction_id = None;
|
||||
self.generation = Task::ready(());
|
||||
cx.emit(Event::Undone);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn range(&self) -> Range<Anchor> {
|
||||
match &self.kind {
|
||||
CodegenKind::Transform { range } => range.clone(),
|
||||
CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn kind(&self) -> &CodegenKind {
|
||||
&self.kind
|
||||
}
|
||||
|
||||
pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
|
||||
&self.last_equal_ranges
|
||||
}
|
||||
|
||||
pub fn idle(&self) -> bool {
|
||||
self.idle
|
||||
}
|
||||
|
||||
pub fn error(&self) -> Option<&anyhow::Error> {
|
||||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
|
||||
let range = self.range();
|
||||
let snapshot = self.snapshot.clone();
|
||||
let selected_text = snapshot
|
||||
.text_for_range(range.start..range.end)
|
||||
.collect::<Rope>();
|
||||
|
||||
let selection_start = range.start.to_point(&snapshot);
|
||||
let suggested_line_indent = snapshot
|
||||
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
|
||||
.into_values()
|
||||
.next()
|
||||
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
|
||||
|
||||
let response = self.provider.complete(prompt);
|
||||
self.generation = cx.spawn_weak(|this, mut cx| {
|
||||
async move {
|
||||
let generate = async {
|
||||
let mut edit_start = range.start.to_offset(&snapshot);
|
||||
|
||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||
let diff = cx.background().spawn(async move {
|
||||
let chunks = strip_invalid_spans_from_codeblock(response.await?);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
|
||||
let mut new_text = String::new();
|
||||
let mut base_indent = None;
|
||||
let mut line_indent = None;
|
||||
let mut first_line = true;
|
||||
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
let mut lines = chunk.split('\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
new_text.push_str(line);
|
||||
if line_indent.is_none() {
|
||||
if let Some(non_whitespace_ch_ix) =
|
||||
new_text.find(|ch: char| !ch.is_whitespace())
|
||||
{
|
||||
line_indent = Some(non_whitespace_ch_ix);
|
||||
base_indent = base_indent.or(line_indent);
|
||||
|
||||
let line_indent = line_indent.unwrap();
|
||||
let base_indent = base_indent.unwrap();
|
||||
let indent_delta = line_indent as i32 - base_indent as i32;
|
||||
let mut corrected_indent_len = cmp::max(
|
||||
0,
|
||||
suggested_line_indent.len as i32 + indent_delta,
|
||||
)
|
||||
as usize;
|
||||
if first_line {
|
||||
corrected_indent_len = corrected_indent_len
|
||||
.saturating_sub(selection_start.column as usize);
|
||||
}
|
||||
|
||||
let indent_char = suggested_line_indent.char();
|
||||
let mut indent_buffer = [0; 4];
|
||||
let indent_str =
|
||||
indent_char.encode_utf8(&mut indent_buffer);
|
||||
new_text.replace_range(
|
||||
..line_indent,
|
||||
&indent_str.repeat(corrected_indent_len),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if line_indent.is_some() {
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
new_text.clear();
|
||||
}
|
||||
|
||||
if lines.peek().is_some() {
|
||||
hunks_tx.send(diff.push_new("\n")).await?;
|
||||
line_indent = None;
|
||||
first_line = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
hunks_tx.send(diff.push_new(&new_text)).await?;
|
||||
hunks_tx.send(diff.finish()).await?;
|
||||
|
||||
anyhow::Ok(())
|
||||
});
|
||||
|
||||
while let Some(hunks) = hunks_rx.next().await {
|
||||
let this = if let Some(this) = this.upgrade(&cx) {
|
||||
this
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.last_equal_ranges.clear();
|
||||
|
||||
let transaction = this.buffer.update(cx, |buffer, cx| {
|
||||
// Avoid grouping assistant edits with user edits.
|
||||
buffer.finalize_last_transaction(cx);
|
||||
|
||||
buffer.start_transaction(cx);
|
||||
buffer.edit(
|
||||
hunks.into_iter().filter_map(|hunk| match hunk {
|
||||
Hunk::Insert { text } => {
|
||||
let edit_start = snapshot.anchor_after(edit_start);
|
||||
Some((edit_start..edit_start, text))
|
||||
}
|
||||
Hunk::Remove { len } => {
|
||||
let edit_end = edit_start + len;
|
||||
let edit_range = snapshot.anchor_after(edit_start)
|
||||
..snapshot.anchor_before(edit_end);
|
||||
edit_start = edit_end;
|
||||
Some((edit_range, String::new()))
|
||||
}
|
||||
Hunk::Keep { len } => {
|
||||
let edit_end = edit_start + len;
|
||||
let edit_range = snapshot.anchor_after(edit_start)
|
||||
..snapshot.anchor_before(edit_end);
|
||||
edit_start = edit_end;
|
||||
this.last_equal_ranges.push(edit_range);
|
||||
None
|
||||
}
|
||||
}),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
|
||||
buffer.end_transaction(cx)
|
||||
});
|
||||
|
||||
if let Some(transaction) = transaction {
|
||||
if let Some(first_transaction) = this.transaction_id {
|
||||
// Group all assistant edits into the first transaction.
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.merge_transactions(
|
||||
transaction,
|
||||
first_transaction,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
} else {
|
||||
this.transaction_id = Some(transaction);
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction(cx)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
|
||||
diff.await?;
|
||||
anyhow::Ok(())
|
||||
};
|
||||
|
||||
let result = generate.await;
|
||||
if let Some(this) = this.upgrade(&cx) {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.last_equal_ranges.clear();
|
||||
this.idle = true;
|
||||
if let Err(error) = result {
|
||||
this.error = Some(error);
|
||||
}
|
||||
cx.emit(Event::Finished);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
self.error.take();
|
||||
self.idle = false;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
|
||||
if let Some(transaction_id) = self.transaction_id {
|
||||
self.buffer
|
||||
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_invalid_spans_from_codeblock(
|
||||
stream: impl Stream<Item = Result<String>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
let mut first_line = true;
|
||||
let mut buffer = String::new();
|
||||
let mut starts_with_markdown_codeblock = false;
|
||||
let mut includes_start_or_end_span = false;
|
||||
stream.filter_map(move |chunk| {
|
||||
let chunk = match chunk {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => return future::ready(Some(Err(err))),
|
||||
};
|
||||
buffer.push_str(&chunk);
|
||||
|
||||
if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
|
||||
includes_start_or_end_span = true;
|
||||
|
||||
buffer = buffer
|
||||
.strip_prefix("<|S|>")
|
||||
.or_else(|| buffer.strip_prefix("<|S|"))
|
||||
.unwrap_or(&buffer)
|
||||
.to_string();
|
||||
} else if buffer.ends_with("|E|>") {
|
||||
includes_start_or_end_span = true;
|
||||
} else if buffer.starts_with("<|")
|
||||
|| buffer.starts_with("<|S")
|
||||
|| buffer.starts_with("<|S|")
|
||||
|| buffer.ends_with("|")
|
||||
|| buffer.ends_with("|E")
|
||||
|| buffer.ends_with("|E|")
|
||||
{
|
||||
return future::ready(None);
|
||||
}
|
||||
|
||||
if first_line {
|
||||
if buffer == "" || buffer == "`" || buffer == "``" {
|
||||
return future::ready(None);
|
||||
} else if buffer.starts_with("```") {
|
||||
starts_with_markdown_codeblock = true;
|
||||
if let Some(newline_ix) = buffer.find('\n') {
|
||||
buffer.replace_range(..newline_ix + 1, "");
|
||||
first_line = false;
|
||||
} else {
|
||||
return future::ready(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut text = buffer.to_string();
|
||||
if starts_with_markdown_codeblock {
|
||||
text = text
|
||||
.strip_suffix("\n```\n")
|
||||
.or_else(|| text.strip_suffix("\n```"))
|
||||
.or_else(|| text.strip_suffix("\n``"))
|
||||
.or_else(|| text.strip_suffix("\n`"))
|
||||
.or_else(|| text.strip_suffix('\n'))
|
||||
.unwrap_or(&text)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
if includes_start_or_end_span {
|
||||
text = text
|
||||
.strip_suffix("|E|>")
|
||||
.or_else(|| text.strip_suffix("E|>"))
|
||||
.or_else(|| text.strip_prefix("|>"))
|
||||
.or_else(|| text.strip_prefix(">"))
|
||||
.unwrap_or(&text)
|
||||
.to_string();
|
||||
};
|
||||
|
||||
if text.contains('\n') {
|
||||
first_line = false;
|
||||
}
|
||||
|
||||
let remainder = buffer.split_off(text.len());
|
||||
let result = if buffer.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Ok(buffer.clone()))
|
||||
};
|
||||
|
||||
buffer = remainder;
|
||||
future::ready(result)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use ai::test::FakeCompletionProvider;
|
||||
use futures::stream::{self};
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
use rand::prelude::*;
|
||||
use serde::Serialize;
|
||||
use settings::SettingsStore;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct DummyCompletionRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl CompletionRequest for DummyCompletionRequest {
|
||||
fn data(&self) -> serde_json::Result<String> {
|
||||
serde_json::to_string(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
let x = 0;
|
||||
for _ in 0..10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let range = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Transform { range },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
" let mut x = 0;\n",
|
||||
" while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
" }",
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("CHUNK: {:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_past_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = indoc! {"
|
||||
fn main() {
|
||||
le
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 6))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"t mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_autoindent_when_generating_before_indentation(
|
||||
cx: &mut TestAppContext,
|
||||
mut rng: StdRng,
|
||||
) {
|
||||
cx.set_global(cx.read(SettingsStore::test));
|
||||
cx.update(language_settings::init);
|
||||
|
||||
let text = concat!(
|
||||
"fn main() {\n",
|
||||
" \n",
|
||||
"}\n" //
|
||||
);
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let position = buffer.read_with(cx, |buffer, cx| {
|
||||
let snapshot = buffer.snapshot(cx);
|
||||
snapshot.anchor_before(Point::new(1, 2))
|
||||
});
|
||||
let provider = Arc::new(FakeCompletionProvider::new());
|
||||
let codegen = cx.add_model(|cx| {
|
||||
Codegen::new(
|
||||
buffer.clone(),
|
||||
CodegenKind::Generate { position },
|
||||
provider.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let request = Box::new(DummyCompletionRequest {
|
||||
name: "test".to_string(),
|
||||
});
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
|
||||
let mut new_text = concat!(
|
||||
"let mut x = 0;\n",
|
||||
"while x < 10 {\n",
|
||||
" x += 1;\n",
|
||||
"}", //
|
||||
);
|
||||
while !new_text.is_empty() {
|
||||
let max_len = cmp::min(new_text.len(), 10);
|
||||
let len = rng.gen_range(1..=max_len);
|
||||
let (chunk, suffix) = new_text.split_at(len);
|
||||
println!("{:?}", &chunk);
|
||||
provider.send_completion(chunk);
|
||||
new_text = suffix;
|
||||
cx.background_executor.run_until_parked();
|
||||
}
|
||||
provider.finish_completion();
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
|
||||
indoc! {"
|
||||
fn main() {
|
||||
let mut x = 0;
|
||||
while x < 10 {
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_strip_invalid_spans_from_codeblock() {
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks(
|
||||
"```html\n```js\nLorem ipsum dolor\n```\n```",
|
||||
2
|
||||
))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"```js\nLorem ipsum dolor\n```"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"``\nLorem ipsum dolor\n```"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
|
||||
.map(|chunk| chunk.unwrap())
|
||||
.collect::<String>()
|
||||
.await,
|
||||
"Lorem ipsum"
|
||||
);
|
||||
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
|
||||
stream::iter(
|
||||
text.chars()
|
||||
.collect::<Vec<_>>()
|
||||
.chunks(size)
|
||||
.map(|chunk| Ok(chunk.iter().collect::<String>()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_indents_query(
|
||||
r#"
|
||||
(call_expression) @indent
|
||||
(field_expression) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
(_ "{" "}" @end) @indent
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
388
crates/assistant2/src/prompts.rs
Normal file
388
crates/assistant2/src/prompts.rs
Normal file
|
@ -0,0 +1,388 @@
|
|||
use ai::models::LanguageModel;
|
||||
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
|
||||
use ai::prompts::file_context::FileContext;
|
||||
use ai::prompts::generate::GenerateInlineContent;
|
||||
use ai::prompts::preamble::EngineerPreamble;
|
||||
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
|
||||
use ai::providers::open_ai::OpenAILanguageModel;
|
||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||
use std::cmp::{self, Reverse};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||
#[derive(Debug)]
|
||||
struct Match {
|
||||
collapse: Range<usize>,
|
||||
keep: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
let selected_range = selected_range.to_offset(buffer);
|
||||
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
|
||||
Some(&grammar.embedding_config.as_ref()?.query)
|
||||
});
|
||||
let configs = ts_matches
|
||||
.grammars()
|
||||
.iter()
|
||||
.map(|g| g.embedding_config.as_ref().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
let mut matches = Vec::new();
|
||||
while let Some(mat) = ts_matches.peek() {
|
||||
let config = &configs[mat.grammar_index];
|
||||
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
|
||||
if Some(cap.index) == config.collapse_capture_ix {
|
||||
Some(cap.node.byte_range())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}) {
|
||||
let mut keep = Vec::new();
|
||||
for capture in mat.captures.iter() {
|
||||
if Some(capture.index) == config.keep_capture_ix {
|
||||
keep.push(capture.node.byte_range());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ts_matches.advance();
|
||||
matches.push(Match { collapse, keep });
|
||||
} else {
|
||||
ts_matches.advance();
|
||||
}
|
||||
}
|
||||
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
|
||||
let mut matches = matches.into_iter().peekable();
|
||||
|
||||
let mut summary = String::new();
|
||||
let mut offset = 0;
|
||||
let mut flushed_selection = false;
|
||||
while let Some(mat) = matches.next() {
|
||||
// Keep extending the collapsed range if the next match surrounds
|
||||
// the current one.
|
||||
while let Some(next_mat) = matches.peek() {
|
||||
if mat.collapse.start <= next_mat.collapse.start
|
||||
&& mat.collapse.end >= next_mat.collapse.end
|
||||
{
|
||||
matches.next().unwrap();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if offset > mat.collapse.start {
|
||||
// Skip collapsed nodes that have already been summarized.
|
||||
offset = cmp::max(offset, mat.collapse.end);
|
||||
continue;
|
||||
}
|
||||
|
||||
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
|
||||
if !flushed_selection {
|
||||
// The collapsed node ends after the selection starts, so we'll flush the selection first.
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
flushed_selection = true;
|
||||
}
|
||||
|
||||
// If the selection intersects the collapsed node, we won't collapse it.
|
||||
if selected_range.end >= mat.collapse.start {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
|
||||
for keep in mat.keep {
|
||||
summary.extend(buffer.text_for_range(keep));
|
||||
}
|
||||
offset = mat.collapse.end;
|
||||
}
|
||||
|
||||
// Flush selection if we haven't already done so.
|
||||
if !flushed_selection && offset <= selected_range.start {
|
||||
summary.extend(buffer.text_for_range(offset..selected_range.start));
|
||||
summary.push_str("<|S|");
|
||||
if selected_range.end == selected_range.start {
|
||||
summary.push_str(">");
|
||||
} else {
|
||||
summary.extend(buffer.text_for_range(selected_range.clone()));
|
||||
summary.push_str("|E|>");
|
||||
}
|
||||
offset = selected_range.end;
|
||||
}
|
||||
|
||||
summary.extend(buffer.text_for_range(offset..buffer.len()));
|
||||
summary
|
||||
}
|
||||
|
||||
pub fn generate_content_prompt(
|
||||
user_prompt: String,
|
||||
language_name: Option<&str>,
|
||||
buffer: BufferSnapshot,
|
||||
range: Range<usize>,
|
||||
search_results: Vec<PromptCodeSnippet>,
|
||||
model: &str,
|
||||
project_name: Option<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
// Using new Prompt Templates
|
||||
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
|
||||
let lang_name = if let Some(language_name) = language_name {
|
||||
Some(language_name.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let args = PromptArguments {
|
||||
model: openai_model,
|
||||
language_name: lang_name.clone(),
|
||||
project_name,
|
||||
snippets: search_results.clone(),
|
||||
reserved_tokens: 1000,
|
||||
buffer: Some(buffer),
|
||||
selected_range: Some(range),
|
||||
user_prompt: Some(user_prompt.clone()),
|
||||
};
|
||||
|
||||
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
|
||||
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
|
||||
(
|
||||
PromptPriority::Ordered { order: 1 },
|
||||
Box::new(RepositoryContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Ordered { order: 0 },
|
||||
Box::new(FileContext {}),
|
||||
),
|
||||
(
|
||||
PromptPriority::Mandatory,
|
||||
Box::new(GenerateInlineContent {}),
|
||||
),
|
||||
];
|
||||
let chain = PromptChain::new(args, templates);
|
||||
let (prompt, _) = chain.generate(true)?;
|
||||
|
||||
anyhow::Ok(prompt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::AppContext;
|
||||
use indoc::indoc;
|
||||
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
|
||||
use settings::SettingsStore;
|
||||
|
||||
pub(crate) fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::language()),
|
||||
)
|
||||
.with_embedding_query(
|
||||
r#"
|
||||
(
|
||||
[(line_comment) (attribute_item)]* @context
|
||||
.
|
||||
[
|
||||
(struct_item
|
||||
name: (_) @name)
|
||||
|
||||
(enum_item
|
||||
name: (_) @name)
|
||||
|
||||
(impl_item
|
||||
trait: (_)? @name
|
||||
"for"? @name
|
||||
type: (_) @name)
|
||||
|
||||
(trait_item
|
||||
name: (_) @name)
|
||||
|
||||
(function_item
|
||||
name: (_) @name
|
||||
body: (block
|
||||
"{" @keep
|
||||
"}" @keep) @collapse)
|
||||
|
||||
(macro_definition
|
||||
name: (_) @name)
|
||||
] @item
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_outline_for_prompt(cx: &mut AppContext) {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
language_settings::init(cx);
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
self.a
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
let buffer =
|
||||
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
<|S|>a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let <|S|a |E|>= 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
<|S|>
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
|
||||
indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
<|S|>"}
|
||||
);
|
||||
|
||||
// Ensure nested functions get collapsed properly.
|
||||
let text = indoc! {"
|
||||
struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {
|
||||
let a = 1;
|
||||
let b = 2;
|
||||
Self { a, b }
|
||||
}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {
|
||||
let a = 30;
|
||||
fn nested() -> usize {
|
||||
3
|
||||
}
|
||||
self.a + nested()
|
||||
}
|
||||
|
||||
pub fn b(&self) -> usize {
|
||||
self.b
|
||||
}
|
||||
}
|
||||
"};
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
assert_eq!(
|
||||
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
|
||||
indoc! {"
|
||||
<|S|>struct X {
|
||||
a: usize,
|
||||
b: usize,
|
||||
}
|
||||
|
||||
impl X {
|
||||
|
||||
fn new() -> Self {}
|
||||
|
||||
pub fn a(&self, param: bool) -> usize {}
|
||||
|
||||
pub fn b(&self) -> usize {}
|
||||
}
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
293
crates/assistant2/src/streaming_diff.rs
Normal file
293
crates/assistant2/src/streaming_diff.rs
Normal file
|
@ -0,0 +1,293 @@
|
|||
use collections::HashMap;
|
||||
use ordered_float::OrderedFloat;
|
||||
use std::{
|
||||
cmp,
|
||||
fmt::{self, Debug},
|
||||
ops::Range,
|
||||
};
|
||||
|
||||
struct Matrix {
|
||||
cells: Vec<f64>,
|
||||
rows: usize,
|
||||
cols: usize,
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
cells: Vec::new(),
|
||||
rows: 0,
|
||||
cols: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn resize(&mut self, rows: usize, cols: usize) {
|
||||
self.cells.resize(rows * cols, 0.);
|
||||
self.rows = rows;
|
||||
self.cols = cols;
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> f64 {
|
||||
if row >= self.rows {
|
||||
panic!("row out of bounds")
|
||||
}
|
||||
|
||||
if col >= self.cols {
|
||||
panic!("col out of bounds")
|
||||
}
|
||||
self.cells[col * self.rows + row]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, value: f64) {
|
||||
if row >= self.rows {
|
||||
panic!("row out of bounds")
|
||||
}
|
||||
|
||||
if col >= self.cols {
|
||||
panic!("col out of bounds")
|
||||
}
|
||||
|
||||
self.cells[col * self.rows + row] = value;
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Matrix {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
writeln!(f)?;
|
||||
for i in 0..self.rows {
|
||||
for j in 0..self.cols {
|
||||
write!(f, "{:5}", self.get(i, j))?;
|
||||
}
|
||||
writeln!(f)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Hunk {
|
||||
Insert { text: String },
|
||||
Remove { len: usize },
|
||||
Keep { len: usize },
|
||||
}
|
||||
|
||||
pub struct StreamingDiff {
|
||||
old: Vec<char>,
|
||||
new: Vec<char>,
|
||||
scores: Matrix,
|
||||
old_text_ix: usize,
|
||||
new_text_ix: usize,
|
||||
equal_runs: HashMap<(usize, usize), u32>,
|
||||
}
|
||||
|
||||
impl StreamingDiff {
|
||||
const INSERTION_SCORE: f64 = -1.;
|
||||
const DELETION_SCORE: f64 = -20.;
|
||||
const EQUALITY_BASE: f64 = 1.8;
|
||||
const MAX_EQUALITY_EXPONENT: i32 = 16;
|
||||
|
||||
pub fn new(old: String) -> Self {
|
||||
let old = old.chars().collect::<Vec<_>>();
|
||||
let mut scores = Matrix::new();
|
||||
scores.resize(old.len() + 1, 1);
|
||||
for i in 0..=old.len() {
|
||||
scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
|
||||
}
|
||||
Self {
|
||||
old,
|
||||
new: Vec::new(),
|
||||
scores,
|
||||
old_text_ix: 0,
|
||||
new_text_ix: 0,
|
||||
equal_runs: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
|
||||
self.new.extend(text.chars());
|
||||
self.scores.resize(self.old.len() + 1, self.new.len() + 1);
|
||||
|
||||
for j in self.new_text_ix + 1..=self.new.len() {
|
||||
self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
|
||||
for i in 1..=self.old.len() {
|
||||
let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
|
||||
let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
|
||||
let equality_score = if self.old[i - 1] == self.new[j - 1] {
|
||||
let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0);
|
||||
equal_run += 1;
|
||||
self.equal_runs.insert((i, j), equal_run);
|
||||
|
||||
let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT);
|
||||
self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
|
||||
} else {
|
||||
f64::NEG_INFINITY
|
||||
};
|
||||
|
||||
let score = insertion_score.max(deletion_score).max(equality_score);
|
||||
self.scores.set(i, j, score);
|
||||
}
|
||||
}
|
||||
|
||||
let mut max_score = f64::NEG_INFINITY;
|
||||
let mut next_old_text_ix = self.old_text_ix;
|
||||
let next_new_text_ix = self.new.len();
|
||||
for i in self.old_text_ix..=self.old.len() {
|
||||
let score = self.scores.get(i, next_new_text_ix);
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
next_old_text_ix = i;
|
||||
}
|
||||
}
|
||||
|
||||
let hunks = self.backtrack(next_old_text_ix, next_new_text_ix);
|
||||
self.old_text_ix = next_old_text_ix;
|
||||
self.new_text_ix = next_new_text_ix;
|
||||
hunks
|
||||
}
|
||||
|
||||
fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
|
||||
let mut pending_insert: Option<Range<usize>> = None;
|
||||
let mut hunks = Vec::new();
|
||||
let mut i = old_text_ix;
|
||||
let mut j = new_text_ix;
|
||||
while (i, j) != (self.old_text_ix, self.new_text_ix) {
|
||||
let insertion_score = if j > self.new_text_ix {
|
||||
Some((i, j - 1))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let deletion_score = if i > self.old_text_ix {
|
||||
Some((i - 1, j))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
|
||||
if self.old[i - 1] == self.new[j - 1] {
|
||||
Some((i - 1, j - 1))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
|
||||
.iter()
|
||||
.max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
if prev_i == i && prev_j == j - 1 {
|
||||
if let Some(pending_insert) = pending_insert.as_mut() {
|
||||
pending_insert.start = prev_j;
|
||||
} else {
|
||||
pending_insert = Some(prev_j..j);
|
||||
}
|
||||
} else {
|
||||
if let Some(range) = pending_insert.take() {
|
||||
hunks.push(Hunk::Insert {
|
||||
text: self.new[range].iter().collect(),
|
||||
});
|
||||
}
|
||||
|
||||
let char_len = self.old[i - 1].len_utf8();
|
||||
if prev_i == i - 1 && prev_j == j {
|
||||
if let Some(Hunk::Remove { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
hunks.push(Hunk::Remove { len: char_len })
|
||||
}
|
||||
} else {
|
||||
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
|
||||
*len += char_len;
|
||||
} else {
|
||||
hunks.push(Hunk::Keep { len: char_len })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = prev_i;
|
||||
j = prev_j;
|
||||
}
|
||||
|
||||
if let Some(range) = pending_insert.take() {
|
||||
hunks.push(Hunk::Insert {
|
||||
text: self.new[range].iter().collect(),
|
||||
});
|
||||
}
|
||||
|
||||
hunks.reverse();
|
||||
hunks
|
||||
}
|
||||
|
||||
pub fn finish(self) -> Vec<Hunk> {
|
||||
self.backtrack(self.old.len(), self.new.len())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::env;
|
||||
|
||||
use super::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
fn test_random_diffs(mut rng: StdRng) {
|
||||
let old_text_len = env::var("OLD_TEXT_LEN")
|
||||
.map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable"))
|
||||
.unwrap_or(10);
|
||||
let new_text_len = env::var("NEW_TEXT_LEN")
|
||||
.map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable"))
|
||||
.unwrap_or(10);
|
||||
|
||||
let old = util::RandomCharIter::new(&mut rng)
|
||||
.take(old_text_len)
|
||||
.collect::<String>();
|
||||
log::info!("old text: {:?}", old);
|
||||
|
||||
let mut diff = StreamingDiff::new(old.clone());
|
||||
let mut hunks = Vec::new();
|
||||
let mut new_len = 0;
|
||||
let mut new = String::new();
|
||||
while new_len < new_text_len {
|
||||
let new_chunk_len = rng.gen_range(1..=new_text_len - new_len);
|
||||
let new_chunk = util::RandomCharIter::new(&mut rng)
|
||||
.take(new_len)
|
||||
.collect::<String>();
|
||||
log::info!("new chunk: {:?}", new_chunk);
|
||||
new_len += new_chunk_len;
|
||||
new.push_str(&new_chunk);
|
||||
let new_hunks = diff.push_new(&new_chunk);
|
||||
log::info!("hunks: {:?}", new_hunks);
|
||||
hunks.extend(new_hunks);
|
||||
}
|
||||
let final_hunks = diff.finish();
|
||||
log::info!("final hunks: {:?}", final_hunks);
|
||||
hunks.extend(final_hunks);
|
||||
|
||||
log::info!("new text: {:?}", new);
|
||||
let mut old_ix = 0;
|
||||
let mut new_ix = 0;
|
||||
let mut patched = String::new();
|
||||
for hunk in hunks {
|
||||
match hunk {
|
||||
Hunk::Keep { len } => {
|
||||
assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]);
|
||||
patched.push_str(&old[old_ix..old_ix + len]);
|
||||
old_ix += len;
|
||||
new_ix += len;
|
||||
}
|
||||
Hunk::Remove { len } => {
|
||||
old_ix += len;
|
||||
}
|
||||
Hunk::Insert { text } => {
|
||||
assert_eq!(text, &new[new_ix..new_ix + text.len()]);
|
||||
patched.push_str(&text);
|
||||
new_ix += text.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
assert_eq!(patched, new);
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue