Rely on model to determine indentation level and always rewrite the full line (#16145)

This PR simplifies our approach to indentation in the inline assistant
in hopes of improving our experience for Python. We tell the model to
generate the correct indentation in the prompt, and always start
generating at the start of the line. This may fall down for less capable
models, but I want to get a solid experience on the best models and then
figure the rest out later.

Also: We now prefer `./assets/prompts` as an overrides directory when
stdout is a PTY, so you can do `cargo run` and then iterate prompts for
the current run inside the current working copy.

cc @trishume @dsp-ant 

Release Notes:

- Zed now allows the model to control indentation when performing inline
transformation. We're hoping this improves the indentation experience in
Python and other indentation-sensitive languages, but it does require
more from the model.

---------

Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
Nathan Sobo 2024-08-12 22:41:24 -06:00 committed by GitHub
parent e662bfc74f
commit a515442a36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 235 additions and 565 deletions

View file

@ -1,14 +1,12 @@
Here's a text file that I'm going to ask you to make an edit to.
{{#if language_name}}
Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
{{else}}
Here's a file of text that I'm going to ask you to make an edit to.
The file is in {{language_name}}.
{{/if}}
{{#if is_insert}}
The point you'll need to insert at is marked with <insert_here></insert_here>.
{{else}}
The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
{{/if}}
You need to rewrite a portion of it.
The section you'll need to edit is marked with <rewrite_this></rewrite_this> tags.
<document>
{{{document_content}}}
@ -18,44 +16,37 @@ The section you'll need to rewrite is marked with <rewrite_this></rewrite_this>
The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
{{/if}}
{{#if is_insert}}
You can't replace {{content_type}}, your answer will be inserted in place of the `<insert_here></insert_here>` tags. Don't include the insert_here tags in your output.
Generate {{content_type}} based on the following prompt:
Rewrite the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
<prompt>
{{{user_prompt}}}
</prompt>
Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
Immediately start with the following format with no remarks:
```
{{INSERTED_CODE}}
```
{{else}}
Edit the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
<prompt>
{{{user_prompt}}}
</prompt>
{{#if rewrite_section}}
And here's the section to rewrite based on that prompt again for reference:
Here's the section to edit based on that prompt again for reference:
<rewrite_this>
{{{rewrite_section}}}
</rewrite_this>
You'll rewrite this entire section, but you will only make changes within certain subsections.
{{#if has_insertion}}
Insert text anywhere you see it marked with with <insert_here></insert_here> tags. Do not include <insert_here> tags in your output.
{{/if}}
{{#if has_replacement}}
Edit edit text that you see surrounded with <edit_here></edit_here> tags. Do not include <edit_here> tags in your output.
{{/if}}
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
<rewrite_this>
{{{rewrite_section_with_selections}}}
</rewrite_this>
Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. Do not output the <rewrite_this></rewrite this> tags or anything outside of them.
Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make. Always write out the whole section with no unnecessary elisions.
Immediately start with the following format with no remarks:
```
{{REWRITTEN_CODE}}
\{{REWRITTEN_CODE}}
```
{{/if}}

View file

@ -30,6 +30,7 @@ use language_model::{
};
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
use prompts::PromptOverrideContext;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
@ -168,7 +169,12 @@ impl Assistant {
}
}
pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
pub fn init(
fs: Arc<dyn Fs>,
client: Arc<Client>,
dev_mode: bool,
cx: &mut AppContext,
) -> Arc<PromptBuilder> {
cx.set_global(Assistant::default());
AssistantSettings::register(cx);
@ -203,10 +209,14 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<Pr
assistant_slash_command::init(cx);
assistant_panel::init(cx);
let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
.log_err()
.map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
dev_mode,
fs: fs.clone(),
cx,
}))
.log_err()
.map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
register_slash_commands(Some(prompt_builder.clone()), cx);
inline_assistant::init(
fs.clone(),

View file

@ -27,7 +27,7 @@ use gpui::{
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
UpdateGlobal, View, ViewContext, WeakView, WindowContext,
};
use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language::{Buffer, IndentKind, Point, TransactionId};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
@ -37,7 +37,6 @@ use rope::Rope;
use settings::Settings;
use smol::future::FutureExt;
use std::{
cmp,
future::{self, Future},
mem,
ops::{Range, RangeInclusive},
@ -46,6 +45,7 @@ use std::{
task::{self, Poll},
time::{Duration, Instant},
};
use text::ToOffset as _;
use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
@ -140,65 +140,74 @@ impl InlineAssistant {
) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let mut selections = Vec::<Selection<Point>>::new();
let mut newest_selection = None;
for mut selection in editor.read(cx).selections.all::<Point>(cx) {
if selection.end > selection.start {
selection.start.column = 0;
// If the selection ends at the start of the line, we don't want to include it.
if selection.end.column == 0 {
selection.end.row -= 1;
}
selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
}
struct CodegenRange {
transform_range: Range<Point>,
selection_ranges: Vec<Range<Point>>,
focus_assist: bool,
}
if let Some(prev_selection) = selections.last_mut() {
if selection.start <= prev_selection.end {
prev_selection.end = selection.end;
let newest_selection = editor.read(cx).selections.newest::<Point>(cx);
let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
for selection in editor.read(cx).selections.all::<Point>(cx) {
let selection_is_newest = selection.id == newest_selection.id;
let mut transform_range = selection.start..selection.end;
// Expand the transform range to start/end of lines.
// If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
transform_range.start.column = 0;
if transform_range.end.column == 0 && transform_range.end > transform_range.start {
transform_range.end.row -= 1;
}
transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
let selection_range = selection.start..selection.end.min(transform_range.end);
// If we intersect the previous transform range,
if let Some(CodegenRange {
transform_range: prev_transform_range,
selection_ranges,
focus_assist,
}) = codegen_ranges.last_mut()
{
if transform_range.start <= prev_transform_range.end {
prev_transform_range.end = transform_range.end;
selection_ranges.push(selection_range);
*focus_assist |= selection_is_newest;
continue;
}
}
let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
if selection.id > latest_selection.id {
*latest_selection = selection.clone();
}
selections.push(selection);
}
let newest_selection = newest_selection.unwrap();
let mut codegen_ranges = Vec::new();
for (excerpt_id, buffer, buffer_range) in
snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
}))
{
let start = Anchor {
buffer_id: Some(buffer.remote_id()),
excerpt_id,
text_anchor: buffer.anchor_before(buffer_range.start),
};
let end = Anchor {
buffer_id: Some(buffer.remote_id()),
excerpt_id,
text_anchor: buffer.anchor_after(buffer_range.end),
};
codegen_ranges.push(start..end);
codegen_ranges.push(CodegenRange {
transform_range,
selection_ranges: vec![selection_range],
focus_assist: selection_is_newest,
})
}
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
let mut assists = Vec::new();
let mut assist_to_focus = None;
for range in codegen_ranges {
let assist_id = self.next_assist_id.post_inc();
for CodegenRange {
transform_range,
selection_ranges,
focus_assist,
} in codegen_ranges
{
let transform_range = snapshot.anchor_before(transform_range.start)
..snapshot.anchor_after(transform_range.end);
let selection_ranges = selection_ranges
.iter()
.map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
.collect::<Vec<_>>();
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
transform_range.clone(),
selection_ranges,
None,
self.telemetry.clone(),
self.prompt_builder.clone(),
@ -206,6 +215,7 @@ impl InlineAssistant {
)
});
let assist_id = self.next_assist_id.post_inc();
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
PromptEditor::new(
@ -222,23 +232,16 @@ impl InlineAssistant {
)
});
if assist_to_focus.is_none() {
let focus_assist = if newest_selection.reversed {
range.start.to_point(&snapshot) == newest_selection.start
} else {
range.end.to_point(&snapshot) == newest_selection.end
};
if focus_assist {
assist_to_focus = Some(assist_id);
}
if focus_assist {
assist_to_focus = Some(assist_id);
}
let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
assists.push((
assist_id,
range,
transform_range,
prompt_editor,
prompt_block_id,
end_block_id,
@ -305,6 +308,7 @@ impl InlineAssistant {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
vec![range.clone()],
initial_transaction_id,
self.telemetry.clone(),
self.prompt_builder.clone(),
@ -888,12 +892,7 @@ impl InlineAssistant {
assist
.codegen
.update(cx, |codegen, cx| {
codegen.start(
assist.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
codegen.start(user_prompt, assistant_panel_context, cx)
})
.log_err();
@ -2084,12 +2083,9 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
self.codegen.read(cx).count_tokens(
self.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
self.codegen
.read(cx)
.count_tokens(user_prompt, assistant_panel_context, cx)
}
}
@ -2110,6 +2106,8 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
transform_range: Range<Anchor>,
selected_ranges: Vec<Range<Anchor>>,
edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>,
@ -2119,7 +2117,7 @@ pub struct Codegen {
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
builder: Arc<PromptBuilder>,
prompt_builder: Arc<PromptBuilder>,
}
enum CodegenStatus {
@ -2146,7 +2144,8 @@ impl EventEmitter<CodegenEvent> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
range: Range<Anchor>,
transform_range: Range<Anchor>,
selected_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
@ -2156,7 +2155,7 @@ impl Codegen {
let (old_buffer, _, _) = buffer
.read(cx)
.range_to_buffer_ranges(range.clone(), cx)
.range_to_buffer_ranges(transform_range.clone(), cx)
.pop()
.unwrap();
let old_buffer = cx.new_model(|cx| {
@ -2187,7 +2186,9 @@ impl Codegen {
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id,
builder,
prompt_builder: builder,
transform_range,
selected_ranges,
}
}
@ -2212,13 +2213,12 @@ impl Codegen {
pub fn count_tokens(
&self,
edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
let request = self.build_request(user_prompt, assistant_panel_context, cx);
match request {
Ok(request) => model.count_tokens(request, cx),
Err(error) => futures::future::ready(Err(error)).boxed(),
@ -2230,7 +2230,6 @@ impl Codegen {
pub fn start(
&mut self,
edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
@ -2245,24 +2244,20 @@ impl Codegen {
});
}
self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
.trim()
.to_lowercase()
== "delete"
{
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
if user_prompt.trim().to_lowercase() == "delete" {
async { Ok(stream::empty().boxed()) }.boxed_local()
} else {
let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
let chunks =
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(telemetry_id, edit_range, chunks, cx);
let chunks =
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local()
};
self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
Ok(())
}
@ -2270,11 +2265,10 @@ impl Codegen {
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
edit_range: Range<Anchor>,
cx: &AppContext,
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(edit_range.start);
let language = buffer.language_at(self.transform_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@ -2299,8 +2293,8 @@ impl Codegen {
};
let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(edit_range.start);
let end = buffer.point_to_buffer_offset(edit_range.end);
let start = buffer.point_to_buffer_offset(self.transform_range.start);
let end = buffer.point_to_buffer_offset(self.transform_range.end);
let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
@ -2312,9 +2306,20 @@ impl Codegen {
} else {
return Err(anyhow::anyhow!("invalid transformation range"));
};
let selected_ranges = self
.selected_ranges
.iter()
.map(|range| {
let start = range.start.text_anchor.to_offset(&buffer);
let end = range.end.text_anchor.to_offset(&buffer);
start..end
})
.collect::<Vec<_>>();
let prompt = self
.builder
.generate_content_prompt(user_prompt, language_name, buffer, range)
.prompt_builder
.generate_content_prompt(user_prompt, language_name, buffer, range, selected_ranges)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new();
@ -2386,84 +2391,19 @@ impl Codegen {
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
new_text.clear();
}
if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
new_text.clear();
}
line_indent = None;
first_line = false;
}
}
let char_ops = diff.push_new(&chunk);
line_diff.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
}
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
let char_ops = diff.finish();
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
@ -2824,311 +2764,13 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
use language::{
language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
Point,
};
use language_model::LanguageModelRegistry;
use rand::prelude::*;
use serde::Serialize;
use settings::SettingsStore;
use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range,
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
func main() {
\tx := 0
\tfor i := 0; i < 10; i++ {
\t\tx++
\t}
}
"};
let buffer = cx.new_model(|cx| Buffer::local(text, cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let new_text = concat!(
"func main() {\n",
"\tx := 0\n",
"\tfor x < 10 {\n",
"\t\tx++\n",
"\t}", //
);
chunks_tx.unbounded_send(new_text.to_string()).unwrap();
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
func main() {
\tx := 0
\tfor x < 10 {
\t\tx++
\t}
}
"}
);
}
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@ -3168,27 +2810,4 @@ mod tests {
)
}
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View file

@ -12,11 +12,13 @@ use util::ResultExt;
pub struct ContentPromptContext {
pub content_type: String,
pub language_name: Option<String>,
pub is_insert: bool,
pub is_truncated: bool,
pub document_content: String,
pub user_prompt: String,
pub rewrite_section: Option<String>,
pub rewrite_section: String,
pub rewrite_section_with_selections: String,
pub has_insertion: bool,
pub has_replacement: bool,
}
#[derive(Serialize)]
@ -33,41 +35,54 @@ pub struct PromptBuilder {
handlebars: Arc<Mutex<Handlebars<'static>>>,
}
pub struct PromptOverrideContext<'a> {
pub dev_mode: bool,
pub fs: Arc<dyn Fs>,
pub cx: &'a mut gpui::AppContext,
}
impl PromptBuilder {
pub fn new(
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
) -> Result<Self, Box<TemplateError>> {
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
let mut handlebars = Handlebars::new();
Self::register_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars));
if let Some((fs, cx)) = fs_and_cx {
Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
if let Some(override_cx) = override_cx {
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
}
Ok(Self { handlebars })
}
fn watch_fs_for_template_overrides(
fs: Arc<dyn Fs>,
cx: &gpui::AppContext,
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
handlebars: Arc<Mutex<Handlebars<'static>>>,
) {
let templates_dir = paths::prompt_overrides_dir();
cx.background_executor()
.spawn(async move {
let templates_dir = if dev_mode {
std::env::current_dir()
.ok()
.and_then(|pwd| {
let pwd_assets_prompts = pwd.join("assets").join("prompts");
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
})
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
} else {
paths::prompt_overrides_dir().clone()
};
// Create the prompt templates directory if it doesn't exist
if !fs.is_dir(templates_dir).await {
if let Err(e) = fs.create_dir(templates_dir).await {
if !fs.is_dir(&templates_dir).await {
if let Err(e) = fs.create_dir(&templates_dir).await {
log::error!("Failed to create prompt templates directory: {}", e);
return;
}
}
// Initial scan of the prompts directory
if let Ok(mut entries) = fs.read_dir(templates_dir).await {
if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
while let Some(Ok(file_path)) = entries.next().await {
if file_path.to_string_lossy().ends_with(".hbs") {
if let Ok(content) = fs.load(&file_path).await {
@ -95,7 +110,7 @@ impl PromptBuilder {
}
// Watch for changes
let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
while let Some(changed_paths) = changes.next().await {
for changed_path in changed_paths {
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
@ -147,7 +162,8 @@ impl PromptBuilder {
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
transform_range: Range<usize>,
selected_ranges: Vec<Range<usize>>,
) -> Result<String, RenderError> {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@ -155,21 +171,20 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
let is_insert = range.is_empty();
let mut is_truncated = false;
let before_range = 0..range.start;
let before_range = 0..transform_range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
range.start - MAX_CTX..range.start
transform_range.start - MAX_CTX..transform_range.start
} else {
before_range
};
let after_range = range.end..buffer.len();
let after_range = transform_range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
range.end..range.end + MAX_CTX
transform_range.end..transform_range.end + MAX_CTX
} else {
after_range
};
@ -178,37 +193,61 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
if is_insert {
document_content.push_str("<insert_here></insert_here>");
} else {
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(range.clone()) {
document_content.push_str(chunk);
}
document_content.push_str("\n</rewrite_this>");
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(transform_range.clone()) {
document_content.push_str(chunk);
}
document_content.push_str("\n</rewrite_this>");
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
let rewrite_section = if !is_insert {
let mut section = String::new();
for chunk in buffer.text_for_range(range.clone()) {
section.push_str(chunk);
let mut rewrite_section = String::new();
for chunk in buffer.text_for_range(transform_range.clone()) {
rewrite_section.push_str(chunk);
}
let rewrite_section_with_selections = {
let mut section_with_selections = String::new();
let mut last_end = 0;
for selected_range in &selected_ranges {
if selected_range.start > last_end {
section_with_selections.push_str(
&rewrite_section[last_end..selected_range.start - transform_range.start],
);
}
if selected_range.start == selected_range.end {
section_with_selections.push_str("<insert_here></insert_here>");
} else {
section_with_selections.push_str("<edit_here>");
section_with_selections.push_str(
&rewrite_section[selected_range.start - transform_range.start
..selected_range.end - transform_range.start],
);
section_with_selections.push_str("</edit_here>");
}
last_end = selected_range.end - transform_range.start;
}
Some(section)
} else {
None
if last_end < rewrite_section.len() {
section_with_selections.push_str(&rewrite_section[last_end..]);
}
section_with_selections
};
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
let context = ContentPromptContext {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
is_insert,
is_truncated,
document_content,
user_prompt,
rewrite_section,
rewrite_section_with_selections,
has_insertion,
has_replacement,
};
self.handlebars.lock().render("content_prompt", &context)

View file

@ -187,7 +187,12 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) -> Arc<PromptBuild
);
snippet_provider::init(cx);
inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
let prompt_builder = assistant::init(
app_state.fs.clone(),
app_state.client.clone(),
stdout_is_a_pty(),
cx,
);
repl::init(
app_state.fs.clone(),
app_state.client.telemetry().clone(),

View file

@ -1018,6 +1018,8 @@ fn open_settings_file(
#[cfg(test)]
mod tests {
use crate::stdout_is_a_pty;
use super::*;
use anyhow::anyhow;
use assets::Assets;
@ -3485,8 +3487,12 @@ mod tests {
app_state.fs.clone(),
cx,
);
let prompt_builder =
assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
let prompt_builder = assistant::init(
app_state.fs.clone(),
app_state.client.clone(),
stdout_is_a_pty(),
cx,
);
repl::init(
app_state.fs.clone(),
app_state.client.telemetry().clone(),