add retrieve context button to inline assistant

This commit is contained in:
KCaverly 2023-10-03 11:19:54 +03:00
parent e9637267ef
commit bfe76467b0
4 changed files with 131 additions and 93 deletions

21
Cargo.lock generated
View file

@ -108,7 +108,7 @@ dependencies = [
"rusqlite", "rusqlite",
"serde", "serde",
"serde_json", "serde_json",
"tiktoken-rs 0.5.4", "tiktoken-rs",
"util", "util",
] ]
@ -327,7 +327,7 @@ dependencies = [
"settings", "settings",
"smol", "smol",
"theme", "theme",
"tiktoken-rs 0.4.5", "tiktoken-rs",
"util", "util",
"uuid 1.4.1", "uuid 1.4.1",
"workspace", "workspace",
@ -6798,7 +6798,7 @@ dependencies = [
"smol", "smol",
"tempdir", "tempdir",
"theme", "theme",
"tiktoken-rs 0.5.4", "tiktoken-rs",
"tree-sitter", "tree-sitter",
"tree-sitter-cpp", "tree-sitter-cpp",
"tree-sitter-elixir", "tree-sitter-elixir",
@ -7875,21 +7875,6 @@ dependencies = [
"weezl", "weezl",
] ]
[[package]]
name = "tiktoken-rs"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
dependencies = [
"anyhow",
"base64 0.21.4",
"bstr",
"fancy-regex",
"lazy_static",
"parking_lot 0.12.1",
"rustc-hash",
]
[[package]] [[package]]
name = "tiktoken-rs" name = "tiktoken-rs"
version = "0.5.4" version = "0.5.4"

View file

@ -38,7 +38,7 @@ schemars.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
smol.workspace = true smol.workspace = true
tiktoken-rs = "0.4" tiktoken-rs = "0.5"
[dev-dependencies] [dev-dependencies]
editor = { path = "../editor", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] }

View file

@ -437,8 +437,15 @@ impl AssistantPanel {
InlineAssistantEvent::Confirmed { InlineAssistantEvent::Confirmed {
prompt, prompt,
include_conversation, include_conversation,
retrieve_context,
} => { } => {
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); self.confirm_inline_assist(
assist_id,
prompt,
*include_conversation,
cx,
*retrieve_context,
);
} }
InlineAssistantEvent::Canceled => { InlineAssistantEvent::Canceled => {
self.finish_inline_assist(assist_id, true, cx); self.finish_inline_assist(assist_id, true, cx);
@ -532,6 +539,7 @@ impl AssistantPanel {
user_prompt: &str, user_prompt: &str,
include_conversation: bool, include_conversation: bool,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
retrieve_context: bool,
) { ) {
let conversation = if include_conversation { let conversation = if include_conversation {
self.active_editor() self.active_editor()
@ -593,42 +601,49 @@ impl AssistantPanel {
let codegen_kind = codegen.read(cx).kind().clone(); let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string(); let user_prompt = user_prompt.to_string();
let project = if let Some(workspace) = self.workspace.upgrade(cx) { let snippets = if retrieve_context {
workspace.read(cx).project() let project = if let Some(workspace) = self.workspace.upgrade(cx) {
} else { workspace.read(cx).project()
return; } else {
}; return;
};
let project = project.to_owned(); let project = project.to_owned();
let search_results = if let Some(semantic_index) = self.semantic_index.clone() { let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
let search_results = semantic_index.update(cx, |this, cx| { let search_results = semantic_index.update(cx, |this, cx| {
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
});
cx.background()
.spawn(async move { search_results.await.unwrap_or_default() })
} else {
Task::ready(Vec::new())
};
let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new();
for result in search_results.await {
snippets.push(result.buffer.read_with(&cx, |buffer, _| {
buffer
.snapshot()
.text_for_range(result.range)
.collect::<String>()
}));
}
snippets
}); });
snippets
cx.background()
.spawn(async move { search_results.await.unwrap_or_default() })
} else { } else {
Task::ready(Vec::new()) Task::ready(Vec::new())
}; };
let snippets = cx.spawn(|_, cx| async move { let mut model = settings::get::<AssistantSettings>(cx)
let mut snippets = Vec::new(); .default_open_ai_model
for result in search_results.await { .clone();
snippets.push(result.buffer.read_with(&cx, |buffer, _| { let model_name = model.full_name();
buffer
.snapshot()
.text_for_range(result.range)
.collect::<String>()
}));
}
snippets
});
let prompt = cx.background().spawn(async move { let prompt = cx.background().spawn(async move {
let snippets = snippets.await; let snippets = snippets.await;
for snippet in &snippets {
println!("SNIPPET: \n{:?}", snippet);
}
let language_name = language_name.as_deref(); let language_name = language_name.as_deref();
generate_content_prompt( generate_content_prompt(
@ -638,13 +653,11 @@ impl AssistantPanel {
range, range,
codegen_kind, codegen_kind,
snippets, snippets,
model_name,
) )
}); });
let mut messages = Vec::new(); let mut messages = Vec::new();
let mut model = settings::get::<AssistantSettings>(cx)
.default_open_ai_model
.clone();
if let Some(conversation) = conversation { if let Some(conversation) = conversation {
let conversation = conversation.read(cx); let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx); let buffer = conversation.buffer.read(cx);
@ -1557,12 +1570,14 @@ impl Conversation {
Role::Assistant => "assistant".into(), Role::Assistant => "assistant".into(),
Role::System => "system".into(), Role::System => "system".into(),
}, },
content: self content: Some(
.buffer self.buffer
.read(cx) .read(cx)
.text_for_range(message.offset_range) .text_for_range(message.offset_range)
.collect(), .collect(),
),
name: None, name: None,
function_call: None,
}) })
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -2681,6 +2696,7 @@ enum InlineAssistantEvent {
Confirmed { Confirmed {
prompt: String, prompt: String,
include_conversation: bool, include_conversation: bool,
retrieve_context: bool,
}, },
Canceled, Canceled,
Dismissed, Dismissed,
@ -2922,6 +2938,7 @@ impl InlineAssistant {
cx.emit(InlineAssistantEvent::Confirmed { cx.emit(InlineAssistantEvent::Confirmed {
prompt, prompt,
include_conversation: self.include_conversation, include_conversation: self.include_conversation,
retrieve_context: self.retrieve_context,
}); });
self.confirmed = true; self.confirmed = true;
cx.notify(); cx.notify();

View file

@ -1,8 +1,10 @@
use crate::codegen::CodegenKind; use crate::codegen::CodegenKind;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp; use std::cmp;
use std::fmt::Write;
use std::iter;
use std::ops::Range; use std::ops::Range;
use std::{fmt::Write, iter}; use tiktoken_rs::ChatCompletionRequestMessage;
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String { fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)] #[derive(Debug)]
@ -122,69 +124,103 @@ pub fn generate_content_prompt(
range: Range<impl ToOffset>, range: Range<impl ToOffset>,
kind: CodegenKind, kind: CodegenKind,
search_results: Vec<String>, search_results: Vec<String>,
model: &str,
) -> String { ) -> String {
let mut prompt = String::new(); const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
let mut prompts = Vec::new();
// General Preamble // General Preamble
if let Some(language_name) = language_name { if let Some(language_name) = language_name {
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap(); prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else { } else {
writeln!(prompt, "You're an expert engineer.\n").unwrap(); prompts.push("You're an expert engineer.\n".to_string());
} }
// Snippets
let mut snippet_position = prompts.len() - 1;
let outline = summarize(buffer, range); let outline = summarize(buffer, range);
writeln!( prompts.push("The file you are currently working on has the following outline:".to_string());
prompt,
"The file you are currently working on has the following outline:"
)
.unwrap();
if let Some(language_name) = language_name { if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase(); let language_name = language_name.to_lowercase();
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap(); prompts.push(format!("```{language_name}\n{outline}\n```"));
} else { } else {
writeln!(prompt, "```\n{outline}\n```").unwrap(); prompts.push(format!("```\n{outline}\n```"));
} }
match kind { match kind {
CodegenKind::Generate { position: _ } => { CodegenKind::Generate { position: _ } => {
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap(); prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
writeln!( prompts
prompt, .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
"Assume the cursor is located where the `<|START|` marker is." prompts.push(
)
.unwrap();
writeln!(
prompt,
"Text can't be replaced, so assume your answer will be inserted at the cursor." "Text can't be replaced, so assume your answer will be inserted at the cursor."
) .to_string(),
.unwrap(); );
writeln!( prompts.push(format!(
prompt,
"Generate text based on the users prompt: {user_prompt}" "Generate text based on the users prompt: {user_prompt}"
) ));
.unwrap();
} }
CodegenKind::Transform { range: _ } => { CodegenKind::Transform { range: _ } => {
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
writeln!( prompts.push(format!(
prompt,
"Modify the users code selected text based upon the users prompt: {user_prompt}" "Modify the users code selected text based upon the users prompt: {user_prompt}"
) ));
.unwrap(); prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
writeln!(
prompt,
"You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
)
.unwrap();
} }
} }
if let Some(language_name) = language_name { if let Some(language_name) = language_name {
writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap(); prompts.push(format!("Your answer MUST always be valid {language_name}"));
} }
writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap(); prompts.push("Always wrap your response in a Markdown codeblock".to_string());
writeln!(prompt, "Never make remarks about the output.").unwrap(); prompts.push("Never make remarks about the output.".to_string());
let current_messages = [ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(prompts.join("\n")),
function_call: None,
name: None,
}];
let remaining_token_count = if let Ok(current_token_count) =
tiktoken_rs::num_tokens_from_messages(model, &current_messages)
{
let max_token_count = tiktoken_rs::model::get_context_size(model);
max_token_count - current_token_count
} else {
// If tiktoken fails to count token count, assume we have no space remaining.
0
};
// TODO:
// - add repository name to snippet
// - add file path
// - add language
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
let template = "You are working inside a large repository, here are a few code snippets that may be useful";
for search_result in search_results {
let mut snippet_prompt = template.to_string();
writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
let token_count = encoding
.encode_with_special_tokens(snippet_prompt.as_str())
.len();
if token_count <= remaining_token_count {
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
prompts.insert(snippet_position, snippet_prompt);
snippet_position += 1;
}
} else {
break;
}
}
}
let prompt = prompts.join("\n");
println!("PROMPT: {:?}", prompt);
prompt prompt
} }