From 87b0f62041a9bdcd22596002972c36b174bc9f60 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Thu, 30 Jan 2025 15:27:42 -0700 Subject: [PATCH] Implement simpler logic for edit predictions prompt byte limits (#23983) Realized that the logic in #23814 was more than needed, and harder to maintain. Something like that could make sense if using the tokenizer and wanting to precisely hit a token limit. However in the case of edit predictions it's more of a latency+expense vs capability tradeoff, and so such precision is unnecessary. Happily this change didn't require much extra work, just copy-modifying parts of that change was sufficient. Release Notes: - N/A --- crates/collab/src/llm.rs | 9 +- crates/zeta/src/zeta.rs | 387 +++++++++++++++------------------------ 2 files changed, 156 insertions(+), 240 deletions(-) diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index b5ee3713fc..6e0ca40d09 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -42,6 +42,11 @@ use util::ResultExt; pub use token::*; +const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); + +/// Output token limit. A copy of this constant is also in `crates/zeta/src/zeta.rs`. +const MAX_OUTPUT_TOKENS: u32 = 2048; + pub struct LlmState { pub config: Config, pub executor: Executor, @@ -52,8 +57,6 @@ pub struct LlmState { RwLock, ActiveUserCount)>>, } -const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); - impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { let database_url = config @@ -488,7 +491,7 @@ async fn predict_edits( fireworks::CompletionRequest { model: model.to_string(), prompt: prompt.clone(), - max_tokens: 2048, + max_tokens: MAX_OUTPUT_TOKENS, temperature: 0., prediction: Some(fireworks::Prediction::Content { content: params.input_excerpt.clone(), diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 0fe08cd107..ad53dcc757 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -58,11 +58,19 @@ const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str = /// intentionally low to err on the side of underestimating limits. const BYTES_PER_TOKEN_GUESS: usize = 3; -/// This is based on the output token limit `max_tokens: 2048` in `crates/collab/src/llm.rs`. Number -/// of output tokens is relevant to the size of the input excerpt because the model is tasked with -/// outputting a modified excerpt. `2/3` is chosen so that there are some output tokens remaining -/// for the model to specify insertions. -const BUFFER_EXCERPT_BYTE_LIMIT: usize = (2048 * 2 / 3) * BYTES_PER_TOKEN_GUESS; +/// Output token limit, used to inform the size of the input. A copy of this constant is also in +/// `crates/collab/src/llm.rs`. +const MAX_OUTPUT_TOKENS: usize = 2048; + +/// Total bytes limit for editable region of buffer excerpt. +/// +/// The number of output tokens is relevant to the size of the input excerpt because the model is +/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens +/// remaining for the model to specify insertions. +const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS; + +/// Total line limit for editable region of buffer excerpt. +const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64; /// Note that this is not the limit for the overall prompt, just for the inputs to the template /// instantiated in `crates/collab/src/llm.rs`. @@ -342,12 +350,11 @@ impl Zeta { F: FnOnce(Arc, LlmApiToken, bool, PredictEditsParams) -> R + 'static, R: Future> + Send + 'static, { - let buffer = buffer.clone(); let snapshot = self.report_changes_for_buffer(&buffer, cx); let cursor_point = cursor.to_point(&snapshot); let cursor_offset = cursor_point.to_offset(&snapshot); let events = self.events.clone(); - let path = snapshot + let path: Arc = snapshot .file() .map(|f| Arc::from(f.full_path(cx).as_path())) .unwrap_or_else(|| Arc::from(Path::new("untitled"))); @@ -356,25 +363,40 @@ impl Zeta { let llm_token = self.llm_token.clone(); let is_staff = cx.is_staff(); + let buffer = buffer.clone(); cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); - let (input_events, input_excerpt, input_outline, excerpt_range) = cx + let (input_events, input_excerpt, excerpt_range, input_outline) = cx .background_executor() .spawn({ let snapshot = snapshot.clone(); + let path = path.clone(); async move { - let (input_excerpt, excerpt_range) = - prompt_for_excerpt(&snapshot, cursor_point, cursor_offset)?; + let path = path.to_string_lossy(); + let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position( + cursor_point, + BUFFER_EXCERPT_BYTE_LIMIT, + BUFFER_EXCERPT_LINE_LIMIT, + &path, + &snapshot, + )?; + let input_excerpt = prompt_for_excerpt( + cursor_offset, + &excerpt_range, + excerpt_len_guess, + &path, + &snapshot, + ); - let chars_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len()); - let input_events = prompt_for_events(events.iter(), chars_remaining); + let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len()); + let input_events = prompt_for_events(events.iter(), bytes_remaining); // Note that input_outline is not currently used in prompt generation and so // is not counted towards TOTAL_BYTE_LIMIT. let input_outline = prompt_for_outline(&snapshot); - anyhow::Ok((input_events, input_excerpt, input_outline, excerpt_range)) + anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline)) } }) .await?; @@ -998,201 +1020,137 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String { input_outline } -#[derive(Debug, Default)] -struct ExcerptPromptBuilder<'a> { - file_path: Cow<'a, str>, - include_start_of_file_marker: bool, - before_editable_region: Option>, - before_cursor: ReversedStringChunks<'a>, - after_cursor: StringChunks<'a>, - after_editable_region: Option>, -} - -impl<'a> ExcerptPromptBuilder<'a> { - pub fn len(&self) -> usize { - let mut length = 0; - length += "```".len(); - length += self.file_path.len(); - length += 1; - if self.include_start_of_file_marker { - length += START_OF_FILE_MARKER.len(); - length += 1; - } - if let Some(before_editable_region) = &self.before_editable_region { - length += before_editable_region.len(); - length += 1; - } - length += EDITABLE_REGION_START_MARKER.len(); - length += 1; - length += self.before_cursor.len(); - length += CURSOR_MARKER.len(); - length += self.after_cursor.len(); - length += 1; - length += EDITABLE_REGION_END_MARKER.len(); - length += 1; - if let Some(after_editable_region) = &self.after_editable_region { - length += after_editable_region.len(); - length += 1; - } - length += "```".len(); - length - } - - pub fn to_string(&self) -> String { - let length = self.len(); - let mut result = String::with_capacity(length); - result.push_str("```"); - result.push_str(&self.file_path); - result.push('\n'); - if self.include_start_of_file_marker { - result.push_str(START_OF_FILE_MARKER); - result.push('\n'); - } - if let Some(before_editable_region) = &self.before_editable_region { - before_editable_region.add_to_string(&mut result); - result.push('\n'); - } - result.push_str(EDITABLE_REGION_START_MARKER); - result.push('\n'); - self.before_cursor.add_to_string(&mut result); - result.push_str(CURSOR_MARKER); - self.after_cursor.add_to_string(&mut result); - result.push('\n'); - result.push_str(EDITABLE_REGION_END_MARKER); - result.push('\n'); - if let Some(after_editable_region) = &self.after_editable_region { - after_editable_region.add_to_string(&mut result); - result.push('\n'); - } - result.push_str("```"); - debug_assert!( - result.len() == length, - "Expected length: {}, Actual length: {}", - length, - result.len() - ); - result - } -} - -#[derive(Debug, Default)] -pub struct StringChunks<'a> { - chunks: Vec<&'a str>, - length: usize, -} - -#[derive(Debug, Default)] -pub struct ReversedStringChunks<'a>(StringChunks<'a>); - -impl<'a> StringChunks<'a> { - pub fn len(&self) -> usize { - self.length - } - - pub fn extend(&mut self, new_chunks: impl Iterator) { - self.chunks - .extend(new_chunks.inspect(|chunk| self.length += chunk.len())); - } - - pub fn append_from_buffer( - &mut self, - snapshot: &'a BufferSnapshot, - range: Range, - ) { - self.extend(snapshot.text_for_range(range)); - } - - pub fn add_to_string(&self, string: &mut String) { - for chunk in self.chunks.iter() { - string.push_str(chunk); - } - } -} - -impl<'a> ReversedStringChunks<'a> { - pub fn len(&self) -> usize { - self.0.len() - } - - pub fn prepend_from_buffer( - &mut self, - snapshot: &'a BufferSnapshot, - range: Range, - ) { - self.0.extend(snapshot.reversed_chunks_in_range(range)); - } - - pub fn add_to_string(&self, string: &mut String) { - for chunk in self.0.chunks.iter().rev() { - string.push_str(chunk); - } - } -} - -/// Computes a prompt for the excerpt of the buffer around the cursor. This always includes complete -/// lines and the result length will be `<= MAX_INPUT_EXCERPT_BYTES`. fn prompt_for_excerpt( + offset: usize, + excerpt_range: &Range, + mut len_guess: usize, + path: &str, snapshot: &BufferSnapshot, +) -> String { + let point_range = excerpt_range.to_point(snapshot); + + // Include one line of extra context before and after editable range, if those lines are non-empty. + let extra_context_before_range = + if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) { + let range = + (Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot); + len_guess += range.end - range.start; + Some(range) + } else { + None + }; + let extra_context_after_range = if point_range.end.row < snapshot.max_point().row + && !snapshot.is_line_blank(point_range.end.row + 1) + { + let range = (point_range.end + ..Point::new( + point_range.end.row + 1, + snapshot.line_len(point_range.end.row + 1), + )) + .to_offset(snapshot); + len_guess += range.end - range.start; + Some(range) + } else { + None + }; + + let mut prompt_excerpt = String::with_capacity(len_guess); + writeln!(prompt_excerpt, "```{}", path).unwrap(); + + if excerpt_range.start == 0 { + writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap(); + } + + if let Some(extra_context_before_range) = extra_context_before_range { + for chunk in snapshot.text_for_range(extra_context_before_range) { + prompt_excerpt.push_str(chunk); + } + } + writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.text_for_range(excerpt_range.start..offset) { + prompt_excerpt.push_str(chunk); + } + prompt_excerpt.push_str(CURSOR_MARKER); + for chunk in snapshot.text_for_range(offset..excerpt_range.end) { + prompt_excerpt.push_str(chunk); + } + write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); + + if let Some(extra_context_after_range) = extra_context_after_range { + for chunk in snapshot.text_for_range(extra_context_after_range) { + prompt_excerpt.push_str(chunk); + } + } + + write!(prompt_excerpt, "\n```").unwrap(); + debug_assert!( + prompt_excerpt.len() <= len_guess, + "Excerpt length {} exceeds estimated length {}", + prompt_excerpt.len(), + len_guess + ); + prompt_excerpt +} + +fn excerpt_range_for_position( cursor_point: Point, - cursor_offset: usize, -) -> Result<(String, Range)> { - let mut builder = ExcerptPromptBuilder::default(); - builder.file_path = snapshot.file().map_or(Cow::Borrowed("untitled"), |file| { - file.path().to_string_lossy() - }); - + byte_limit: usize, + line_limit: u32, + path: &str, + snapshot: &BufferSnapshot, +) -> Result<(Range, usize)> { let cursor_row = cursor_point.row; - let cursor_line_start_offset = Point::new(cursor_row, 0).to_offset(snapshot); - let cursor_line_end_offset = - Point::new(cursor_row, snapshot.line_len(cursor_row)).to_offset(snapshot); - builder - .before_cursor - .prepend_from_buffer(snapshot, cursor_line_start_offset..cursor_offset); - builder - .after_cursor - .append_from_buffer(snapshot, cursor_offset..cursor_line_end_offset); + let last_buffer_row = snapshot.max_point().row; - if builder.len() > BUFFER_EXCERPT_BYTE_LIMIT { + // This is an overestimate because it includes parts of prompt_for_excerpt which are + // conditionally skipped. + let mut len_guess = 0; + len_guess += "```".len() + path.len() + 1; + len_guess += START_OF_FILE_MARKER.len() + 1; + len_guess += EDITABLE_REGION_START_MARKER.len() + 1; + len_guess += CURSOR_MARKER.len(); + len_guess += EDITABLE_REGION_END_MARKER.len() + 1; + len_guess += "```".len() + 1; + + len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap(); + + if len_guess > byte_limit { return Err(anyhow!("Current line too long to send to model.")); } - let last_buffer_row = snapshot.max_point().row; - - // Figure out how many lines of the buffer to include in the prompt, walking outwards from the - // cursor. Even if a line before or after the cursor causes the byte limit to be exceeded, - // continues walking in the other direction. - let mut first_included_row = cursor_row; - let mut last_included_row = cursor_row; + let mut excerpt_start_row = cursor_row; + let mut excerpt_end_row = cursor_row; let mut no_more_before = cursor_row == 0; let mut no_more_after = cursor_row >= last_buffer_row; - let mut output_len = builder.len(); let mut row_delta = 1; loop { if !no_more_before { let row = cursor_point.row - row_delta; - let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap(); - let mut new_output_len = output_len + line_len; + let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); + let mut new_len_guess = len_guess + line_len; if row == 0 { - new_output_len += START_OF_FILE_MARKER.len() + 1; + new_len_guess += START_OF_FILE_MARKER.len() + 1; } - if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT { - output_len = new_output_len; - first_included_row = row; + if new_len_guess <= byte_limit { + len_guess = new_len_guess; + excerpt_start_row = row; if row == 0 { - builder.include_start_of_file_marker = true; no_more_before = true; } } else { no_more_before = true; } } + if excerpt_end_row - excerpt_start_row >= line_limit { + break; + } if !no_more_after { let row = cursor_point.row + row_delta; - let line_len: usize = (snapshot.line_len(row) + 1).try_into().unwrap(); - let new_output_len = output_len + line_len; - if new_output_len <= BUFFER_EXCERPT_BYTE_LIMIT { - output_len = new_output_len; - last_included_row = row; + let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap(); + let new_len_guess = len_guess + line_len; + if new_len_guess <= byte_limit { + len_guess = new_len_guess; + excerpt_end_row = row; if row >= last_buffer_row { no_more_after = true; } @@ -1200,66 +1158,21 @@ fn prompt_for_excerpt( no_more_after = true; } } + if excerpt_end_row - excerpt_start_row >= line_limit { + break; + } if no_more_before && no_more_after { break; } row_delta += 1; } - // Include a line of context outside the editable region, but only if it is not the first line - // (otherwise the first line of the file would never be uneditable). - let first_editable_row = if first_included_row != 0 - && first_included_row < cursor_row - && !snapshot.is_line_blank(first_included_row) - { - let mut before_editable_region = ReversedStringChunks::default(); - before_editable_region.prepend_from_buffer( - snapshot, - Point::new(first_included_row, 0) - ..Point::new(first_included_row, snapshot.line_len(first_included_row)), - ); - builder.before_editable_region = Some(before_editable_region); - first_included_row + 1 - } else { - first_included_row - }; - - // Include a line of context outside the editable region, but only if it is not the last line - // (otherwise the first line of the file would never be uneditable). - let last_editable_row = if last_included_row < last_buffer_row - && last_included_row > cursor_row - && !snapshot.is_line_blank(last_included_row) - { - let mut after_editable_region = StringChunks::default(); - after_editable_region.append_from_buffer( - snapshot, - Point::new(last_included_row, 0) - ..Point::new(last_included_row, snapshot.line_len(last_included_row)), - ); - builder.after_editable_region = Some(after_editable_region); - last_included_row + 1 - } else { - last_included_row - }; - - let editable_range = (Point::new(first_editable_row, 0) - ..Point::new(last_editable_row, snapshot.line_len(last_editable_row))) - .to_offset(snapshot); - - let before_cursor_row = editable_range.start..cursor_line_start_offset; - let after_cursor_row = cursor_line_end_offset..editable_range.end; - if !before_cursor_row.is_empty() { - builder - .before_cursor - .prepend_from_buffer(snapshot, before_cursor_row); - } - if !after_cursor_row.is_empty() { - builder - .after_cursor - .append_from_buffer(snapshot, after_cursor_row); - } - - anyhow::Ok((builder.to_string(), editable_range)) + let excerpt_start = Point::new(excerpt_start_row, 0); + let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row)); + Ok(( + excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot), + len_guess, + )) } fn prompt_for_events<'a>(