This commit is contained in:
Antonio Scandurra 2023-12-05 16:49:36 +01:00
parent 09db455db2
commit ede86d9187
16 changed files with 5405 additions and 41 deletions

View file

@ -0,0 +1,113 @@
pub mod assistant_panel;
mod assistant_settings;
mod codegen;
mod prompts;
mod streaming_diff;
use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
use futures::StreamExt;
use gpui::AppContext;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
struct MessageId(usize);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
sent_at: DateTime<Local>,
status: MessageStatus,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
enum MessageStatus {
Pending,
Done,
Error(Arc<str>),
}
#[derive(Serialize, Deserialize)]
struct SavedMessage {
id: MessageId,
start: usize,
}
#[derive(Serialize, Deserialize)]
struct SavedConversation {
id: Option<String>,
zed: String,
version: String,
text: String,
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
model: OpenAIModel,
}
impl SavedConversation {
const VERSION: &'static str = "0.1.0";
}
struct SavedConversationMetadata {
title: String,
path: PathBuf,
mtime: chrono::DateTime<chrono::Local>,
}
impl SavedConversationMetadata {
pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
fs.create_dir(&CONVERSATIONS_DIR).await?;
let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
let mut conversations = Vec::<SavedConversationMetadata>::new();
while let Some(path) = paths.next().await {
let path = path?;
if path.extension() != Some(OsStr::new("json")) {
continue;
}
let pattern = r" - \d+.zed.json$";
let re = Regex::new(pattern).unwrap();
let metadata = fs.metadata(&path).await?;
if let Some((file_name, metadata)) = path
.file_name()
.and_then(|name| name.to_str())
.zip(metadata)
{
let title = re.replace(file_name, "");
conversations.push(Self {
title: title.into_owned(),
path,
mtime: metadata.mtime.into(),
});
}
}
conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
Ok(conversations)
}
}
pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx);
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
use anyhow;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub enum OpenAIModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
#[serde(rename = "gpt-4-0613")]
Four,
#[serde(rename = "gpt-4-1106-preview")]
FourTurbo,
}
impl OpenAIModel {
pub fn full_name(&self) -> &'static str {
match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
OpenAIModel::Four => "gpt-4-0613",
OpenAIModel::FourTurbo => "gpt-4-1106-preview",
}
}
pub fn short_name(&self) -> &'static str {
match self {
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
OpenAIModel::Four => "gpt-4",
OpenAIModel::FourTurbo => "gpt-4-turbo",
}
}
pub fn cycle(&self) -> Self {
match self {
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
OpenAIModel::Four => OpenAIModel::FourTurbo,
OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
Left,
Right,
Bottom,
}
#[derive(Deserialize, Debug)]
pub struct AssistantSettings {
pub button: bool,
pub dock: AssistantDockPosition,
pub default_width: f32,
pub default_height: f32,
pub default_open_ai_model: OpenAIModel,
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContent {
pub button: Option<bool>,
pub dock: Option<AssistantDockPosition>,
pub default_width: Option<f32>,
pub default_height: Option<f32>,
pub default_open_ai_model: Option<OpenAIModel>,
}
impl Settings for AssistantSettings {
const KEY: Option<&'static str> = Some("assistant");
type FileContent = AssistantSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
_: &mut gpui::AppContext,
) -> anyhow::Result<Self> {
Self::load_via_json_merge(default_value, user_values)
}
}

View file

@ -0,0 +1,695 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
use multi_buffer;
use std::{cmp, future, ops::Range, sync::Arc};
pub enum Event {
Finished,
Undone,
}
#[derive(Clone)]
pub enum CodegenKind {
Transform { range: Range<Anchor> },
Generate { position: Anchor },
}
pub struct Codegen {
provider: Arc<dyn CompletionProvider>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
last_equal_ranges: Vec<Range<Anchor>>,
transaction_id: Option<TransactionId>,
error: Option<anyhow::Error>,
generation: Task<()>,
idle: bool,
_subscription: gpui::Subscription,
}
impl EventEmitter<Event> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
kind: CodegenKind,
provider: Arc<dyn CompletionProvider>,
cx: &mut ModelContext<Self>,
) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
Self {
provider,
buffer: buffer.clone(),
snapshot,
kind,
last_equal_ranges: Default::default(),
transaction_id: Default::default(),
error: Default::default(),
idle: true,
generation: Task::ready(()),
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
fn handle_buffer_event(
&mut self,
_buffer: Model<MultiBuffer>,
event: &multi_buffer::Event,
cx: &mut ModelContext<Self>,
) {
if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
if self.transaction_id == Some(*transaction_id) {
self.transaction_id = None;
self.generation = Task::ready(());
cx.emit(Event::Undone);
}
}
}
pub fn range(&self) -> Range<Anchor> {
match &self.kind {
CodegenKind::Transform { range } => range.clone(),
CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
}
}
pub fn kind(&self) -> &CodegenKind {
&self.kind
}
pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
&self.last_equal_ranges
}
pub fn idle(&self) -> bool {
self.idle
}
pub fn error(&self) -> Option<&anyhow::Error> {
self.error.as_ref()
}
pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
.text_for_range(range.start..range.end)
.collect::<Rope>();
let selection_start = range.start.to_point(&snapshot);
let suggested_line_indent = snapshot
.suggested_indents(selection_start.row..selection_start.row + 1, cx)
.into_values()
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
let response = self.provider.complete(prompt);
self.generation = cx.spawn_weak(|this, mut cx| {
async move {
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
let diff = cx.background().spawn(async move {
let chunks = strip_invalid_spans_from_codeblock(response.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
let chunk = chunk?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta = line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(selection_start.column as usize);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
hunks_tx.send(diff.push_new(&new_text)).await?;
new_text.clear();
}
if lines.peek().is_some() {
hunks_tx.send(diff.push_new("\n")).await?;
line_indent = None;
first_line = false;
}
}
}
hunks_tx.send(diff.push_new(&new_text)).await?;
hunks_tx.send(diff.finish()).await?;
anyhow::Ok(())
});
while let Some(hunks) = hunks_rx.next().await {
let this = if let Some(this) = this.upgrade(&cx) {
this
} else {
break;
};
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
let transaction = this.buffer.update(cx, |buffer, cx| {
// Avoid grouping assistant edits with user edits.
buffer.finalize_last_transaction(cx);
buffer.start_transaction(cx);
buffer.edit(
hunks.into_iter().filter_map(|hunk| match hunk {
Hunk::Insert { text } => {
let edit_start = snapshot.anchor_after(edit_start);
Some((edit_start..edit_start, text))
}
Hunk::Remove { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
Some((edit_range, String::new()))
}
Hunk::Keep { len } => {
let edit_end = edit_start + len;
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
this.last_equal_ranges.push(edit_range);
None
}
}),
None,
cx,
);
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
if let Some(first_transaction) = this.transaction_id {
// Group all assistant edits into the first transaction.
this.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
transaction,
first_transaction,
cx,
)
});
} else {
this.transaction_id = Some(transaction);
this.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
cx.notify();
});
}
diff.await?;
anyhow::Ok(())
};
let result = generate.await;
if let Some(this) = this.upgrade(&cx) {
this.update(&mut cx, |this, cx| {
this.last_equal_ranges.clear();
this.idle = true;
if let Err(error) = result {
this.error = Some(error);
}
cx.emit(Event::Finished);
cx.notify();
});
}
}
});
self.error.take();
self.idle = false;
cx.notify();
}
pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
if let Some(transaction_id) = self.transaction_id {
self.buffer
.update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
}
}
}
fn strip_invalid_spans_from_codeblock(
stream: impl Stream<Item = Result<String>>,
) -> impl Stream<Item = Result<String>> {
let mut first_line = true;
let mut buffer = String::new();
let mut starts_with_markdown_codeblock = false;
let mut includes_start_or_end_span = false;
stream.filter_map(move |chunk| {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(err) => return future::ready(Some(Err(err))),
};
buffer.push_str(&chunk);
if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
includes_start_or_end_span = true;
buffer = buffer
.strip_prefix("<|S|>")
.or_else(|| buffer.strip_prefix("<|S|"))
.unwrap_or(&buffer)
.to_string();
} else if buffer.ends_with("|E|>") {
includes_start_or_end_span = true;
} else if buffer.starts_with("<|")
|| buffer.starts_with("<|S")
|| buffer.starts_with("<|S|")
|| buffer.ends_with("|")
|| buffer.ends_with("|E")
|| buffer.ends_with("|E|")
{
return future::ready(None);
}
if first_line {
if buffer == "" || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_markdown_codeblock = true;
if let Some(newline_ix) = buffer.find('\n') {
buffer.replace_range(..newline_ix + 1, "");
first_line = false;
} else {
return future::ready(None);
}
}
}
let mut text = buffer.to_string();
if starts_with_markdown_codeblock {
text = text
.strip_suffix("\n```\n")
.or_else(|| text.strip_suffix("\n```"))
.or_else(|| text.strip_suffix("\n``"))
.or_else(|| text.strip_suffix("\n`"))
.or_else(|| text.strip_suffix('\n'))
.unwrap_or(&text)
.to_string();
}
if includes_start_or_end_span {
text = text
.strip_suffix("|E|>")
.or_else(|| text.strip_suffix("E|>"))
.or_else(|| text.strip_prefix("|>"))
.or_else(|| text.strip_prefix(">"))
.unwrap_or(&text)
.to_string();
};
if text.contains('\n') {
first_line = false;
}
let remainder = buffer.split_off(text.len());
let result = if buffer.is_empty() {
None
} else {
Some(Ok(buffer.clone()))
};
buffer = remainder;
future::ready(result)
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::TestAppContext;
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
impl CompletionRequest for DummyCompletionRequest {
fn data(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Transform { range },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("CHUNK: {:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.read(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
let position = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
CodegenKind::Generate { position },
provider.clone(),
cx,
)
});
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
println!("{:?}", &chunk);
provider.send_completion(chunk);
new_text = suffix;
cx.background_executor.run_until_parked();
}
provider.finish_completion();
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum dolor"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks(
"```html\n```js\nLorem ipsum dolor\n```\n```",
2
))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"```js\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"``\nLorem ipsum dolor\n```"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
assert_eq!(
strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
.map(|chunk| chunk.unwrap())
.collect::<String>()
.await,
"Lorem ipsum"
);
fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
stream::iter(
text.chars()
.collect::<Vec<_>>()
.chunks(size)
.map(|chunk| Ok(chunk.iter().collect::<String>()))
.collect::<Vec<_>>(),
)
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View file

@ -0,0 +1,388 @@
use ai::models::LanguageModel;
use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
use ai::prompts::file_context::FileContext;
use ai::prompts::generate::GenerateInlineContent;
use ai::prompts::preamble::EngineerPreamble;
use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
struct Match {
collapse: Range<usize>,
keep: Vec<Range<usize>>,
}
let selected_range = selected_range.to_offset(buffer);
let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
Some(&grammar.embedding_config.as_ref()?.query)
});
let configs = ts_matches
.grammars()
.iter()
.map(|g| g.embedding_config.as_ref().unwrap())
.collect::<Vec<_>>();
let mut matches = Vec::new();
while let Some(mat) = ts_matches.peek() {
let config = &configs[mat.grammar_index];
if let Some(collapse) = mat.captures.iter().find_map(|cap| {
if Some(cap.index) == config.collapse_capture_ix {
Some(cap.node.byte_range())
} else {
None
}
}) {
let mut keep = Vec::new();
for capture in mat.captures.iter() {
if Some(capture.index) == config.keep_capture_ix {
keep.push(capture.node.byte_range());
} else {
continue;
}
}
ts_matches.advance();
matches.push(Match { collapse, keep });
} else {
ts_matches.advance();
}
}
matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
let mut matches = matches.into_iter().peekable();
let mut summary = String::new();
let mut offset = 0;
let mut flushed_selection = false;
while let Some(mat) = matches.next() {
// Keep extending the collapsed range if the next match surrounds
// the current one.
while let Some(next_mat) = matches.peek() {
if mat.collapse.start <= next_mat.collapse.start
&& mat.collapse.end >= next_mat.collapse.end
{
matches.next().unwrap();
} else {
break;
}
}
if offset > mat.collapse.start {
// Skip collapsed nodes that have already been summarized.
offset = cmp::max(offset, mat.collapse.end);
continue;
}
if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
if !flushed_selection {
// The collapsed node ends after the selection starts, so we'll flush the selection first.
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
flushed_selection = true;
}
// If the selection intersects the collapsed node, we won't collapse it.
if selected_range.end >= mat.collapse.start {
continue;
}
}
summary.extend(buffer.text_for_range(offset..mat.collapse.start));
for keep in mat.keep {
summary.extend(buffer.text_for_range(keep));
}
offset = mat.collapse.end;
}
// Flush selection if we haven't already done so.
if !flushed_selection && offset <= selected_range.start {
summary.extend(buffer.text_for_range(offset..selected_range.start));
summary.push_str("<|S|");
if selected_range.end == selected_range.start {
summary.push_str(">");
} else {
summary.extend(buffer.text_for_range(selected_range.clone()));
summary.push_str("|E|>");
}
offset = selected_range.end;
}
summary.extend(buffer.text_for_range(offset..buffer.len()));
summary
}
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> {
// Using new Prompt Templates
let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
let lang_name = if let Some(language_name) = language_name {
Some(language_name.to_string())
} else {
None
};
let args = PromptArguments {
model: openai_model,
language_name: lang_name.clone(),
project_name,
snippets: search_results.clone(),
reserved_tokens: 1000,
buffer: Some(buffer),
selected_range: Some(range),
user_prompt: Some(user_prompt.clone()),
};
let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
(PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
(
PromptPriority::Ordered { order: 1 },
Box::new(RepositoryContext {}),
),
(
PromptPriority::Ordered { order: 0 },
Box::new(FileContext {}),
),
(
PromptPriority::Mandatory,
Box::new(GenerateInlineContent {}),
),
];
let chain = PromptChain::new(args, templates);
let (prompt, _) = chain.generate(true)?;
anyhow::Ok(prompt)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use std::sync::Arc;
use gpui::AppContext;
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use settings::SettingsStore;
pub(crate) fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_embedding_query(
r#"
(
[(line_comment) (attribute_item)]* @context
.
[
(struct_item
name: (_) @name)
(enum_item
name: (_) @name)
(impl_item
trait: (_)? @name
"for"? @name
type: (_) @name)
(trait_item
name: (_) @name)
(function_item
name: (_) @name
body: (block
"{" @keep
"}" @keep) @collapse)
(macro_definition
name: (_) @name)
] @item
)
"#,
)
.unwrap()
}
#[gpui::test]
fn test_outline_for_prompt(cx: &mut AppContext) {
cx.set_global(SettingsStore::test(cx));
language_settings::init(cx);
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
self.a
}
pub fn b(&self) -> usize {
self.b
}
}
"};
let buffer =
cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
indoc! {"
struct X {
<|S|>a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let <|S|a |E|>= 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
<|S|>
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
assert_eq!(
summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
<|S|>"}
);
// Ensure nested functions get collapsed properly.
let text = indoc! {"
struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {
let a = 1;
let b = 2;
Self { a, b }
}
pub fn a(&self, param: bool) -> usize {
let a = 30;
fn nested() -> usize {
3
}
self.a + nested()
}
pub fn b(&self) -> usize {
self.b
}
}
"};
buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
let snapshot = buffer.read(cx).snapshot();
assert_eq!(
summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
indoc! {"
<|S|>struct X {
a: usize,
b: usize,
}
impl X {
fn new() -> Self {}
pub fn a(&self, param: bool) -> usize {}
pub fn b(&self) -> usize {}
}
"}
);
}
}

View file

@ -0,0 +1,293 @@
use collections::HashMap;
use ordered_float::OrderedFloat;
use std::{
cmp,
fmt::{self, Debug},
ops::Range,
};
struct Matrix {
cells: Vec<f64>,
rows: usize,
cols: usize,
}
impl Matrix {
fn new() -> Self {
Self {
cells: Vec::new(),
rows: 0,
cols: 0,
}
}
fn resize(&mut self, rows: usize, cols: usize) {
self.cells.resize(rows * cols, 0.);
self.rows = rows;
self.cols = cols;
}
fn get(&self, row: usize, col: usize) -> f64 {
if row >= self.rows {
panic!("row out of bounds")
}
if col >= self.cols {
panic!("col out of bounds")
}
self.cells[col * self.rows + row]
}
fn set(&mut self, row: usize, col: usize, value: f64) {
if row >= self.rows {
panic!("row out of bounds")
}
if col >= self.cols {
panic!("col out of bounds")
}
self.cells[col * self.rows + row] = value;
}
}
impl Debug for Matrix {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f)?;
for i in 0..self.rows {
for j in 0..self.cols {
write!(f, "{:5}", self.get(i, j))?;
}
writeln!(f)?;
}
Ok(())
}
}
#[derive(Debug)]
pub enum Hunk {
Insert { text: String },
Remove { len: usize },
Keep { len: usize },
}
pub struct StreamingDiff {
old: Vec<char>,
new: Vec<char>,
scores: Matrix,
old_text_ix: usize,
new_text_ix: usize,
equal_runs: HashMap<(usize, usize), u32>,
}
impl StreamingDiff {
const INSERTION_SCORE: f64 = -1.;
const DELETION_SCORE: f64 = -20.;
const EQUALITY_BASE: f64 = 1.8;
const MAX_EQUALITY_EXPONENT: i32 = 16;
pub fn new(old: String) -> Self {
let old = old.chars().collect::<Vec<_>>();
let mut scores = Matrix::new();
scores.resize(old.len() + 1, 1);
for i in 0..=old.len() {
scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
}
Self {
old,
new: Vec::new(),
scores,
old_text_ix: 0,
new_text_ix: 0,
equal_runs: Default::default(),
}
}
pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
self.new.extend(text.chars());
self.scores.resize(self.old.len() + 1, self.new.len() + 1);
for j in self.new_text_ix + 1..=self.new.len() {
self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
for i in 1..=self.old.len() {
let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
let equality_score = if self.old[i - 1] == self.new[j - 1] {
let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0);
equal_run += 1;
self.equal_runs.insert((i, j), equal_run);
let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT);
self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
} else {
f64::NEG_INFINITY
};
let score = insertion_score.max(deletion_score).max(equality_score);
self.scores.set(i, j, score);
}
}
let mut max_score = f64::NEG_INFINITY;
let mut next_old_text_ix = self.old_text_ix;
let next_new_text_ix = self.new.len();
for i in self.old_text_ix..=self.old.len() {
let score = self.scores.get(i, next_new_text_ix);
if score > max_score {
max_score = score;
next_old_text_ix = i;
}
}
let hunks = self.backtrack(next_old_text_ix, next_new_text_ix);
self.old_text_ix = next_old_text_ix;
self.new_text_ix = next_new_text_ix;
hunks
}
fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
let mut pending_insert: Option<Range<usize>> = None;
let mut hunks = Vec::new();
let mut i = old_text_ix;
let mut j = new_text_ix;
while (i, j) != (self.old_text_ix, self.new_text_ix) {
let insertion_score = if j > self.new_text_ix {
Some((i, j - 1))
} else {
None
};
let deletion_score = if i > self.old_text_ix {
Some((i - 1, j))
} else {
None
};
let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
if self.old[i - 1] == self.new[j - 1] {
Some((i - 1, j - 1))
} else {
None
}
} else {
None
};
let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
.iter()
.max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
.unwrap()
.unwrap();
if prev_i == i && prev_j == j - 1 {
if let Some(pending_insert) = pending_insert.as_mut() {
pending_insert.start = prev_j;
} else {
pending_insert = Some(prev_j..j);
}
} else {
if let Some(range) = pending_insert.take() {
hunks.push(Hunk::Insert {
text: self.new[range].iter().collect(),
});
}
let char_len = self.old[i - 1].len_utf8();
if prev_i == i - 1 && prev_j == j {
if let Some(Hunk::Remove { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Remove { len: char_len })
}
} else {
if let Some(Hunk::Keep { len }) = hunks.last_mut() {
*len += char_len;
} else {
hunks.push(Hunk::Keep { len: char_len })
}
}
}
i = prev_i;
j = prev_j;
}
if let Some(range) = pending_insert.take() {
hunks.push(Hunk::Insert {
text: self.new[range].iter().collect(),
});
}
hunks.reverse();
hunks
}
pub fn finish(self) -> Vec<Hunk> {
self.backtrack(self.old.len(), self.new.len())
}
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use rand::prelude::*;
#[gpui::test(iterations = 100)]
fn test_random_diffs(mut rng: StdRng) {
let old_text_len = env::var("OLD_TEXT_LEN")
.map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable"))
.unwrap_or(10);
let new_text_len = env::var("NEW_TEXT_LEN")
.map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable"))
.unwrap_or(10);
let old = util::RandomCharIter::new(&mut rng)
.take(old_text_len)
.collect::<String>();
log::info!("old text: {:?}", old);
let mut diff = StreamingDiff::new(old.clone());
let mut hunks = Vec::new();
let mut new_len = 0;
let mut new = String::new();
while new_len < new_text_len {
let new_chunk_len = rng.gen_range(1..=new_text_len - new_len);
let new_chunk = util::RandomCharIter::new(&mut rng)
.take(new_len)
.collect::<String>();
log::info!("new chunk: {:?}", new_chunk);
new_len += new_chunk_len;
new.push_str(&new_chunk);
let new_hunks = diff.push_new(&new_chunk);
log::info!("hunks: {:?}", new_hunks);
hunks.extend(new_hunks);
}
let final_hunks = diff.finish();
log::info!("final hunks: {:?}", final_hunks);
hunks.extend(final_hunks);
log::info!("new text: {:?}", new);
let mut old_ix = 0;
let mut new_ix = 0;
let mut patched = String::new();
for hunk in hunks {
match hunk {
Hunk::Keep { len } => {
assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]);
patched.push_str(&old[old_ix..old_ix + len]);
old_ix += len;
new_ix += len;
}
Hunk::Remove { len } => {
old_ix += len;
}
Hunk::Insert { text } => {
assert_eq!(text, &new[new_ix..new_ix + text.len()]);
patched.push_str(&text);
new_ix += text.len();
}
}
}
assert_eq!(patched, new);
}
}