From c64b26110cb57f4a68b1a58b2a28c8d913ed2776 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 4 Feb 2025 19:32:17 +0100 Subject: [PATCH] Revert "edit prediction: Try to expand context to parent treesitter region" (#24214) Reverts zed-industries/zed#24186 --- crates/rpc/src/llm.rs | 1 - crates/zeta/src/input_excerpt.rs | 238 ---------------------------- crates/zeta/src/zeta.rs | 257 ++++++++++++++++++++++++------- 3 files changed, 203 insertions(+), 293 deletions(-) delete mode 100644 crates/zeta/src/input_excerpt.rs diff --git a/crates/rpc/src/llm.rs b/crates/rpc/src/llm.rs index 92cd6dccca..93ac5bdee8 100644 --- a/crates/rpc/src/llm.rs +++ b/crates/rpc/src/llm.rs @@ -39,7 +39,6 @@ pub struct PredictEditsParams { pub outline: Option, pub input_events: String, pub input_excerpt: String, - pub speculated_output: String, /// Whether the user provided consent for sampling this interaction. #[serde(default)] pub data_collection_permission: bool, diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/input_excerpt.rs deleted file mode 100644 index 103f03750a..0000000000 --- a/crates/zeta/src/input_excerpt.rs +++ /dev/null @@ -1,238 +0,0 @@ -use crate::{ - BYTES_PER_TOKEN_GUESS, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, - START_OF_FILE_MARKER, -}; -use language::{BufferSnapshot, Point}; -use std::{fmt::Write, ops::Range}; - -pub struct InputExcerpt { - pub editable_range: Range, - pub prompt: String, - pub speculated_output: String, -} - -pub fn excerpt_for_cursor_position( - position: Point, - path: &str, - snapshot: &BufferSnapshot, - editable_region_token_limit: usize, - context_token_limit: usize, -) -> InputExcerpt { - let mut scope_range = position..position; - let mut remaining_edit_tokens = editable_region_token_limit; - - while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = tokens_for_bytes(parent.byte_range().len()); - if parent_tokens <= editable_region_token_limit { - scope_range = Point::new( - parent.start_position().row as u32, - parent.start_position().column as u32, - ) - ..Point::new( - parent.end_position().row as u32, - parent.end_position().column as u32, - ); - remaining_edit_tokens = editable_region_token_limit - parent_tokens; - } else { - break; - } - } - - let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); - let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); - - let mut prompt = String::new(); - let mut speculated_output = String::new(); - - writeln!(&mut prompt, "```{path}").unwrap(); - if context_range.start == Point::zero() { - writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); - } - - for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { - prompt.push_str(chunk.text); - } - - push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); - push_editable_range( - position, - snapshot, - editable_range.clone(), - &mut speculated_output, - ); - - for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n```").unwrap(); - - InputExcerpt { - editable_range, - prompt, - speculated_output, - } -} - -fn push_editable_range( - cursor_position: Point, - snapshot: &BufferSnapshot, - editable_range: Range, - prompt: &mut String, -) { - writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); - for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { - prompt.push_str(chunk.text); - } - prompt.push_str(CURSOR_MARKER); - for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { - prompt.push_str(chunk.text); - } - write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); -} - -fn expand_range( - snapshot: &BufferSnapshot, - range: Range, - mut remaining_tokens: usize, -) -> Range { - let mut expanded_range = range.clone(); - expanded_range.start.column = 0; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - loop { - let mut expanded = false; - - if remaining_tokens > 0 && expanded_range.start.row > 0 { - expanded_range.start.row -= 1; - let line_tokens = - tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { - expanded_range.end.row += 1; - expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = tokens_for_bytes(expanded_range.end.column as usize); - remaining_tokens = remaining_tokens.saturating_sub(line_tokens); - expanded = true; - } - - if !expanded { - break; - } - } - expanded_range -} - -fn tokens_for_bytes(bytes: usize) -> usize { - bytes / BYTES_PER_TOKEN_GUESS -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::{App, AppContext}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher}; - use std::sync::Arc; - - #[gpui::test] - fn test_excerpt_for_cursor_position(cx: &mut App) { - let text = indoc! {r#" - fn foo() { - let x = 42; - println!("Hello, world!"); - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - return sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.gen_range(1..101)); - } - numbers - } - "#}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let snapshot = buffer.read(cx).snapshot(); - - // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion - // when a larger scope doesn't fit the editable region. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - let x = 42; - println!("Hello, world!"); - <|editable_region_start|> - } - - fn bar() { - let x = 42; - let mut sum = 0; - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - <|editable_region_end|> - let mut rng = rand::thread_rng(); - let mut numbers = Vec::new(); - ```"#} - ); - - // The `bar` function won't fit within the editable region, so we resort to line-based expansion. - let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); - assert_eq!( - excerpt.prompt, - indoc! {r#" - ```main.rs - fn bar() { - let x = 42; - let mut sum = 0; - <|editable_region_start|> - for i in 0..x { - sum += i; - } - println!("Sum: {}", sum); - r<|user_cursor_is_here|>eturn sum; - } - - fn generate_random_numbers() -> Vec { - let mut rng = rand::thread_rng(); - <|editable_region_end|> - let mut numbers = Vec::new(); - for _ in 0..5 { - numbers.push(rng.gen_range(1..101)); - ```"#} - ); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - } -} diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 135a38297d..7ed79caf6d 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,6 +1,5 @@ mod completion_diff_element; mod init; -mod input_excerpt; mod license_detection; mod onboarding_banner; mod onboarding_modal; @@ -26,7 +25,7 @@ use gpui::{ use http_client::{HttpClient, Method}; use language::{ language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, EditPreview, - OffsetRangeExt, ToOffset, ToPoint, + OffsetRangeExt, Point, ToOffset, ToPoint, }; use language_models::LlmApiToken; use postage::watch; @@ -62,26 +61,26 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch /// intentionally low to err on the side of underestimating limits. const BYTES_PER_TOKEN_GUESS: usize = 3; -/// Input token limit, used to inform the size of the input. A copy of this constant is also in +/// Output token limit, used to inform the size of the input. A copy of this constant is also in /// `crates/collab/src/llm.rs`. -const MAX_INPUT_TOKENS: usize = 2048; - -const MAX_CONTEXT_TOKENS: usize = 64; -const MAX_OUTPUT_TOKENS: usize = 256; +const MAX_OUTPUT_TOKENS: usize = 2048; /// Total bytes limit for editable region of buffer excerpt. /// /// The number of output tokens is relevant to the size of the input excerpt because the model is /// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens /// remaining for the model to specify insertions. -const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_INPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS; +const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS; + +/// Total line limit for editable region of buffer excerpt. +const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64; /// Note that this is not the limit for the overall prompt, just for the inputs to the template /// instantiated in `crates/collab/src/llm.rs`. const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2; /// Maximum number of events to include in the prompt. -const MAX_EVENT_COUNT: usize = 8; +const MAX_EVENT_COUNT: usize = 16; /// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of /// equally splitting up the the remaining bytes after the largest possible buffer excerpt. @@ -374,8 +373,8 @@ impl Zeta { R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); - let cursor_position = cursor.to_point(&snapshot); - let cursor_offset = cursor_position.to_offset(&snapshot); + let cursor_point = cursor.to_point(&snapshot); + let cursor_offset = cursor_point.to_offset(&snapshot); let events = self.events.clone(); let path: Arc = snapshot .file() @@ -390,47 +389,45 @@ impl Zeta { cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); - let (input_events, input_excerpt, editable_range, input_outline, speculated_output) = - cx.background_executor() - .spawn({ - let snapshot = snapshot.clone(); - let path = path.clone(); - async move { - let path = path.to_string_lossy(); - let input_excerpt = input_excerpt::excerpt_for_cursor_position( - cursor_position, - &path, - &snapshot, - MAX_OUTPUT_TOKENS, - MAX_CONTEXT_TOKENS, - ); + let (input_events, input_excerpt, excerpt_range, input_outline) = cx + .background_executor() + .spawn({ + let snapshot = snapshot.clone(); + let path = path.clone(); + async move { + let path = path.to_string_lossy(); + let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position( + cursor_point, + BUFFER_EXCERPT_BYTE_LIMIT, + BUFFER_EXCERPT_LINE_LIMIT, + &path, + &snapshot, + )?; + let input_excerpt = prompt_for_excerpt( + cursor_offset, + &excerpt_range, + excerpt_len_guess, + &path, + &snapshot, + ); - let bytes_remaining = - TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.prompt.len()); - let input_events = prompt_for_events(events.iter(), bytes_remaining); + let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len()); + let input_events = prompt_for_events(events.iter(), bytes_remaining); - // Note that input_outline is not currently used in prompt generation and so - // is not counted towards TOTAL_BYTE_LIMIT. - let input_outline = prompt_for_outline(&snapshot); + // Note that input_outline is not currently used in prompt generation and so + // is not counted towards TOTAL_BYTE_LIMIT. + let input_outline = prompt_for_outline(&snapshot); - let editable_range = input_excerpt.editable_range.to_offset(&snapshot); - anyhow::Ok(( - input_events, - input_excerpt.prompt, - editable_range, - input_outline, - input_excerpt.speculated_output, - )) - } - }) - .await?; + anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline)) + } + }) + .await?; log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt); let body = PredictEditsParams { input_events: input_events.clone(), input_excerpt: input_excerpt.clone(), - speculated_output, outline: Some(input_outline.clone()), data_collection_permission, }; @@ -444,7 +441,7 @@ impl Zeta { output_excerpt, buffer, &snapshot, - editable_range, + excerpt_range, cursor_offset, path, input_outline, @@ -460,8 +457,6 @@ impl Zeta { // Generates several example completions of various states to fill the Zeta completion modal #[cfg(any(test, feature = "test-support"))] pub fn fill_with_fake_completions(&mut self, cx: &mut Context) -> Task<()> { - use language::Point; - let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line And maybe a short line @@ -680,7 +675,7 @@ and then another output_excerpt: String, buffer: Entity, snapshot: &BufferSnapshot, - editable_range: Range, + excerpt_range: Range, cursor_offset: usize, path: Arc, input_outline: String, @@ -697,9 +692,9 @@ and then another .background_executor() .spawn({ let output_excerpt = output_excerpt.clone(); - let editable_range = editable_range.clone(); + let excerpt_range = excerpt_range.clone(); let snapshot = snapshot.clone(); - async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) } + async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) } }) .await? .into(); @@ -722,7 +717,7 @@ and then another Ok(Some(InlineCompletion { id: InlineCompletionId::new(), path, - excerpt_range: editable_range, + excerpt_range, cursor_offset, edits, edit_preview, @@ -739,7 +734,7 @@ and then another fn parse_edits( output_excerpt: Arc, - editable_range: Range, + excerpt_range: Range, snapshot: &BufferSnapshot, ) -> Result, String)>> { let content = output_excerpt.replace(CURSOR_MARKER, ""); @@ -783,13 +778,13 @@ and then another let new_text = &content[..codefence_end]; let old_text = snapshot - .text_for_range(editable_range.clone()) + .text_for_range(excerpt_range.clone()) .collect::(); Ok(Self::compute_edits( old_text, new_text, - editable_range.start, + excerpt_range.start, &snapshot, )) } @@ -1016,6 +1011,161 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { input_outline } +fn prompt_for_excerpt( + offset: usize, + excerpt_range: &Range, + mut len_guess: usize, + path: &str, + snapshot: &BufferSnapshot, +) -> String { + let point_range = excerpt_range.to_point(snapshot); + + // Include one line of extra context before and after editable range, if those lines are non-empty. + let extra_context_before_range = + if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) { + let range = + (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot); + len_guess += range.end - range.start; + Some(range) + } else { + None + }; + let extra_context_after_range = if point_range.end.row < snapshot.max_point().row + && !snapshot.is_line_blank(point_range.end.row + 1) + { + let range = (point_range.end + ..Point::new( + point_range.end.row + 1, + snapshot.line_len(point_range.end.row + 1), + )) + .to_offset(snapshot); + len_guess += range.end - range.start; + Some(range) + } else { + None + }; + + let mut prompt_excerpt = String::with_capacity(len_guess); + writeln!(prompt_excerpt, "```{}", path).unwrap(); + + if excerpt_range.start == 0 { + writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap(); + } + + if let Some(extra_context_before_range) = extra_context_before_range { + for chunk in snapshot.text_for_range(extra_context_before_range) { + prompt_excerpt.push_str(chunk); + } + } + writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.text_for_range(excerpt_range.start..offset) { + prompt_excerpt.push_str(chunk); + } + prompt_excerpt.push_str(CURSOR_MARKER); + for chunk in snapshot.text_for_range(offset..excerpt_range.end) { + prompt_excerpt.push_str(chunk); + } + write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); + + if let Some(extra_context_after_range) = extra_context_after_range { + for chunk in snapshot.text_for_range(extra_context_after_range) { + prompt_excerpt.push_str(chunk); + } + } + + write!(prompt_excerpt, "\n```").unwrap(); + debug_assert!( + prompt_excerpt.len() <= len_guess, + "Excerpt length {} exceeds estimated length {}", + prompt_excerpt.len(), + len_guess + ); + prompt_excerpt +} + +fn excerpt_range_for_position( + cursor_point: Point, + byte_limit: usize, + line_limit: u32, + path: &str, + snapshot: &BufferSnapshot, +) -> Result<(Range, usize)> { + let cursor_row = cursor_point.row; + let last_buffer_row = snapshot.max_point().row; + + // This is an overestimate because it includes parts of prompt_for_excerpt which are + // conditionally skipped. + let mut len_guess = 0; + len_guess += "```".len() + path.len() + 1; + len_guess += START_OF_FILE_MARKER.len() + 1; + len_guess += EDITABLE_REGION_START_MARKER.len() + 1; + len_guess += CURSOR_MARKER.len(); + len_guess += EDITABLE_REGION_END_MARKER.len() + 1; + len_guess += "```".len() + 1; + + len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap(); + + if len_guess > byte_limit { + return Err(anyhow!("Current line too long to send to model.")); + } + + let mut excerpt_start_row = cursor_row; + let mut excerpt_end_row = cursor_row; + let mut no_more_before = cursor_row == 0; + let mut no_more_after = cursor_row >= last_buffer_row; + let mut row_delta = 1; + loop { + if !no_more_before { + let row = cursor_point.row - row_delta; + let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); + let mut new_len_guess = len_guess + line_len; + if row == 0 { + new_len_guess += START_OF_FILE_MARKER.len() + 1; + } + if new_len_guess <= byte_limit { + len_guess = new_len_guess; + excerpt_start_row = row; + if row == 0 { + no_more_before = true; + } + } else { + no_more_before = true; + } + } + if excerpt_end_row - excerpt_start_row >= line_limit { + break; + } + if !no_more_after { + let row = cursor_point.row + row_delta; + let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); + let new_len_guess = len_guess + line_len; + if new_len_guess <= byte_limit { + len_guess = new_len_guess; + excerpt_end_row = row; + if row >= last_buffer_row { + no_more_after = true; + } + } else { + no_more_after = true; + } + } + if excerpt_end_row - excerpt_start_row >= line_limit { + break; + } + if no_more_before && no_more_after { + break; + } + row_delta += 1; + } + + let excerpt_start = Point::new(excerpt_start_row, 0); + let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row)); + Ok(( + excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot), + len_guess, + )) +} + fn prompt_for_events<'a>( events: impl Iterator, mut bytes_remaining: usize, @@ -1527,7 +1677,6 @@ mod tests { use gpui::TestAppContext; use http_client::FakeHttpClient; use indoc::indoc; - use language::Point; use language_models::RefreshLlmTokenListener; use rpc::proto; use settings::SettingsStore;