diff --git a/assets/prompts/content_prompt.hbs b/assets/prompts/content_prompt.hbs index cd618a6761..cf4141349b 100644 --- a/assets/prompts/content_prompt.hbs +++ b/assets/prompts/content_prompt.hbs @@ -1,426 +1,61 @@ -You are an expert developer assistant working in an AI-enabled text editor. -Your task is to rewrite a specific section of the provided document based on a user-provided prompt. +{{#if language_name}} +Here's a file of {{language_name}} that I'm going to ask you to make an edit to. +{{else}} +Here's a file of text that I'm going to ask you to make an edit to. +{{/if}} - -1. Scope: Modify only content within tags. Do not alter anything outside these boundaries. -2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is. -3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure. -4. Tag exclusion: Never include , , , or tags in the output. -5. Indentation: Maintain the original indentation level of the file in rewritten sections. -6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions. -7. Insertions: Replace tags with appropriate content as specified by the prompt. -8. Code integrity: Respect existing code structure and functionality when making changes. -9. Consistency: Maintain a uniform style and tone throughout the rewritten text. - +{{#if is_insert}} +The point you'll need to insert at is marked with . +{{else}} +The section you'll need to rewrite is marked with tags. +{{/if}} - - - -use std::cell::Cell; -use std::collections::HashMap; -use std::cmp; - - - - -pub struct LruCache { - /// The maximum number of items the cache can hold. - capacity: usize, - /// The map storing the cached items. - items: HashMap, -} - -// The rest of the implementation... - - -doc this - - - - -/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure. -/// -/// This structure is used for efficient spatial queries and collision detection. -/// It organizes objects in a hierarchical tree structure based on their bounding boxes. -/// -/// # Type Parameters -/// -/// * `T`: The type of data associated with each node in the tree. -pub struct AabbTree { - root: Option, - - -/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure. -/// -/// This structure is used for efficient spatial queries and collision detection. -/// It organizes objects in a hierarchical tree structure based on their bounding boxes. -/// -/// # Type Parameters -/// -/// * `T`: The type of data associated with each node in the tree. - - - - - - -import math - -def calculate_circle_area(radius): - """Calculate the area of a circle given its radius.""" - return math.pi * radius ** 2 - - - - -class Circle: - def __init__(self, radius): - self.radius = radius - - def area(self): - return math.pi * self.radius ** 2 - - def circumference(self): - return 2 * math.pi * self.radius - -# Usage example -circle = Circle(5) -print(f"Area: {circle.area():.2f}") -print(f"Circumference: {circle.circumference():.2f}") - - -write docs - - - - -""" -Represents a circle with methods to calculate its area and circumference. - -This class provides a simple way to work with circles in a geometric context. -It allows for the creation of Circle objects with a specified radius and -offers methods to compute the circle's area and circumference. - -Attributes: - radius (float): The radius of the circle. - -Methods: - area(): Calculates and returns the area of the circle. - circumference(): Calculates and returns the circumference of the circle. -""" -class Circle: - - -""" -Represents a circle with methods to calculate its area and circumference. - -This class provides a simple way to work with circles in a geometric context. -It allows for the creation of Circle objects with a specified radius and -offers methods to compute the circle's area and circumference. - -Attributes: - radius (float): The radius of the circle. - -Methods: - area(): Calculates and returns the area of the circle. - circumference(): Calculates and returns the circumference of the circle. -""" - - - - - - -class BankAccount { - private balance: number; - - constructor(initialBalance: number) { - this.balance = initialBalance; - } - - - - - deposit(amount: number): void { - if (amount > 0) { - this.balance += amount; - } - } - - withdraw(amount: number): boolean { - if (amount > 0 && this.balance >= amount) { - this.balance -= amount; - return true; - } - return false; - } - - getBalance(): number { - return this.balance; - } -} - -// Usage -const account = new BankAccount(1000); -account.deposit(500); -console.log(account.getBalance()); // 1500 -account.withdraw(200); -console.log(account.getBalance()); // 1300 - - -// - - - - - /** - * Deposits the specified amount into the bank account. - * - * @param amount The amount to deposit. Must be a positive number. - * @throws Error if the amount is not positive. - */ - deposit(amount: number): void { - if (amount > 0) { - this.balance += amount; - } else { - throw new Error("Deposit amount must be positive"); - } - } - - - /** - * Deposits the specified amount into the bank account. - * - * @param amount The amount to deposit. Must be a positive number. - * @throws Error if the amount is not positive. - */ - - - - - - -use std::collections::VecDeque; - -pub struct BinaryTree { - root: Option>, -} - - - - -struct Node { - value: T, - left: Option>>, - right: Option>>, -} - - -derive clone - - - - -#[derive(Clone)] - -struct Node { - value: T, - left: Option>>, - right: Option>>, -} - - - -pub struct BinaryTree { - root: Option>, -} - -#[derive(Clone)] - - - -#[derive(Clone)] -struct Node { - value: T, - left: Option>>, - right: Option>>, -} - -impl Node { - fn new(value: T) -> Self { - Node { - value, - left: None, - right: None, - } - } -} - - -#[derive(Clone)] - - - - - - -import math - -def calculate_circle_area(radius): - """Calculate the area of a circle given its radius.""" - return math.pi * radius ** 2 - - - - -class Circle: - def __init__(self, radius): - self.radius = radius - - def area(self): - return math.pi * self.radius ** 2 - - def circumference(self): - return 2 * math.pi * self.radius - -# Usage example -circle = Circle(5) -print(f"Area: {circle.area():.2f}") -print(f"Circumference: {circle.circumference():.2f}") - - -add dataclass decorator - - - - -@dataclass -class Circle: - radius: float - - def __init__(self, radius): - self.radius = radius - - def area(self): - return math.pi * self.radius ** 2 - - -@dataclass - - - - - - -interface ShoppingCart { - items: string[]; - total: number; -} - - -class ShoppingCartManager { - - private cart: ShoppingCart; - - constructor() { - this.cart = { items: [], total: 0 }; - } - - addItem(item: string, price: number): void { - this.cart.items.push(item); - this.cart.total += price; - } - - getTotal(): number { - return this.cart.total; - } -} - -// Usage -const manager = new ShoppingCartManager(); -manager.addItem("Book", 15.99); -console.log(manager.getTotal()); // 15.99 - - -add readonly modifier - - - - -readonly interface ShoppingCart { - items: string[]; - total: number; -} - -class ShoppingCartManager { - private readonly cart: ShoppingCart; - - constructor() { - this.cart = { items: [], total: 0 }; - } - - -readonly interface ShoppingCart { - - - - - -With these examples in mind, edit the following file: - - -{{{ document_content }}} +{{{document_content}}} {{#if is_truncated}} -The provided document has been truncated (potentially mid-line) for brevity. +The context around the relevant section has been truncated (possibly in the middle of a line) for brevity. {{/if}} - -{{#if has_insertion}} -Insert text anywhere you see marked with tags. It's CRITICAL that you DO NOT include tags in your output. -{{/if}} -{{#if has_replacement}} -Edit text that you see surrounded with ... tags. It's CRITICAL that you DO NOT include tags in your output. -{{/if}} -Make no changes to the rewritten content outside these tags. +{{#if is_insert}} +You can't replace {{content_type}}, your answer will be inserted in place of the `` tags. Don't include the insert_here tags in your output. - -{{{ rewrite_section_prefix }}} - -{{{ rewrite_section_with_edits }}} - -{{{ rewrite_section_suffix }}} - - -Rewrite the lines enclosed within the tags in accordance with the provided instructions and the prompt below. +Generate {{content_type}} based on the following prompt: -{{{ user_prompt }}} +{{{user_prompt}}} -Do not include or annotations in your output. Here is a clean copy of the snippet without annotations for your reference. +Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines. - -{{{ rewrite_section_prefix }}} -{{{ rewrite_section }}} -{{{ rewrite_section_suffix }}} - - +Immediately start with the following format with no remarks: - -1. Focus on necessary changes: Modify only what's required to fulfill the prompt. -2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow. -3. Exclude annotation tags: Do not output , , , or tags. -4. Maintain indentation: Begin at the original file's indentation level. -5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed. -6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output. -7. Respect content boundaries: Preserve code integrity. - +``` +\{{INSERTED_CODE}} +``` +{{else}} +Edit the section of {{content_type}} in tags based on the following prompt: + + +{{{user_prompt}}} + + +{{#if rewrite_section}} +And here's the section to rewrite based on that prompt again for reference: + + +{{{rewrite_section}}} + +{{/if}} + +Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. + +Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions. Immediately start with the following format with no remarks: ``` \{{REWRITTEN_CODE}} ``` +{{/if}} diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 085358385f..dbe497bafb 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -34,7 +34,6 @@ use language_model::{ }; pub(crate) use model_selector::*; pub use prompts::PromptBuilder; -use prompts::PromptOverrideContext; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; use settings::{update_settings_file, Settings, SettingsStore}; @@ -181,12 +180,7 @@ impl Assistant { } } -pub fn init( - fs: Arc, - client: Arc, - dev_mode: bool, - cx: &mut AppContext, -) -> Arc { +pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) -> Arc { cx.set_global(Assistant::default()); AssistantSettings::register(cx); SlashCommandSettings::register(cx); @@ -223,14 +217,10 @@ pub fn init( assistant_panel::init(cx); context_servers::init(cx); - let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext { - dev_mode, - fs: fs.clone(), - cx, - })) - .log_err() - .map(Arc::new) - .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); + let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx))) + .log_err() + .map(Arc::new) + .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap())); register_slash_commands(Some(prompt_builder.clone()), cx); inline_assistant::init( fs.clone(), diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 533107c1d5..dbb750f512 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -28,7 +28,7 @@ use gpui::{ FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext, }; -use language::{Buffer, IndentKind, Point, TransactionId}; +use language::{Buffer, IndentKind, Point, Selection, TransactionId}; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; @@ -38,6 +38,7 @@ use rope::Rope; use settings::Settings; use smol::future::FutureExt; use std::{ + cmp, future::{self, Future}, mem, ops::{Range, RangeInclusive}, @@ -46,7 +47,6 @@ use std::{ task::{self, Poll}, time::{Duration, Instant}, }; -use text::OffsetRangeExt as _; use theme::ThemeSettings; use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip}; use util::{RangeExt, ResultExt}; @@ -140,81 +140,66 @@ impl InlineAssistant { cx: &mut WindowContext, ) { let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); - struct CodegenRange { - transform_range: Range, - selection_ranges: Vec>, - focus_assist: bool, - } - let newest_selection_range = editor.read(cx).selections.newest::(cx).range(); - let mut codegen_ranges: Vec = Vec::new(); - - let selection_ranges = snapshot - .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges()) - .map(|range| range.to_point(&snapshot)) - .collect::>>(); - - for selection_range in selection_ranges { - let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range); - let mut transform_range = selection_range.start..selection_range.end; - - // Expand the transform range to start/end of lines. - // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line. - transform_range.start.column = 0; - if transform_range.end.column == 0 && transform_range.end > transform_range.start { - transform_range.end.row -= 1; + let mut selections = Vec::>::new(); + let mut newest_selection = None; + for mut selection in editor.read(cx).selections.all::(cx) { + if selection.end > selection.start { + selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if selection.end.column == 0 { + selection.end.row -= 1; + } + selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row)); } - transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row)); - let selection_range = - selection_range.start..selection_range.end.min(transform_range.end); - // If we intersect the previous transform range, - if let Some(CodegenRange { - transform_range: prev_transform_range, - selection_ranges, - focus_assist, - }) = codegen_ranges.last_mut() - { - if transform_range.start <= prev_transform_range.end { - prev_transform_range.end = transform_range.end; - selection_ranges.push(selection_range); - *focus_assist |= selection_is_newest; + if let Some(prev_selection) = selections.last_mut() { + if selection.start <= prev_selection.end { + prev_selection.end = selection.end; continue; } } - codegen_ranges.push(CodegenRange { - transform_range, - selection_ranges: vec![selection_range], - focus_assist: selection_is_newest, - }) + let latest_selection = newest_selection.get_or_insert_with(|| selection.clone()); + if selection.id > latest_selection.id { + *latest_selection = selection.clone(); + } + selections.push(selection); + } + let newest_selection = newest_selection.unwrap(); + + let mut codegen_ranges = Vec::new(); + for (excerpt_id, buffer, buffer_range) in + snapshot.excerpts_in_ranges(selections.iter().map(|selection| { + snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end) + })) + { + let start = Anchor { + buffer_id: Some(buffer.remote_id()), + excerpt_id, + text_anchor: buffer.anchor_before(buffer_range.start), + }; + let end = Anchor { + buffer_id: Some(buffer.remote_id()), + excerpt_id, + text_anchor: buffer.anchor_after(buffer_range.end), + }; + codegen_ranges.push(start..end); } let assist_group_id = self.next_assist_group_id.post_inc(); let prompt_buffer = cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)); let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx)); + let mut assists = Vec::new(); let mut assist_to_focus = None; - - for CodegenRange { - transform_range, - selection_ranges, - focus_assist, - } in codegen_ranges - { - let transform_range = snapshot.anchor_before(transform_range.start) - ..snapshot.anchor_after(transform_range.end); - let selection_ranges = selection_ranges - .iter() - .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)) - .collect::>(); - + for range in codegen_ranges { + let assist_id = self.next_assist_id.post_inc(); let codegen = cx.new_model(|cx| { Codegen::new( editor.read(cx).buffer().clone(), - transform_range.clone(), - selection_ranges, + range.clone(), None, self.telemetry.clone(), self.prompt_builder.clone(), @@ -222,7 +207,6 @@ impl InlineAssistant { ) }); - let assist_id = self.next_assist_id.post_inc(); let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default())); let prompt_editor = cx.new_view(|cx| { PromptEditor::new( @@ -239,16 +223,23 @@ impl InlineAssistant { ) }); - if focus_assist { - assist_to_focus = Some(assist_id); + if assist_to_focus.is_none() { + let focus_assist = if newest_selection.reversed { + range.start.to_point(&snapshot) == newest_selection.start + } else { + range.end.to_point(&snapshot) == newest_selection.end + }; + if focus_assist { + assist_to_focus = Some(assist_id); + } } let [prompt_block_id, end_block_id] = - self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx); + self.insert_assist_blocks(editor, &range, &prompt_editor, cx); assists.push(( assist_id, - transform_range, + range, prompt_editor, prompt_block_id, end_block_id, @@ -315,7 +306,6 @@ impl InlineAssistant { Codegen::new( editor.read(cx).buffer().clone(), range.clone(), - vec![range.clone()], initial_transaction_id, self.telemetry.clone(), self.prompt_builder.clone(), @@ -925,7 +915,12 @@ impl InlineAssistant { assist .codegen .update(cx, |codegen, cx| { - codegen.start(user_prompt, assistant_panel_context, cx) + codegen.start( + assist.range.clone(), + user_prompt, + assistant_panel_context, + cx, + ) }) .log_err(); @@ -2120,9 +2115,12 @@ impl InlineAssist { return future::ready(Err(anyhow!("no user prompt"))).boxed(); }; let assistant_panel_context = self.assistant_panel_context(cx); - self.codegen - .read(cx) - .count_tokens(user_prompt, assistant_panel_context, cx) + self.codegen.read(cx).count_tokens( + self.range.clone(), + user_prompt, + assistant_panel_context, + cx, + ) } } @@ -2143,8 +2141,6 @@ pub struct Codegen { buffer: Model, old_buffer: Model, snapshot: MultiBufferSnapshot, - transform_range: Range, - selected_ranges: Vec>, edit_position: Option, last_equal_ranges: Vec>, initial_transaction_id: Option, @@ -2154,7 +2150,7 @@ pub struct Codegen { diff: Diff, telemetry: Option>, _subscription: gpui::Subscription, - prompt_builder: Arc, + builder: Arc, } enum CodegenStatus { @@ -2181,8 +2177,7 @@ impl EventEmitter for Codegen {} impl Codegen { pub fn new( buffer: Model, - transform_range: Range, - selected_ranges: Vec>, + range: Range, initial_transaction_id: Option, telemetry: Option>, builder: Arc, @@ -2192,7 +2187,7 @@ impl Codegen { let (old_buffer, _, _) = buffer .read(cx) - .range_to_buffer_ranges(transform_range.clone(), cx) + .range_to_buffer_ranges(range.clone(), cx) .pop() .unwrap(); let old_buffer = cx.new_model(|cx| { @@ -2223,9 +2218,7 @@ impl Codegen { telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), initial_transaction_id, - prompt_builder: builder, - transform_range, - selected_ranges, + builder, } } @@ -2250,12 +2243,14 @@ impl Codegen { pub fn count_tokens( &self, + edit_range: Range, user_prompt: String, assistant_panel_context: Option, cx: &AppContext, ) -> BoxFuture<'static, Result> { if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() { - let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx); + let request = + self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx); match request { Ok(request) => { let total_count = model.count_tokens(request.clone(), cx); @@ -2280,6 +2275,7 @@ impl Codegen { pub fn start( &mut self, + edit_range: Range, user_prompt: String, assistant_panel_context: Option, cx: &mut ModelContext, @@ -2294,20 +2290,24 @@ impl Codegen { }); } - self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot)); + self.edit_position = Some(edit_range.start.bias_right(&self.snapshot)); let telemetry_id = model.telemetry_id(); - let chunks: LocalBoxFuture>>> = - if user_prompt.trim().to_lowercase() == "delete" { - async { Ok(stream::empty().boxed()) }.boxed_local() - } else { - let request = self.build_request(user_prompt, assistant_panel_context, cx)?; + let chunks: LocalBoxFuture>>> = if user_prompt + .trim() + .to_lowercase() + == "delete" + { + async { Ok(stream::empty().boxed()) }.boxed_local() + } else { + let request = + self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?; - let chunks = - cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); - async move { Ok(chunks.await?.boxed()) }.boxed_local() - }; - self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx); + let chunks = + cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await }); + async move { Ok(chunks.await?.boxed()) }.boxed_local() + }; + self.handle_stream(telemetry_id, edit_range, chunks, cx); Ok(()) } @@ -2315,10 +2315,11 @@ impl Codegen { &self, user_prompt: String, assistant_panel_context: Option, + edit_range: Range, cx: &AppContext, ) -> Result { let buffer = self.buffer.read(cx).snapshot(cx); - let language = buffer.language_at(self.transform_range.start); + let language = buffer.language_at(edit_range.start); let language_name = if let Some(language) = language.as_ref() { if Arc::ptr_eq(language, &language::PLAIN_TEXT) { None @@ -2343,9 +2344,9 @@ impl Codegen { }; let language_name = language_name.as_deref(); - let start = buffer.point_to_buffer_offset(self.transform_range.start); - let end = buffer.point_to_buffer_offset(self.transform_range.end); - let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) { + let start = buffer.point_to_buffer_offset(edit_range.start); + let end = buffer.point_to_buffer_offset(edit_range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { let (start_buffer, start_buffer_offset) = start; let (end_buffer, end_buffer_offset) = end; if start_buffer.remote_id() == end_buffer.remote_id() { @@ -2357,39 +2358,9 @@ impl Codegen { return Err(anyhow::anyhow!("invalid transformation range")); }; - let mut transform_context_range = transform_range.to_point(&transform_buffer); - transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3); - transform_context_range.start.column = 0; - transform_context_range.end = - (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point()); - transform_context_range.end.column = - transform_buffer.line_len(transform_context_range.end.row); - let transform_context_range = transform_context_range.to_offset(&transform_buffer); - - let selected_ranges = self - .selected_ranges - .iter() - .filter_map(|selected_range| { - let start = buffer - .point_to_buffer_offset(selected_range.start) - .map(|(_, offset)| offset)?; - let end = buffer - .point_to_buffer_offset(selected_range.end) - .map(|(_, offset)| offset)?; - Some(start..end) - }) - .collect::>(); - let prompt = self - .prompt_builder - .generate_content_prompt( - user_prompt, - language_name, - transform_buffer, - transform_range, - selected_ranges, - transform_context_range, - ) + .builder + .generate_content_prompt(user_prompt, language_name, buffer, range) .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?; let mut messages = Vec::new(); @@ -2462,19 +2433,84 @@ impl Codegen { let mut diff = StreamingDiff::new(selected_text.to_string()); let mut line_diff = LineDiff::default(); + let mut new_text = String::new(); + let mut base_indent = None; + let mut line_indent = None; + let mut first_line = true; + while let Some(chunk) = chunks.next().await { if response_latency.is_none() { response_latency = Some(request_start.elapsed()); } let chunk = chunk?; - let char_ops = diff.push_new(&chunk); - line_diff.push_char_operations(&char_ops, &selected_text); - diff_tx - .send((char_ops, line_diff.line_operations())) - .await?; + + let mut lines = chunk.split('\n').peekable(); + while let Some(line) = lines.next() { + new_text.push_str(line); + if line_indent.is_none() { + if let Some(non_whitespace_ch_ix) = + new_text.find(|ch: char| !ch.is_whitespace()) + { + line_indent = Some(non_whitespace_ch_ix); + base_indent = base_indent.or(line_indent); + + let line_indent = line_indent.unwrap(); + let base_indent = base_indent.unwrap(); + let indent_delta = + line_indent as i32 - base_indent as i32; + let mut corrected_indent_len = cmp::max( + 0, + suggested_line_indent.len as i32 + indent_delta, + ) + as usize; + if first_line { + corrected_indent_len = corrected_indent_len + .saturating_sub( + selection_start.column as usize, + ); + } + + let indent_char = suggested_line_indent.char(); + let mut indent_buffer = [0; 4]; + let indent_str = + indent_char.encode_utf8(&mut indent_buffer); + new_text.replace_range( + ..line_indent, + &indent_str.repeat(corrected_indent_len), + ); + } + } + + if line_indent.is_some() { + let char_ops = diff.push_new(&new_text); + line_diff + .push_char_operations(&char_ops, &selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; + new_text.clear(); + } + + if lines.peek().is_some() { + let char_ops = diff.push_new("\n"); + line_diff + .push_char_operations(&char_ops, &selected_text); + diff_tx + .send((char_ops, line_diff.line_operations())) + .await?; + if line_indent.is_none() { + // Don't write out the leading indentation in empty lines on the next line + // This is the case where the above if statement didn't clear the buffer + new_text.clear(); + } + line_indent = None; + first_line = false; + } + } } - let char_ops = diff.finish(); + let mut char_ops = diff.push_new(&new_text); + char_ops.extend(diff.finish()); line_diff.push_char_operations(&char_ops, &selected_text); line_diff.finish(&selected_text); diff_tx @@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use futures::stream::{self}; + use gpui::{Context, TestAppContext}; + use indoc::indoc; + use language::{ + language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher, + Point, + }; + use language_model::LanguageModelRegistry; + use rand::prelude::*; use serde::Serialize; + use settings::SettingsStore; + use std::{future, sync::Arc}; #[derive(Serialize)] pub struct DummyCompletionRequest { pub name: String, } + #[gpui::test(iterations = 10)] + async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_model::LanguageModelRegistry::test); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + let x = 0; + for _ in 0..10 { + x += 1; + } + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + Codegen::new( + buffer.clone(), + range.clone(), + None, + None, + prompt_builder, + cx, + ) + }); + + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + range, + future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), + cx, + ) + }); + + let mut new_text = concat!( + " let mut x = 0;\n", + " while x < 10 {\n", + " x += 1;\n", + " }", + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_past_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + fn main() { + le + } + "}; + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + Codegen::new( + buffer.clone(), + range.clone(), + None, + None, + prompt_builder, + cx, + ) + }); + + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + range.clone(), + future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), + cx, + ) + }); + + cx.background_executor.run_until_parked(); + + let mut new_text = concat!( + "t mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_when_generating_before_indentation( + cx: &mut TestAppContext, + mut rng: StdRng, + ) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = concat!( + "fn main() {\n", + " \n", + "}\n" // + ); + let buffer = + cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + Codegen::new( + buffer.clone(), + range.clone(), + None, + None, + prompt_builder, + cx, + ) + }); + + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + range.clone(), + future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), + cx, + ) + }); + + cx.background_executor.run_until_parked(); + + let mut new_text = concat!( + "let mut x = 0;\n", + "while x < 10 {\n", + " x += 1;\n", + "}", // + ); + while !new_text.is_empty() { + let max_len = cmp::min(new_text.len(), 10); + let len = rng.gen_range(1..=max_len); + let (chunk, suffix) = new_text.split_at(len); + chunks_tx.unbounded_send(chunk.to_string()).unwrap(); + new_text = suffix; + cx.background_executor.run_until_parked(); + } + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + fn main() { + let mut x = 0; + while x < 10 { + x += 1; + } + } + "} + ); + } + + #[gpui::test(iterations = 10)] + async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) { + cx.update(LanguageModelRegistry::test); + cx.set_global(cx.update(SettingsStore::test)); + cx.update(language_settings::init); + + let text = indoc! {" + func main() { + \tx := 0 + \tfor i := 0; i < 10; i++ { + \t\tx++ + \t} + } + "}; + let buffer = cx.new_model(|cx| Buffer::local(text, cx)); + let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx)); + let range = buffer.read_with(cx, |buffer, cx| { + let snapshot = buffer.snapshot(cx); + snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) + }); + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let codegen = cx.new_model(|cx| { + Codegen::new( + buffer.clone(), + range.clone(), + None, + None, + prompt_builder, + cx, + ) + }); + + let (chunks_tx, chunks_rx) = mpsc::unbounded(); + codegen.update(cx, |codegen, cx| { + codegen.handle_stream( + String::new(), + range.clone(), + future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), + cx, + ) + }); + + let new_text = concat!( + "func main() {\n", + "\tx := 0\n", + "\tfor x < 10 {\n", + "\t\tx++\n", + "\t}", // + ); + chunks_tx.unbounded_send(new_text.to_string()).unwrap(); + drop(chunks_tx); + cx.background_executor.run_until_parked(); + + assert_eq!( + buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()), + indoc! {" + func main() { + \tx := 0 + \tfor x < 10 { + \t\tx++ + \t} + } + "} + ); + } + #[gpui::test] async fn test_strip_invalid_spans_from_codeblock() { assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await; @@ -2984,4 +3318,27 @@ mod tests { ) } } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ) + .with_indents_query( + r#" + (call_expression) @indent + (field_expression) @indent + (_ "(" ")" @end) @indent + (_ "{" "}" @end) @indent + "#, + ) + .unwrap() + } } diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index ed3324f54f..f8e645ef46 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -12,15 +12,11 @@ use util::ResultExt; pub struct ContentPromptContext { pub content_type: String, pub language_name: Option, + pub is_insert: bool, pub is_truncated: bool, pub document_content: String, pub user_prompt: String, - pub rewrite_section: String, - pub rewrite_section_prefix: String, - pub rewrite_section_suffix: String, - pub rewrite_section_with_edits: String, - pub has_insertion: bool, - pub has_replacement: bool, + pub rewrite_section: Option, } #[derive(Serialize)] @@ -46,54 +42,41 @@ pub struct PromptBuilder { handlebars: Arc>>, } -pub struct PromptOverrideContext<'a> { - pub dev_mode: bool, - pub fs: Arc, - pub cx: &'a mut gpui::AppContext, -} - impl PromptBuilder { - pub fn new(override_cx: Option) -> Result> { + pub fn new( + fs_and_cx: Option<(Arc, &gpui::AppContext)>, + ) -> Result> { let mut handlebars = Handlebars::new(); Self::register_templates(&mut handlebars)?; let handlebars = Arc::new(Mutex::new(handlebars)); - if let Some(override_cx) = override_cx { - Self::watch_fs_for_template_overrides(override_cx, handlebars.clone()); + if let Some((fs, cx)) = fs_and_cx { + Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone()); } Ok(Self { handlebars }) } fn watch_fs_for_template_overrides( - PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext, + fs: Arc, + cx: &gpui::AppContext, handlebars: Arc>>, ) { + let templates_dir = paths::prompt_overrides_dir(); + cx.background_executor() .spawn(async move { - let templates_dir = if dev_mode { - std::env::current_dir() - .ok() - .and_then(|pwd| { - let pwd_assets_prompts = pwd.join("assets").join("prompts"); - pwd_assets_prompts.exists().then_some(pwd_assets_prompts) - }) - .unwrap_or_else(|| paths::prompt_overrides_dir().clone()) - } else { - paths::prompt_overrides_dir().clone() - }; - // Create the prompt templates directory if it doesn't exist - if !fs.is_dir(&templates_dir).await { - if let Err(e) = fs.create_dir(&templates_dir).await { + if !fs.is_dir(templates_dir).await { + if let Err(e) = fs.create_dir(templates_dir).await { log::error!("Failed to create prompt templates directory: {}", e); return; } } // Initial scan of the prompts directory - if let Ok(mut entries) = fs.read_dir(&templates_dir).await { + if let Ok(mut entries) = fs.read_dir(templates_dir).await { while let Some(Ok(file_path)) = entries.next().await { if file_path.to_string_lossy().ends_with(".hbs") { if let Ok(content) = fs.load(&file_path).await { @@ -121,7 +104,7 @@ impl PromptBuilder { } // Watch for changes - let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await; + let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await; while let Some(changed_paths) = changes.next().await { for changed_path in changed_paths { if changed_path.extension().map_or(false, |ext| ext == "hbs") { @@ -173,9 +156,7 @@ impl PromptBuilder { user_prompt: String, language_name: Option<&str>, buffer: BufferSnapshot, - transform_range: Range, - selected_ranges: Vec>, - transform_context_range: Range, + range: Range, ) -> Result { let content_type = match language_name { None | Some("Markdown" | "Plain Text") => "text", @@ -183,20 +164,21 @@ impl PromptBuilder { }; const MAX_CTX: usize = 50000; + let is_insert = range.is_empty(); let mut is_truncated = false; - let before_range = 0..transform_range.start; + let before_range = 0..range.start; let truncated_before = if before_range.len() > MAX_CTX { is_truncated = true; - transform_range.start - MAX_CTX..transform_range.start + range.start - MAX_CTX..range.start } else { before_range }; - let after_range = transform_range.end..buffer.len(); + let after_range = range.end..buffer.len(); let truncated_after = if after_range.len() > MAX_CTX { is_truncated = true; - transform_range.end..transform_range.end + MAX_CTX + range.end..range.end + MAX_CTX } else { after_range }; @@ -205,74 +187,37 @@ impl PromptBuilder { for chunk in buffer.text_for_range(truncated_before) { document_content.push_str(chunk); } - - document_content.push_str("\n"); - for chunk in buffer.text_for_range(transform_range.clone()) { - document_content.push_str(chunk); + if is_insert { + document_content.push_str(""); + } else { + document_content.push_str("\n"); + for chunk in buffer.text_for_range(range.clone()) { + document_content.push_str(chunk); + } + document_content.push_str("\n"); } - document_content.push_str("\n"); - for chunk in buffer.text_for_range(truncated_after) { document_content.push_str(chunk); } - let mut rewrite_section = String::new(); - for chunk in buffer.text_for_range(transform_range.clone()) { - rewrite_section.push_str(chunk); - } - - let mut rewrite_section_prefix = String::new(); - for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) { - rewrite_section_prefix.push_str(chunk); - } - - let mut rewrite_section_suffix = String::new(); - for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) { - rewrite_section_suffix.push_str(chunk); - } - - let rewrite_section_with_edits = { - let mut section_with_selections = String::new(); - let mut last_end = 0; - for selected_range in &selected_ranges { - if selected_range.start > last_end { - section_with_selections.push_str( - &rewrite_section[last_end..selected_range.start - transform_range.start], - ); - } - if selected_range.start == selected_range.end { - section_with_selections.push_str(""); - } else { - section_with_selections.push_str(""); - section_with_selections.push_str( - &rewrite_section[selected_range.start - transform_range.start - ..selected_range.end - transform_range.start], - ); - section_with_selections.push_str(""); - } - last_end = selected_range.end - transform_range.start; + let rewrite_section = if !is_insert { + let mut section = String::new(); + for chunk in buffer.text_for_range(range.clone()) { + section.push_str(chunk); } - if last_end < rewrite_section.len() { - section_with_selections.push_str(&rewrite_section[last_end..]); - } - section_with_selections + Some(section) + } else { + None }; - let has_insertion = selected_ranges.iter().any(|range| range.start == range.end); - let has_replacement = selected_ranges.iter().any(|range| range.start != range.end); - let context = ContentPromptContext { content_type: content_type.to_string(), language_name: language_name.map(|s| s.to_string()), + is_insert, is_truncated, document_content, user_prompt, rewrite_section, - rewrite_section_prefix, - rewrite_section_suffix, - rewrite_section_with_edits, - has_insertion, - has_replacement, }; self.handlebars.lock().render("content_prompt", &context) diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 173387b840..2eb5b3fc05 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -187,12 +187,7 @@ fn init_common(app_state: Arc, cx: &mut AppContext) -> Arc