add retrieve context button to inline assistant
This commit is contained in:
parent
e9637267ef
commit
bfe76467b0
4 changed files with 131 additions and 93 deletions
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -108,7 +108,7 @@ dependencies = [
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tiktoken-rs 0.5.4",
|
"tiktoken-rs",
|
||||||
"util",
|
"util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -327,7 +327,7 @@ dependencies = [
|
||||||
"settings",
|
"settings",
|
||||||
"smol",
|
"smol",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs 0.4.5",
|
"tiktoken-rs",
|
||||||
"util",
|
"util",
|
||||||
"uuid 1.4.1",
|
"uuid 1.4.1",
|
||||||
"workspace",
|
"workspace",
|
||||||
|
@ -6798,7 +6798,7 @@ dependencies = [
|
||||||
"smol",
|
"smol",
|
||||||
"tempdir",
|
"tempdir",
|
||||||
"theme",
|
"theme",
|
||||||
"tiktoken-rs 0.5.4",
|
"tiktoken-rs",
|
||||||
"tree-sitter",
|
"tree-sitter",
|
||||||
"tree-sitter-cpp",
|
"tree-sitter-cpp",
|
||||||
"tree-sitter-elixir",
|
"tree-sitter-elixir",
|
||||||
|
@ -7875,21 +7875,6 @@ dependencies = [
|
||||||
"weezl",
|
"weezl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tiktoken-rs"
|
|
||||||
version = "0.4.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"base64 0.21.4",
|
|
||||||
"bstr",
|
|
||||||
"fancy-regex",
|
|
||||||
"lazy_static",
|
|
||||||
"parking_lot 0.12.1",
|
|
||||||
"rustc-hash",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tiktoken-rs"
|
name = "tiktoken-rs"
|
||||||
version = "0.5.4"
|
version = "0.5.4"
|
||||||
|
|
|
@ -38,7 +38,7 @@ schemars.workspace = true
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
smol.workspace = true
|
smol.workspace = true
|
||||||
tiktoken-rs = "0.4"
|
tiktoken-rs = "0.5"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
editor = { path = "../editor", features = ["test-support"] }
|
editor = { path = "../editor", features = ["test-support"] }
|
||||||
|
|
|
@ -437,8 +437,15 @@ impl AssistantPanel {
|
||||||
InlineAssistantEvent::Confirmed {
|
InlineAssistantEvent::Confirmed {
|
||||||
prompt,
|
prompt,
|
||||||
include_conversation,
|
include_conversation,
|
||||||
|
retrieve_context,
|
||||||
} => {
|
} => {
|
||||||
self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
|
self.confirm_inline_assist(
|
||||||
|
assist_id,
|
||||||
|
prompt,
|
||||||
|
*include_conversation,
|
||||||
|
cx,
|
||||||
|
*retrieve_context,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
InlineAssistantEvent::Canceled => {
|
InlineAssistantEvent::Canceled => {
|
||||||
self.finish_inline_assist(assist_id, true, cx);
|
self.finish_inline_assist(assist_id, true, cx);
|
||||||
|
@ -532,6 +539,7 @@ impl AssistantPanel {
|
||||||
user_prompt: &str,
|
user_prompt: &str,
|
||||||
include_conversation: bool,
|
include_conversation: bool,
|
||||||
cx: &mut ViewContext<Self>,
|
cx: &mut ViewContext<Self>,
|
||||||
|
retrieve_context: bool,
|
||||||
) {
|
) {
|
||||||
let conversation = if include_conversation {
|
let conversation = if include_conversation {
|
||||||
self.active_editor()
|
self.active_editor()
|
||||||
|
@ -593,42 +601,49 @@ impl AssistantPanel {
|
||||||
let codegen_kind = codegen.read(cx).kind().clone();
|
let codegen_kind = codegen.read(cx).kind().clone();
|
||||||
let user_prompt = user_prompt.to_string();
|
let user_prompt = user_prompt.to_string();
|
||||||
|
|
||||||
let project = if let Some(workspace) = self.workspace.upgrade(cx) {
|
let snippets = if retrieve_context {
|
||||||
workspace.read(cx).project()
|
let project = if let Some(workspace) = self.workspace.upgrade(cx) {
|
||||||
} else {
|
workspace.read(cx).project()
|
||||||
return;
|
} else {
|
||||||
};
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
let project = project.to_owned();
|
let project = project.to_owned();
|
||||||
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
|
let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
|
||||||
let search_results = semantic_index.update(cx, |this, cx| {
|
let search_results = semantic_index.update(cx, |this, cx| {
|
||||||
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
|
this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background()
|
||||||
|
.spawn(async move { search_results.await.unwrap_or_default() })
|
||||||
|
} else {
|
||||||
|
Task::ready(Vec::new())
|
||||||
|
};
|
||||||
|
|
||||||
|
let snippets = cx.spawn(|_, cx| async move {
|
||||||
|
let mut snippets = Vec::new();
|
||||||
|
for result in search_results.await {
|
||||||
|
snippets.push(result.buffer.read_with(&cx, |buffer, _| {
|
||||||
|
buffer
|
||||||
|
.snapshot()
|
||||||
|
.text_for_range(result.range)
|
||||||
|
.collect::<String>()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
snippets
|
||||||
});
|
});
|
||||||
|
snippets
|
||||||
cx.background()
|
|
||||||
.spawn(async move { search_results.await.unwrap_or_default() })
|
|
||||||
} else {
|
} else {
|
||||||
Task::ready(Vec::new())
|
Task::ready(Vec::new())
|
||||||
};
|
};
|
||||||
|
|
||||||
let snippets = cx.spawn(|_, cx| async move {
|
let mut model = settings::get::<AssistantSettings>(cx)
|
||||||
let mut snippets = Vec::new();
|
.default_open_ai_model
|
||||||
for result in search_results.await {
|
.clone();
|
||||||
snippets.push(result.buffer.read_with(&cx, |buffer, _| {
|
let model_name = model.full_name();
|
||||||
buffer
|
|
||||||
.snapshot()
|
|
||||||
.text_for_range(result.range)
|
|
||||||
.collect::<String>()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
snippets
|
|
||||||
});
|
|
||||||
|
|
||||||
let prompt = cx.background().spawn(async move {
|
let prompt = cx.background().spawn(async move {
|
||||||
let snippets = snippets.await;
|
let snippets = snippets.await;
|
||||||
for snippet in &snippets {
|
|
||||||
println!("SNIPPET: \n{:?}", snippet);
|
|
||||||
}
|
|
||||||
|
|
||||||
let language_name = language_name.as_deref();
|
let language_name = language_name.as_deref();
|
||||||
generate_content_prompt(
|
generate_content_prompt(
|
||||||
|
@ -638,13 +653,11 @@ impl AssistantPanel {
|
||||||
range,
|
range,
|
||||||
codegen_kind,
|
codegen_kind,
|
||||||
snippets,
|
snippets,
|
||||||
|
model_name,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
let mut model = settings::get::<AssistantSettings>(cx)
|
|
||||||
.default_open_ai_model
|
|
||||||
.clone();
|
|
||||||
if let Some(conversation) = conversation {
|
if let Some(conversation) = conversation {
|
||||||
let conversation = conversation.read(cx);
|
let conversation = conversation.read(cx);
|
||||||
let buffer = conversation.buffer.read(cx);
|
let buffer = conversation.buffer.read(cx);
|
||||||
|
@ -1557,12 +1570,14 @@ impl Conversation {
|
||||||
Role::Assistant => "assistant".into(),
|
Role::Assistant => "assistant".into(),
|
||||||
Role::System => "system".into(),
|
Role::System => "system".into(),
|
||||||
},
|
},
|
||||||
content: self
|
content: Some(
|
||||||
.buffer
|
self.buffer
|
||||||
.read(cx)
|
.read(cx)
|
||||||
.text_for_range(message.offset_range)
|
.text_for_range(message.offset_range)
|
||||||
.collect(),
|
.collect(),
|
||||||
|
),
|
||||||
name: None,
|
name: None,
|
||||||
|
function_call: None,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
@ -2681,6 +2696,7 @@ enum InlineAssistantEvent {
|
||||||
Confirmed {
|
Confirmed {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
include_conversation: bool,
|
include_conversation: bool,
|
||||||
|
retrieve_context: bool,
|
||||||
},
|
},
|
||||||
Canceled,
|
Canceled,
|
||||||
Dismissed,
|
Dismissed,
|
||||||
|
@ -2922,6 +2938,7 @@ impl InlineAssistant {
|
||||||
cx.emit(InlineAssistantEvent::Confirmed {
|
cx.emit(InlineAssistantEvent::Confirmed {
|
||||||
prompt,
|
prompt,
|
||||||
include_conversation: self.include_conversation,
|
include_conversation: self.include_conversation,
|
||||||
|
retrieve_context: self.retrieve_context,
|
||||||
});
|
});
|
||||||
self.confirmed = true;
|
self.confirmed = true;
|
||||||
cx.notify();
|
cx.notify();
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
use crate::codegen::CodegenKind;
|
use crate::codegen::CodegenKind;
|
||||||
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
|
||||||
use std::cmp;
|
use std::cmp;
|
||||||
|
use std::fmt::Write;
|
||||||
|
use std::iter;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::{fmt::Write, iter};
|
use tiktoken_rs::ChatCompletionRequestMessage;
|
||||||
|
|
||||||
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -122,69 +124,103 @@ pub fn generate_content_prompt(
|
||||||
range: Range<impl ToOffset>,
|
range: Range<impl ToOffset>,
|
||||||
kind: CodegenKind,
|
kind: CodegenKind,
|
||||||
search_results: Vec<String>,
|
search_results: Vec<String>,
|
||||||
|
model: &str,
|
||||||
) -> String {
|
) -> String {
|
||||||
let mut prompt = String::new();
|
const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
|
||||||
|
|
||||||
|
let mut prompts = Vec::new();
|
||||||
|
|
||||||
// General Preamble
|
// General Preamble
|
||||||
if let Some(language_name) = language_name {
|
if let Some(language_name) = language_name {
|
||||||
writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
|
prompts.push(format!("You're an expert {language_name} engineer.\n"));
|
||||||
} else {
|
} else {
|
||||||
writeln!(prompt, "You're an expert engineer.\n").unwrap();
|
prompts.push("You're an expert engineer.\n".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Snippets
|
||||||
|
let mut snippet_position = prompts.len() - 1;
|
||||||
|
|
||||||
let outline = summarize(buffer, range);
|
let outline = summarize(buffer, range);
|
||||||
writeln!(
|
prompts.push("The file you are currently working on has the following outline:".to_string());
|
||||||
prompt,
|
|
||||||
"The file you are currently working on has the following outline:"
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
if let Some(language_name) = language_name {
|
if let Some(language_name) = language_name {
|
||||||
let language_name = language_name.to_lowercase();
|
let language_name = language_name.to_lowercase();
|
||||||
writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
|
prompts.push(format!("```{language_name}\n{outline}\n```"));
|
||||||
} else {
|
} else {
|
||||||
writeln!(prompt, "```\n{outline}\n```").unwrap();
|
prompts.push(format!("```\n{outline}\n```"));
|
||||||
}
|
}
|
||||||
|
|
||||||
match kind {
|
match kind {
|
||||||
CodegenKind::Generate { position: _ } => {
|
CodegenKind::Generate { position: _ } => {
|
||||||
writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
|
prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
|
||||||
writeln!(
|
prompts
|
||||||
prompt,
|
.push("Assume the cursor is located where the `<|START|` marker is.".to_string());
|
||||||
"Assume the cursor is located where the `<|START|` marker is."
|
prompts.push(
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
writeln!(
|
|
||||||
prompt,
|
|
||||||
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
"Text can't be replaced, so assume your answer will be inserted at the cursor."
|
||||||
)
|
.to_string(),
|
||||||
.unwrap();
|
);
|
||||||
writeln!(
|
prompts.push(format!(
|
||||||
prompt,
|
|
||||||
"Generate text based on the users prompt: {user_prompt}"
|
"Generate text based on the users prompt: {user_prompt}"
|
||||||
)
|
));
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
CodegenKind::Transform { range: _ } => {
|
CodegenKind::Transform { range: _ } => {
|
||||||
writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
|
prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
|
||||||
writeln!(
|
prompts.push(format!(
|
||||||
prompt,
|
|
||||||
"Modify the users code selected text based upon the users prompt: {user_prompt}"
|
"Modify the users code selected text based upon the users prompt: {user_prompt}"
|
||||||
)
|
));
|
||||||
.unwrap();
|
prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
|
||||||
writeln!(
|
|
||||||
prompt,
|
|
||||||
"You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(language_name) = language_name {
|
if let Some(language_name) = language_name {
|
||||||
writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
|
prompts.push(format!("Your answer MUST always be valid {language_name}"));
|
||||||
}
|
}
|
||||||
writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
|
prompts.push("Always wrap your response in a Markdown codeblock".to_string());
|
||||||
writeln!(prompt, "Never make remarks about the output.").unwrap();
|
prompts.push("Never make remarks about the output.".to_string());
|
||||||
|
|
||||||
|
let current_messages = [ChatCompletionRequestMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: Some(prompts.join("\n")),
|
||||||
|
function_call: None,
|
||||||
|
name: None,
|
||||||
|
}];
|
||||||
|
|
||||||
|
let remaining_token_count = if let Ok(current_token_count) =
|
||||||
|
tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
|
||||||
|
{
|
||||||
|
let max_token_count = tiktoken_rs::model::get_context_size(model);
|
||||||
|
max_token_count - current_token_count
|
||||||
|
} else {
|
||||||
|
// If tiktoken fails to count token count, assume we have no space remaining.
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// - add repository name to snippet
|
||||||
|
// - add file path
|
||||||
|
// - add language
|
||||||
|
if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
|
||||||
|
let template = "You are working inside a large repository, here are a few code snippets that may be useful";
|
||||||
|
|
||||||
|
for search_result in search_results {
|
||||||
|
let mut snippet_prompt = template.to_string();
|
||||||
|
writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
|
||||||
|
|
||||||
|
let token_count = encoding
|
||||||
|
.encode_with_special_tokens(snippet_prompt.as_str())
|
||||||
|
.len();
|
||||||
|
if token_count <= remaining_token_count {
|
||||||
|
if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
|
||||||
|
prompts.insert(snippet_position, snippet_prompt);
|
||||||
|
snippet_position += 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt = prompts.join("\n");
|
||||||
|
println!("PROMPT: {:?}", prompt);
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue