Revert changes to inline assist indentation logic and prompt (#16403)

This PR reverts #16145 and subsequent changes.

This reverts commit a515442a36.

We still have issues with our approach to indentation in Python
unfortunately, but this feels like a safer equilibrium than where we
were.

Release Notes:

- Returned to our previous prompt for inline assist transformations,
since recent changes were introducing issues.
This commit is contained in:
Nathan Sobo 2024-08-17 02:24:55 -06:00 committed by GitHub
parent ebecd7e65f
commit 07d5e22cbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 574 additions and 658 deletions

View file

@ -1,426 +1,61 @@
You are an expert developer assistant working in an AI-enabled text editor. {{#if language_name}}
Your task is to rewrite a specific section of the provided document based on a user-provided prompt. Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
{{else}}
Here's a file of text that I'm going to ask you to make an edit to.
{{/if}}
<guidelines> {{#if is_insert}}
1. Scope: Modify only content within <rewrite_this> tags. Do not alter anything outside these boundaries. The point you'll need to insert at is marked with <insert_here></insert_here>.
2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is. {{else}}
3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure. The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
4. Tag exclusion: Never include <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags in the output. {{/if}}
5. Indentation: Maintain the original indentation level of the file in rewritten sections.
6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions.
7. Insertions: Replace <insert_here></insert_here> tags with appropriate content as specified by the prompt.
8. Code integrity: Respect existing code structure and functionality when making changes.
9. Consistency: Maintain a uniform style and tone throughout the rewritten text.
</guidelines>
<examples>
<example>
<input>
<document> <document>
use std::cell::Cell;
use std::collections::HashMap;
use std::cmp;
<rewrite_this>
<insert_here></insert_here>
</rewrite_this>
pub struct LruCache<K, V> {
/// The maximum number of items the cache can hold.
capacity: usize,
/// The map storing the cached items.
items: HashMap<K, V>,
}
// The rest of the implementation...
</document>
<prompt>
doc this
</prompt>
</input>
<incorrect_output failure="Over-generation. The text starting with `pub struct AabbTree<T> {` is *after* the rewrite_this tag">
/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
///
/// This structure is used for efficient spatial queries and collision detection.
/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
///
/// # Type Parameters
///
/// * `T`: The type of data associated with each node in the tree.
pub struct AabbTree<T> {
root: Option<usize>,
</incorrect_output>
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
///
/// This structure is used for efficient spatial queries and collision detection.
/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
///
/// # Type Parameters
///
/// * `T`: The type of data associated with each node in the tree.
</corrected_output>
</example>
<example>
<input>
<document>
import math
def calculate_circle_area(radius):
"""Calculate the area of a circle given its radius."""
return math.pi * radius ** 2
<rewrite_this>
<insert_here></insert_here>
</rewrite_this>
class Circle:
def __init__(self, radius):
self.radius = radius
def area(self):
return math.pi * self.radius ** 2
def circumference(self):
return 2 * math.pi * self.radius
# Usage example
circle = Circle(5)
print(f"Area: {circle.area():.2f}")
print(f"Circumference: {circle.circumference():.2f}")
</document>
<prompt>
write docs
</prompt>
</input>
<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
"""
Represents a circle with methods to calculate its area and circumference.
This class provides a simple way to work with circles in a geometric context.
It allows for the creation of Circle objects with a specified radius and
offers methods to compute the circle's area and circumference.
Attributes:
radius (float): The radius of the circle.
Methods:
area(): Calculates and returns the area of the circle.
circumference(): Calculates and returns the circumference of the circle.
"""
class Circle:
</incorrect_output>
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
"""
Represents a circle with methods to calculate its area and circumference.
This class provides a simple way to work with circles in a geometric context.
It allows for the creation of Circle objects with a specified radius and
offers methods to compute the circle's area and circumference.
Attributes:
radius (float): The radius of the circle.
Methods:
area(): Calculates and returns the area of the circle.
circumference(): Calculates and returns the circumference of the circle.
"""
</corrected_output>
</example>
<example>
<input>
<document>
class BankAccount {
private balance: number;
constructor(initialBalance: number) {
this.balance = initialBalance;
}
<rewrite_this>
<insert_here></insert_here>
</rewrite_this>
deposit(amount: number): void {
if (amount > 0) {
this.balance += amount;
}
}
withdraw(amount: number): boolean {
if (amount > 0 && this.balance >= amount) {
this.balance -= amount;
return true;
}
return false;
}
getBalance(): number {
return this.balance;
}
}
// Usage
const account = new BankAccount(1000);
account.deposit(500);
console.log(account.getBalance()); // 1500
account.withdraw(200);
console.log(account.getBalance()); // 1300
</document>
<prompt>
//
</prompt>
</input>
<incorrect_output failure="Over-generation. The text starting with `deposit(amount: number): void {` is *after* the rewrite_this tag">
/**
* Deposits the specified amount into the bank account.
*
* @param amount The amount to deposit. Must be a positive number.
* @throws Error if the amount is not positive.
*/
deposit(amount: number): void {
if (amount > 0) {
this.balance += amount;
} else {
throw new Error("Deposit amount must be positive");
}
}
</incorrect_output>
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
/**
* Deposits the specified amount into the bank account.
*
* @param amount The amount to deposit. Must be a positive number.
* @throws Error if the amount is not positive.
*/
</corrected_output>
</example>
<example>
<input>
<document>
use std::collections::VecDeque;
pub struct BinaryTree<T> {
root: Option<Node<T>>,
}
<rewrite_this>
<insert_here></insert_here>
</rewrite_this>
struct Node<T> {
value: T,
left: Option<Box<Node<T>>>,
right: Option<Box<Node<T>>>,
}
</document>
<prompt>
derive clone
</prompt>
</input>
<incorrect_output failure="Over-generation below the rewrite_this tags. Extra space between derive annotation and struct definition.">
#[derive(Clone)]
struct Node<T> {
value: T,
left: Option<Box<Node<T>>>,
right: Option<Box<Node<T>>>,
}
</incorrect_output>
<incorrect_output failure="Over-generation above the rewrite_this tags">
pub struct BinaryTree<T> {
root: Option<Node<T>>,
}
#[derive(Clone)]
</incorrect_output>
<incorrect_output failure="Over-generation below the rewrite_this tags">
#[derive(Clone)]
struct Node<T> {
value: T,
left: Option<Box<Node<T>>>,
right: Option<Box<Node<T>>>,
}
impl<T> Node<T> {
fn new(value: T) -> Self {
Node {
value,
left: None,
right: None,
}
}
}
</incorrect_output>
<corrected_output improvement="Only includes the new content within the rewrite_this tags">
#[derive(Clone)]
</corrected_output>
</example>
<example>
<input>
<document>
import math
def calculate_circle_area(radius):
"""Calculate the area of a circle given its radius."""
return math.pi * radius ** 2
<rewrite_this>
<insert_here></insert_here>
</rewrite_this>
class Circle:
def __init__(self, radius):
self.radius = radius
def area(self):
return math.pi * self.radius ** 2
def circumference(self):
return 2 * math.pi * self.radius
# Usage example
circle = Circle(5)
print(f"Area: {circle.area():.2f}")
print(f"Circumference: {circle.circumference():.2f}")
</document>
<prompt>
add dataclass decorator
</prompt>
</input>
<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
@dataclass
class Circle:
radius: float
def __init__(self, radius):
self.radius = radius
def area(self):
return math.pi * self.radius ** 2
</incorrect_output>
<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
@dataclass
</corrected_output>
</example>
<example>
<input>
<document>
interface ShoppingCart {
items: string[];
total: number;
}
<rewrite_this>
<insert_here></insert_here>class ShoppingCartManager {
</rewrite_this>
private cart: ShoppingCart;
constructor() {
this.cart = { items: [], total: 0 };
}
addItem(item: string, price: number): void {
this.cart.items.push(item);
this.cart.total += price;
}
getTotal(): number {
return this.cart.total;
}
}
// Usage
const manager = new ShoppingCartManager();
manager.addItem("Book", 15.99);
console.log(manager.getTotal()); // 15.99
</document>
<prompt>
add readonly modifier
</prompt>
</input>
<incorrect_output failure="Over-generation. The line starting with ` items: string[];` is *after* the rewrite_this tag">
readonly interface ShoppingCart {
items: string[];
total: number;
}
class ShoppingCartManager {
private readonly cart: ShoppingCart;
constructor() {
this.cart = { items: [], total: 0 };
}
</incorrect_output>
<corrected_output improvement="Only includes the new content within the rewrite_this tags and integrates cleanly into surrounding code">
readonly interface ShoppingCart {
</corrected_output>
</example>
</examples>
With these examples in mind, edit the following file:
<document language="{{ language_name }}">
{{{document_content}}} {{{document_content}}}
</document> </document>
{{#if is_truncated}} {{#if is_truncated}}
The provided document has been truncated (potentially mid-line) for brevity. The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
{{/if}} {{/if}}
<instructions> {{#if is_insert}}
{{#if has_insertion}} You can't replace {{content_type}}, your answer will be inserted in place of the `<insert_here></insert_here>` tags. Don't include the insert_here tags in your output.
Insert text anywhere you see marked with <insert_here></insert_here> tags. It's CRITICAL that you DO NOT include <insert_here> tags in your output.
{{/if}}
{{#if has_replacement}}
Edit text that you see surrounded with <edit_here>...</edit_here> tags. It's CRITICAL that you DO NOT include <edit_here> tags in your output.
{{/if}}
Make no changes to the rewritten content outside these tags.
<snippet language="{{ language_name }}" annotated="true"> Generate {{content_type}} based on the following prompt:
{{{ rewrite_section_prefix }}}
<rewrite_this>
{{{ rewrite_section_with_edits }}}
</rewrite_this>
{{{ rewrite_section_suffix }}}
</snippet>
Rewrite the lines enclosed within the <rewrite_this></rewrite_this> tags in accordance with the provided instructions and the prompt below.
<prompt> <prompt>
{{{user_prompt}}} {{{user_prompt}}}
</prompt> </prompt>
Do not include <insert_here> or <edit_here> annotations in your output. Here is a clean copy of the snippet without annotations for your reference. Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
<snippet> Immediately start with the following format with no remarks:
{{{ rewrite_section_prefix }}}
```
\{{INSERTED_CODE}}
```
{{else}}
Edit the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
<prompt>
{{{user_prompt}}}
</prompt>
{{#if rewrite_section}}
And here's the section to rewrite based on that prompt again for reference:
<rewrite_this>
{{{rewrite_section}}} {{{rewrite_section}}}
{{{ rewrite_section_suffix }}} </rewrite_this>
</snippet> {{/if}}
</instructions>
<guidelines_reminder> Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
1. Focus on necessary changes: Modify only what's required to fulfill the prompt.
2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow. Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
3. Exclude annotation tags: Do not output <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags.
4. Maintain indentation: Begin at the original file's indentation level.
5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed.
6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output.
7. Respect content boundaries: Preserve code integrity.
</guidelines_reminder>
Immediately start with the following format with no remarks: Immediately start with the following format with no remarks:
``` ```
\{{REWRITTEN_CODE}} \{{REWRITTEN_CODE}}
``` ```
{{/if}}

View file

@ -34,7 +34,6 @@ use language_model::{
}; };
pub(crate) use model_selector::*; pub(crate) use model_selector::*;
pub use prompts::PromptBuilder; pub use prompts::PromptBuilder;
use prompts::PromptOverrideContext;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore}; use settings::{update_settings_file, Settings, SettingsStore};
@ -181,12 +180,7 @@ impl Assistant {
} }
} }
pub fn init( pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
fs: Arc<dyn Fs>,
client: Arc<Client>,
dev_mode: bool,
cx: &mut AppContext,
) -> Arc<PromptBuilder> {
cx.set_global(Assistant::default()); cx.set_global(Assistant::default());
AssistantSettings::register(cx); AssistantSettings::register(cx);
SlashCommandSettings::register(cx); SlashCommandSettings::register(cx);
@ -223,11 +217,7 @@ pub fn init(
assistant_panel::init(cx); assistant_panel::init(cx);
context_servers::init(cx); context_servers::init(cx);
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext { let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
dev_mode,
fs: fs.clone(),
cx,
}))
.log_err() .log_err()
.map(Arc::new) .map(Arc::new)
.unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));

View file

@ -28,7 +28,7 @@ use gpui::{
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
UpdateGlobal, View, ViewContext, WeakView, WindowContext, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
}; };
use language::{Buffer, IndentKind, Point, TransactionId}; use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{ use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
}; };
@ -38,6 +38,7 @@ use rope::Rope;
use settings::Settings; use settings::Settings;
use smol::future::FutureExt; use smol::future::FutureExt;
use std::{ use std::{
cmp,
future::{self, Future}, future::{self, Future},
mem, mem,
ops::{Range, RangeInclusive}, ops::{Range, RangeInclusive},
@ -46,7 +47,6 @@ use std::{
task::{self, Poll}, task::{self, Poll},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use text::OffsetRangeExt as _;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip}; use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt}; use util::{RangeExt, ResultExt};
@ -140,81 +140,66 @@ impl InlineAssistant {
cx: &mut WindowContext, cx: &mut WindowContext,
) { ) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
struct CodegenRange {
transform_range: Range<Point>, let mut selections = Vec::<Selection<Point>>::new();
selection_ranges: Vec<Range<Point>>, let mut newest_selection = None;
focus_assist: bool, for mut selection in editor.read(cx).selections.all::<Point>(cx) {
if selection.end > selection.start {
selection.start.column = 0;
// If the selection ends at the start of the line, we don't want to include it.
if selection.end.column == 0 {
selection.end.row -= 1;
}
selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
} }
let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range(); if let Some(prev_selection) = selections.last_mut() {
let mut codegen_ranges: Vec<CodegenRange> = Vec::new(); if selection.start <= prev_selection.end {
prev_selection.end = selection.end;
let selection_ranges = snapshot
.split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
.map(|range| range.to_point(&snapshot))
.collect::<Vec<Range<Point>>>();
for selection_range in selection_ranges {
let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
let mut transform_range = selection_range.start..selection_range.end;
// Expand the transform range to start/end of lines.
// If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
transform_range.start.column = 0;
if transform_range.end.column == 0 && transform_range.end > transform_range.start {
transform_range.end.row -= 1;
}
transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
let selection_range =
selection_range.start..selection_range.end.min(transform_range.end);
// If we intersect the previous transform range,
if let Some(CodegenRange {
transform_range: prev_transform_range,
selection_ranges,
focus_assist,
}) = codegen_ranges.last_mut()
{
if transform_range.start <= prev_transform_range.end {
prev_transform_range.end = transform_range.end;
selection_ranges.push(selection_range);
*focus_assist |= selection_is_newest;
continue; continue;
} }
} }
codegen_ranges.push(CodegenRange { let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
transform_range, if selection.id > latest_selection.id {
selection_ranges: vec![selection_range], *latest_selection = selection.clone();
focus_assist: selection_is_newest, }
}) selections.push(selection);
}
let newest_selection = newest_selection.unwrap();
let mut codegen_ranges = Vec::new();
for (excerpt_id, buffer, buffer_range) in
snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
}))
{
let start = Anchor {
buffer_id: Some(buffer.remote_id()),
excerpt_id,
text_anchor: buffer.anchor_before(buffer_range.start),
};
let end = Anchor {
buffer_id: Some(buffer.remote_id()),
excerpt_id,
text_anchor: buffer.anchor_after(buffer_range.end),
};
codegen_ranges.push(start..end);
} }
let assist_group_id = self.next_assist_group_id.post_inc(); let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer = let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)); cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx)); let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
let mut assists = Vec::new(); let mut assists = Vec::new();
let mut assist_to_focus = None; let mut assist_to_focus = None;
for range in codegen_ranges {
for CodegenRange { let assist_id = self.next_assist_id.post_inc();
transform_range,
selection_ranges,
focus_assist,
} in codegen_ranges
{
let transform_range = snapshot.anchor_before(transform_range.start)
..snapshot.anchor_after(transform_range.end);
let selection_ranges = selection_ranges
.iter()
.map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
.collect::<Vec<_>>();
let codegen = cx.new_model(|cx| { let codegen = cx.new_model(|cx| {
Codegen::new( Codegen::new(
editor.read(cx).buffer().clone(), editor.read(cx).buffer().clone(),
transform_range.clone(), range.clone(),
selection_ranges,
None, None,
self.telemetry.clone(), self.telemetry.clone(),
self.prompt_builder.clone(), self.prompt_builder.clone(),
@ -222,7 +207,6 @@ impl InlineAssistant {
) )
}); });
let assist_id = self.next_assist_id.post_inc();
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| { let prompt_editor = cx.new_view(|cx| {
PromptEditor::new( PromptEditor::new(
@ -239,16 +223,23 @@ impl InlineAssistant {
) )
}); });
if assist_to_focus.is_none() {
let focus_assist = if newest_selection.reversed {
range.start.to_point(&snapshot) == newest_selection.start
} else {
range.end.to_point(&snapshot) == newest_selection.end
};
if focus_assist { if focus_assist {
assist_to_focus = Some(assist_id); assist_to_focus = Some(assist_id);
} }
}
let [prompt_block_id, end_block_id] = let [prompt_block_id, end_block_id] =
self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx); self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push(( assists.push((
assist_id, assist_id,
transform_range, range,
prompt_editor, prompt_editor,
prompt_block_id, prompt_block_id,
end_block_id, end_block_id,
@ -315,7 +306,6 @@ impl InlineAssistant {
Codegen::new( Codegen::new(
editor.read(cx).buffer().clone(), editor.read(cx).buffer().clone(),
range.clone(), range.clone(),
vec![range.clone()],
initial_transaction_id, initial_transaction_id,
self.telemetry.clone(), self.telemetry.clone(),
self.prompt_builder.clone(), self.prompt_builder.clone(),
@ -925,7 +915,12 @@ impl InlineAssistant {
assist assist
.codegen .codegen
.update(cx, |codegen, cx| { .update(cx, |codegen, cx| {
codegen.start(user_prompt, assistant_panel_context, cx) codegen.start(
assist.range.clone(),
user_prompt,
assistant_panel_context,
cx,
)
}) })
.log_err(); .log_err();
@ -2120,9 +2115,12 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed(); return future::ready(Err(anyhow!("no user prompt"))).boxed();
}; };
let assistant_panel_context = self.assistant_panel_context(cx); let assistant_panel_context = self.assistant_panel_context(cx);
self.codegen self.codegen.read(cx).count_tokens(
.read(cx) self.range.clone(),
.count_tokens(user_prompt, assistant_panel_context, cx) user_prompt,
assistant_panel_context,
cx,
)
} }
} }
@ -2143,8 +2141,6 @@ pub struct Codegen {
buffer: Model<MultiBuffer>, buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>, old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot, snapshot: MultiBufferSnapshot,
transform_range: Range<Anchor>,
selected_ranges: Vec<Range<Anchor>>,
edit_position: Option<Anchor>, edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>, last_equal_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>, initial_transaction_id: Option<TransactionId>,
@ -2154,7 +2150,7 @@ pub struct Codegen {
diff: Diff, diff: Diff,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription, _subscription: gpui::Subscription,
prompt_builder: Arc<PromptBuilder>, builder: Arc<PromptBuilder>,
} }
enum CodegenStatus { enum CodegenStatus {
@ -2181,8 +2177,7 @@ impl EventEmitter<CodegenEvent> for Codegen {}
impl Codegen { impl Codegen {
pub fn new( pub fn new(
buffer: Model<MultiBuffer>, buffer: Model<MultiBuffer>,
transform_range: Range<Anchor>, range: Range<Anchor>,
selected_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>, initial_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>, telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>, builder: Arc<PromptBuilder>,
@ -2192,7 +2187,7 @@ impl Codegen {
let (old_buffer, _, _) = buffer let (old_buffer, _, _) = buffer
.read(cx) .read(cx)
.range_to_buffer_ranges(transform_range.clone(), cx) .range_to_buffer_ranges(range.clone(), cx)
.pop() .pop()
.unwrap(); .unwrap();
let old_buffer = cx.new_model(|cx| { let old_buffer = cx.new_model(|cx| {
@ -2223,9 +2218,7 @@ impl Codegen {
telemetry, telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event), _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id, initial_transaction_id,
prompt_builder: builder, builder,
transform_range,
selected_ranges,
} }
} }
@ -2250,12 +2243,14 @@ impl Codegen {
pub fn count_tokens( pub fn count_tokens(
&self, &self,
edit_range: Range<Anchor>,
user_prompt: String, user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext, cx: &AppContext,
) -> BoxFuture<'static, Result<TokenCounts>> { ) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); let request =
self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
match request { match request {
Ok(request) => { Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx); let total_count = model.count_tokens(request.clone(), cx);
@ -2280,6 +2275,7 @@ impl Codegen {
pub fn start( pub fn start(
&mut self, &mut self,
edit_range: Range<Anchor>,
user_prompt: String, user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
@ -2294,20 +2290,24 @@ impl Codegen {
}); });
} }
self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot)); self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id(); let telemetry_id = model.telemetry_id();
let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
if user_prompt.trim().to_lowercase() == "delete" { .trim()
.to_lowercase()
== "delete"
{
async { Ok(stream::empty().boxed()) }.boxed_local() async { Ok(stream::empty().boxed()) }.boxed_local()
} else { } else {
let request = self.build_request(user_prompt, assistant_panel_context, cx)?; let request =
self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
let chunks = let chunks =
cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
async move { Ok(chunks.await?.boxed()) }.boxed_local() async move { Ok(chunks.await?.boxed()) }.boxed_local()
}; };
self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx); self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(()) Ok(())
} }
@ -2315,10 +2315,11 @@ impl Codegen {
&self, &self,
user_prompt: String, user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>, assistant_panel_context: Option<LanguageModelRequest>,
edit_range: Range<Anchor>,
cx: &AppContext, cx: &AppContext,
) -> Result<LanguageModelRequest> { ) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx); let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.transform_range.start); let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() { let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) { if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None None
@ -2343,9 +2344,9 @@ impl Codegen {
}; };
let language_name = language_name.as_deref(); let language_name = language_name.as_deref();
let start = buffer.point_to_buffer_offset(self.transform_range.start); let start = buffer.point_to_buffer_offset(edit_range.start);
let end = buffer.point_to_buffer_offset(self.transform_range.end); let end = buffer.point_to_buffer_offset(edit_range.end);
let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) { let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start; let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end; let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() { if start_buffer.remote_id() == end_buffer.remote_id() {
@ -2357,39 +2358,9 @@ impl Codegen {
return Err(anyhow::anyhow!("invalid transformation range")); return Err(anyhow::anyhow!("invalid transformation range"));
}; };
let mut transform_context_range = transform_range.to_point(&transform_buffer);
transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
transform_context_range.start.column = 0;
transform_context_range.end =
(transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
transform_context_range.end.column =
transform_buffer.line_len(transform_context_range.end.row);
let transform_context_range = transform_context_range.to_offset(&transform_buffer);
let selected_ranges = self
.selected_ranges
.iter()
.filter_map(|selected_range| {
let start = buffer
.point_to_buffer_offset(selected_range.start)
.map(|(_, offset)| offset)?;
let end = buffer
.point_to_buffer_offset(selected_range.end)
.map(|(_, offset)| offset)?;
Some(start..end)
})
.collect::<Vec<_>>();
let prompt = self let prompt = self
.prompt_builder .builder
.generate_content_prompt( .generate_content_prompt(user_prompt, language_name, buffer, range)
user_prompt,
language_name,
transform_buffer,
transform_range,
selected_ranges,
transform_context_range,
)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new(); let mut messages = Vec::new();
@ -2462,19 +2433,84 @@ impl Codegen {
let mut diff = StreamingDiff::new(selected_text.to_string()); let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default(); let mut line_diff = LineDiff::default();
let mut new_text = String::new();
let mut base_indent = None;
let mut line_indent = None;
let mut first_line = true;
while let Some(chunk) = chunks.next().await { while let Some(chunk) = chunks.next().await {
if response_latency.is_none() { if response_latency.is_none() {
response_latency = Some(request_start.elapsed()); response_latency = Some(request_start.elapsed());
} }
let chunk = chunk?; let chunk = chunk?;
let char_ops = diff.push_new(&chunk);
line_diff.push_char_operations(&char_ops, &selected_text); let mut lines = chunk.split('\n').peekable();
while let Some(line) = lines.next() {
new_text.push_str(line);
if line_indent.is_none() {
if let Some(non_whitespace_ch_ix) =
new_text.find(|ch: char| !ch.is_whitespace())
{
line_indent = Some(non_whitespace_ch_ix);
base_indent = base_indent.or(line_indent);
let line_indent = line_indent.unwrap();
let base_indent = base_indent.unwrap();
let indent_delta =
line_indent as i32 - base_indent as i32;
let mut corrected_indent_len = cmp::max(
0,
suggested_line_indent.len as i32 + indent_delta,
)
as usize;
if first_line {
corrected_indent_len = corrected_indent_len
.saturating_sub(
selection_start.column as usize,
);
}
let indent_char = suggested_line_indent.char();
let mut indent_buffer = [0; 4];
let indent_str =
indent_char.encode_utf8(&mut indent_buffer);
new_text.replace_range(
..line_indent,
&indent_str.repeat(corrected_indent_len),
);
}
}
if line_indent.is_some() {
let char_ops = diff.push_new(&new_text);
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx diff_tx
.send((char_ops, line_diff.line_operations())) .send((char_ops, line_diff.line_operations()))
.await?; .await?;
new_text.clear();
} }
let char_ops = diff.finish(); if lines.peek().is_some() {
let char_ops = diff.push_new("\n");
line_diff
.push_char_operations(&char_ops, &selected_text);
diff_tx
.send((char_ops, line_diff.line_operations()))
.await?;
if line_indent.is_none() {
// Don't write out the leading indentation in empty lines on the next line
// This is the case where the above if statement didn't clear the buffer
new_text.clear();
}
line_indent = None;
first_line = false;
}
}
}
let mut char_ops = diff.push_new(&new_text);
char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text); line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text); line_diff.finish(&selected_text);
diff_tx diff_tx
@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests { mod tests {
use super::*; use super::*;
use futures::stream::{self}; use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
use language::{
language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
Point,
};
use language_model::LanguageModelRegistry;
use rand::prelude::*;
use serde::Serialize; use serde::Serialize;
use settings::SettingsStore;
use std::{future, sync::Arc};
#[derive(Serialize)] #[derive(Serialize)]
pub struct DummyCompletionRequest { pub struct DummyCompletionRequest {
pub name: String, pub name: String,
} }
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_model::LanguageModelRegistry::test);
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
let x = 0;
for _ in 0..10 {
x += 1;
}
}
"};
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range,
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let mut new_text = concat!(
" let mut x = 0;\n",
" while x < 10 {\n",
" x += 1;\n",
" }",
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_past_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
fn main() {
le
}
"};
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_when_generating_before_indentation(
cx: &mut TestAppContext,
mut rng: StdRng,
) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = concat!(
"fn main() {\n",
" \n",
"}\n" //
);
let buffer =
cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
" x += 1;\n",
"}", //
);
while !new_text.is_empty() {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
chunks_tx.unbounded_send(chunk.to_string()).unwrap();
new_text = suffix;
cx.background_executor.run_until_parked();
}
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
fn main() {
let mut x = 0;
while x < 10 {
x += 1;
}
}
"}
);
}
#[gpui::test(iterations = 10)]
async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
cx.update(LanguageModelRegistry::test);
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
let text = indoc! {"
func main() {
\tx := 0
\tfor i := 0; i < 10; i++ {
\t\tx++
\t}
}
"};
let buffer = cx.new_model(|cx| Buffer::local(text, cx));
let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let codegen = cx.new_model(|cx| {
Codegen::new(
buffer.clone(),
range.clone(),
None,
None,
prompt_builder,
cx,
)
});
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
String::new(),
range.clone(),
future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
cx,
)
});
let new_text = concat!(
"func main() {\n",
"\tx := 0\n",
"\tfor x < 10 {\n",
"\t\tx++\n",
"\t}", //
);
chunks_tx.unbounded_send(new_text.to_string()).unwrap();
drop(chunks_tx);
cx.background_executor.run_until_parked();
assert_eq!(
buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
indoc! {"
func main() {
\tx := 0
\tfor x < 10 {
\t\tx++
\t}
}
"}
);
}
#[gpui::test] #[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() { async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@ -2984,4 +3318,27 @@ mod tests {
) )
} }
} }
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::language()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
} }

View file

@ -12,15 +12,11 @@ use util::ResultExt;
pub struct ContentPromptContext { pub struct ContentPromptContext {
pub content_type: String, pub content_type: String,
pub language_name: Option<String>, pub language_name: Option<String>,
pub is_insert: bool,
pub is_truncated: bool, pub is_truncated: bool,
pub document_content: String, pub document_content: String,
pub user_prompt: String, pub user_prompt: String,
pub rewrite_section: String, pub rewrite_section: Option<String>,
pub rewrite_section_prefix: String,
pub rewrite_section_suffix: String,
pub rewrite_section_with_edits: String,
pub has_insertion: bool,
pub has_replacement: bool,
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -46,54 +42,41 @@ pub struct PromptBuilder {
handlebars: Arc<Mutex<Handlebars<'static>>>, handlebars: Arc<Mutex<Handlebars<'static>>>,
} }
pub struct PromptOverrideContext<'a> {
pub dev_mode: bool,
pub fs: Arc<dyn Fs>,
pub cx: &'a mut gpui::AppContext,
}
impl PromptBuilder { impl PromptBuilder {
pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> { pub fn new(
fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
) -> Result<Self, Box<TemplateError>> {
let mut handlebars = Handlebars::new(); let mut handlebars = Handlebars::new();
Self::register_templates(&mut handlebars)?; Self::register_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars)); let handlebars = Arc::new(Mutex::new(handlebars));
if let Some(override_cx) = override_cx { if let Some((fs, cx)) = fs_and_cx {
Self::watch_fs_for_template_overrides(override_cx, handlebars.clone()); Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
} }
Ok(Self { handlebars }) Ok(Self { handlebars })
} }
fn watch_fs_for_template_overrides( fn watch_fs_for_template_overrides(
PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext, fs: Arc<dyn Fs>,
cx: &gpui::AppContext,
handlebars: Arc<Mutex<Handlebars<'static>>>, handlebars: Arc<Mutex<Handlebars<'static>>>,
) { ) {
let templates_dir = paths::prompt_overrides_dir();
cx.background_executor() cx.background_executor()
.spawn(async move { .spawn(async move {
let templates_dir = if dev_mode {
std::env::current_dir()
.ok()
.and_then(|pwd| {
let pwd_assets_prompts = pwd.join("assets").join("prompts");
pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
})
.unwrap_or_else(|| paths::prompt_overrides_dir().clone())
} else {
paths::prompt_overrides_dir().clone()
};
// Create the prompt templates directory if it doesn't exist // Create the prompt templates directory if it doesn't exist
if !fs.is_dir(&templates_dir).await { if !fs.is_dir(templates_dir).await {
if let Err(e) = fs.create_dir(&templates_dir).await { if let Err(e) = fs.create_dir(templates_dir).await {
log::error!("Failed to create prompt templates directory: {}", e); log::error!("Failed to create prompt templates directory: {}", e);
return; return;
} }
} }
// Initial scan of the prompts directory // Initial scan of the prompts directory
if let Ok(mut entries) = fs.read_dir(&templates_dir).await { if let Ok(mut entries) = fs.read_dir(templates_dir).await {
while let Some(Ok(file_path)) = entries.next().await { while let Some(Ok(file_path)) = entries.next().await {
if file_path.to_string_lossy().ends_with(".hbs") { if file_path.to_string_lossy().ends_with(".hbs") {
if let Ok(content) = fs.load(&file_path).await { if let Ok(content) = fs.load(&file_path).await {
@ -121,7 +104,7 @@ impl PromptBuilder {
} }
// Watch for changes // Watch for changes
let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await; let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
while let Some(changed_paths) = changes.next().await { while let Some(changed_paths) = changes.next().await {
for changed_path in changed_paths { for changed_path in changed_paths {
if changed_path.extension().map_or(false, |ext| ext == "hbs") { if changed_path.extension().map_or(false, |ext| ext == "hbs") {
@ -173,9 +156,7 @@ impl PromptBuilder {
user_prompt: String, user_prompt: String,
language_name: Option<&str>, language_name: Option<&str>,
buffer: BufferSnapshot, buffer: BufferSnapshot,
transform_range: Range<usize>, range: Range<usize>,
selected_ranges: Vec<Range<usize>>,
transform_context_range: Range<usize>,
) -> Result<String, RenderError> { ) -> Result<String, RenderError> {
let content_type = match language_name { let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text", None | Some("Markdown" | "Plain Text") => "text",
@ -183,20 +164,21 @@ impl PromptBuilder {
}; };
const MAX_CTX: usize = 50000; const MAX_CTX: usize = 50000;
let is_insert = range.is_empty();
let mut is_truncated = false; let mut is_truncated = false;
let before_range = 0..transform_range.start; let before_range = 0..range.start;
let truncated_before = if before_range.len() > MAX_CTX { let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true; is_truncated = true;
transform_range.start - MAX_CTX..transform_range.start range.start - MAX_CTX..range.start
} else { } else {
before_range before_range
}; };
let after_range = transform_range.end..buffer.len(); let after_range = range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX { let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true; is_truncated = true;
transform_range.end..transform_range.end + MAX_CTX range.end..range.end + MAX_CTX
} else { } else {
after_range after_range
}; };
@ -205,74 +187,37 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) { for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk); document_content.push_str(chunk);
} }
if is_insert {
document_content.push_str("<insert_here></insert_here>");
} else {
document_content.push_str("<rewrite_this>\n"); document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(transform_range.clone()) { for chunk in buffer.text_for_range(range.clone()) {
document_content.push_str(chunk); document_content.push_str(chunk);
} }
document_content.push_str("\n</rewrite_this>"); document_content.push_str("\n</rewrite_this>");
}
for chunk in buffer.text_for_range(truncated_after) { for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk); document_content.push_str(chunk);
} }
let mut rewrite_section = String::new(); let rewrite_section = if !is_insert {
for chunk in buffer.text_for_range(transform_range.clone()) { let mut section = String::new();
rewrite_section.push_str(chunk); for chunk in buffer.text_for_range(range.clone()) {
section.push_str(chunk);
} }
Some(section)
let mut rewrite_section_prefix = String::new();
for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
rewrite_section_prefix.push_str(chunk);
}
let mut rewrite_section_suffix = String::new();
for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
rewrite_section_suffix.push_str(chunk);
}
let rewrite_section_with_edits = {
let mut section_with_selections = String::new();
let mut last_end = 0;
for selected_range in &selected_ranges {
if selected_range.start > last_end {
section_with_selections.push_str(
&rewrite_section[last_end..selected_range.start - transform_range.start],
);
}
if selected_range.start == selected_range.end {
section_with_selections.push_str("<insert_here></insert_here>");
} else { } else {
section_with_selections.push_str("<edit_here>"); None
section_with_selections.push_str(
&rewrite_section[selected_range.start - transform_range.start
..selected_range.end - transform_range.start],
);
section_with_selections.push_str("</edit_here>");
}
last_end = selected_range.end - transform_range.start;
}
if last_end < rewrite_section.len() {
section_with_selections.push_str(&rewrite_section[last_end..]);
}
section_with_selections
}; };
let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
let context = ContentPromptContext { let context = ContentPromptContext {
content_type: content_type.to_string(), content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()), language_name: language_name.map(|s| s.to_string()),
is_insert,
is_truncated, is_truncated,
document_content, document_content,
user_prompt, user_prompt,
rewrite_section, rewrite_section,
rewrite_section_prefix,
rewrite_section_suffix,
rewrite_section_with_edits,
has_insertion,
has_replacement,
}; };
self.handlebars.lock().render("content_prompt", &context) self.handlebars.lock().render("content_prompt", &context)

View file

@ -187,12 +187,7 @@ fn init_common(app_state: Arc<AppState>, cx: &mut AppContext) -> Arc<PromptBuild
); );
snippet_provider::init(cx); snippet_provider::init(cx);
inline_completion_registry::init(app_state.client.telemetry().clone(), cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx);
let prompt_builder = assistant::init( let prompt_builder = assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
app_state.fs.clone(),
app_state.client.clone(),
stdout_is_a_pty(),
cx,
);
repl::init( repl::init(
app_state.fs.clone(), app_state.fs.clone(),
app_state.client.telemetry().clone(), app_state.client.telemetry().clone(),

View file

@ -1018,8 +1018,6 @@ fn open_settings_file(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::stdout_is_a_pty;
use super::*; use super::*;
use anyhow::anyhow; use anyhow::anyhow;
use assets::Assets; use assets::Assets;
@ -3487,12 +3485,8 @@ mod tests {
app_state.fs.clone(), app_state.fs.clone(),
cx, cx,
); );
let prompt_builder = assistant::init( let prompt_builder =
app_state.fs.clone(), assistant::init(app_state.fs.clone(), app_state.client.clone(), cx);
app_state.client.clone(),
stdout_is_a_pty(),
cx,
);
repl::init( repl::init(
app_state.fs.clone(), app_state.fs.clone(),
app_state.client.telemetry().clone(), app_state.client.telemetry().clone(),