Re-introduce syntax-based context and use new model (#24469)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Antonio Scandurra 2025-02-07 20:19:57 +01:00 committed by GitHub
parent fd7fa87939
commit f6e396837c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 427 additions and 385 deletions

View file

@ -0,0 +1,238 @@
use crate::{
tokens_for_bytes, CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
START_OF_FILE_MARKER,
};
use language::{BufferSnapshot, Point};
use std::{fmt::Write, ops::Range};
#[derive(Debug)]
pub struct InputExcerpt {
pub editable_range: Range<Point>,
pub prompt: String,
pub speculated_output: String,
}
pub fn excerpt_for_cursor_position(
position: Point,
path: &str,
snapshot: &BufferSnapshot,
editable_region_token_limit: usize,
context_token_limit: usize,
) -> InputExcerpt {
let mut scope_range = position..position;
let mut remaining_edit_tokens = editable_region_token_limit;
while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
let parent_tokens = tokens_for_bytes(parent.byte_range().len());
let parent_point_range = Point::new(
parent.start_position().row as u32,
parent.start_position().column as u32,
)
..Point::new(
parent.end_position().row as u32,
parent.end_position().column as u32,
);
if parent_point_range == scope_range {
break;
} else if parent_tokens <= editable_region_token_limit {
scope_range = parent_point_range;
remaining_edit_tokens = editable_region_token_limit - parent_tokens;
} else {
break;
}
}
let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
let mut prompt = String::new();
let mut speculated_output = String::new();
writeln!(&mut prompt, "```{path}").unwrap();
if context_range.start == Point::zero() {
writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
}
for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
prompt.push_str(chunk.text);
}
push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
push_editable_range(
position,
snapshot,
editable_range.clone(),
&mut speculated_output,
);
for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n```").unwrap();
InputExcerpt {
editable_range,
prompt,
speculated_output,
}
}
fn push_editable_range(
cursor_position: Point,
snapshot: &BufferSnapshot,
editable_range: Range<Point>,
prompt: &mut String,
) {
writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
prompt.push_str(chunk.text);
}
prompt.push_str(CURSOR_MARKER);
for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
prompt.push_str(chunk.text);
}
write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
}
fn expand_range(
snapshot: &BufferSnapshot,
range: Range<Point>,
mut remaining_tokens: usize,
) -> Range<Point> {
let mut expanded_range = range.clone();
expanded_range.start.column = 0;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
loop {
let mut expanded = false;
if remaining_tokens > 0 && expanded_range.start.row > 0 {
expanded_range.start.row -= 1;
let line_tokens =
tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
expanded_range.end.row += 1;
expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
let line_tokens = tokens_for_bytes(expanded_range.end.column as usize);
remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
expanded = true;
}
if !expanded {
break;
}
}
expanded_range
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{App, AppContext};
use indoc::indoc;
use language::{Buffer, Language, LanguageConfig, LanguageMatcher};
use std::sync::Arc;
#[gpui::test]
fn test_excerpt_for_cursor_position(cx: &mut App) {
let text = indoc! {r#"
fn foo() {
let x = 42;
println!("Hello, world!");
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
return sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.gen_range(1..101));
}
numbers
}
"#};
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
// Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
// when a larger scope doesn't fit the editable region.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
let x = 42;
println!("Hello, world!");
<|editable_region_start|>
}
fn bar() {
let x = 42;
let mut sum = 0;
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
<|editable_region_end|>
let mut rng = rand::thread_rng();
let mut numbers = Vec::new();
```"#}
);
// The `bar` function won't fit within the editable region, so we resort to line-based expansion.
let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
assert_eq!(
excerpt.prompt,
indoc! {r#"
```main.rs
fn bar() {
let x = 42;
let mut sum = 0;
<|editable_region_start|>
for i in 0..x {
sum += i;
}
println!("Sum: {}", sum);
r<|user_cursor_is_here|>eturn sum;
}
fn generate_random_numbers() -> Vec<i32> {
let mut rng = rand::thread_rng();
<|editable_region_end|>
let mut numbers = Vec::new();
for _ in 0..5 {
numbers.push(rng.gen_range(1..101));
```"#}
);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
}
}

View file

@ -1,5 +1,6 @@
mod completion_diff_element;
mod init;
mod input_excerpt;
mod license_detection;
mod onboarding_banner;
mod onboarding_modal;
@ -25,9 +26,8 @@ use gpui::{
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
};
use http_client::{HttpClient, Method};
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, Point, ToOffset, ToPoint,
};
use input_excerpt::excerpt_for_cursor_position;
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint};
use language_models::LlmApiToken;
use postage::watch;
use project::Project;
@ -57,38 +57,13 @@ const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
const MAX_CONTEXT_TOKENS: usize = 100;
const MAX_REWRITE_TOKENS: usize = 300;
const MAX_EVENT_TOKENS: usize = 400;
/// Typical number of string bytes per token for the purposes of limiting model input. This is
/// intentionally low to err on the side of underestimating limits.
const BYTES_PER_TOKEN_GUESS: usize = 3;
/// Output token limit, used to inform the size of the input. A copy of this constant is also in
/// `crates/collab/src/llm.rs`.
const MAX_OUTPUT_TOKENS: usize = 2048;
/// Total bytes limit for editable region of buffer excerpt.
///
/// The number of output tokens is relevant to the size of the input excerpt because the model is
/// tasked with outputting a modified excerpt. `2/3` is chosen so that there are some output tokens
/// remaining for the model to specify insertions.
const BUFFER_EXCERPT_BYTE_LIMIT: usize = (MAX_OUTPUT_TOKENS * 2 / 3) * BYTES_PER_TOKEN_GUESS;
/// Total line limit for editable region of buffer excerpt.
const BUFFER_EXCERPT_LINE_LIMIT: u32 = 64;
/// Note that this is not the limit for the overall prompt, just for the inputs to the template
/// instantiated in `crates/collab/src/llm.rs`.
const TOTAL_BYTE_LIMIT: usize = BUFFER_EXCERPT_BYTE_LIMIT * 2;
/// Maximum number of events to include in the prompt.
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
/// Maximum number of string bytes in a single event. Arbitrarily choosing this to be 4x the size of
/// equally splitting up the the remaining bytes after the largest possible buffer excerpt.
const PER_EVENT_BYTE_LIMIT: usize =
(TOTAL_BYTE_LIMIT - BUFFER_EXCERPT_BYTE_LIMIT) / MAX_EVENT_COUNT * 4;
actions!(edit_prediction, [ClearHistory]);
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
@ -418,7 +393,8 @@ impl Zeta {
struct BackgroundValues {
input_events: String,
input_excerpt: String,
excerpt_range: Range<usize>,
speculated_output: String,
editable_range: Range<usize>,
input_outline: String,
}
@ -429,32 +405,21 @@ impl Zeta {
let path = path.clone();
async move {
let path = path.to_string_lossy();
let (excerpt_range, excerpt_len_guess) = excerpt_range_for_position(
let input_excerpt = excerpt_for_cursor_position(
cursor_point,
BUFFER_EXCERPT_BYTE_LIMIT,
BUFFER_EXCERPT_LINE_LIMIT,
&path,
&snapshot,
)?;
let input_excerpt = prompt_for_excerpt(
cursor_offset,
&excerpt_range,
excerpt_len_guess,
&path,
&snapshot,
MAX_REWRITE_TOKENS,
MAX_CONTEXT_TOKENS,
);
let bytes_remaining = TOTAL_BYTE_LIMIT.saturating_sub(input_excerpt.len());
let input_events = prompt_for_events(events.iter(), bytes_remaining);
// Note that input_outline is not currently used in prompt generation and so
// is not counted towards TOTAL_BYTE_LIMIT.
let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
let input_outline = prompt_for_outline(&snapshot);
anyhow::Ok(BackgroundValues {
input_events,
input_excerpt,
excerpt_range,
input_excerpt: input_excerpt.prompt,
speculated_output: input_excerpt.speculated_output,
editable_range: input_excerpt.editable_range.to_offset(&snapshot),
input_outline,
})
}
@ -462,7 +427,7 @@ impl Zeta {
.await?;
log::debug!(
"Events:\n{}\nExcerpt:\n{}",
"Events:\n{}\nExcerpt:\n{:?}",
values.input_events,
values.input_excerpt
);
@ -470,6 +435,7 @@ impl Zeta {
let body = PredictEditsBody {
input_events: values.input_events.clone(),
input_excerpt: values.input_excerpt.clone(),
speculated_output: Some(values.speculated_output),
outline: Some(values.input_outline.clone()),
can_collect_data,
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
@ -492,7 +458,7 @@ impl Zeta {
output_excerpt,
buffer,
&snapshot,
values.excerpt_range,
values.editable_range,
cursor_offset,
path,
values.input_outline,
@ -508,6 +474,8 @@ impl Zeta {
// Generates several example completions of various states to fill the Zeta completion modal
#[cfg(any(test, feature = "test-support"))]
pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
use language::Point;
let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
And maybe a short line
@ -697,7 +665,7 @@ and then another
loop {
let request_builder = http_client::Request::builder().method(Method::POST).uri(
http_client
.build_zed_llm_url("/predict_edits", &[])?
.build_zed_llm_url("/predict_edits/v2", &[])?
.as_ref(),
);
let request = request_builder
@ -737,7 +705,7 @@ and then another
output_excerpt: String,
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
excerpt_range: Range<usize>,
editable_range: Range<usize>,
cursor_offset: usize,
path: Arc<Path>,
input_outline: String,
@ -754,9 +722,9 @@ and then another
.background_executor()
.spawn({
let output_excerpt = output_excerpt.clone();
let excerpt_range = excerpt_range.clone();
let editable_range = editable_range.clone();
let snapshot = snapshot.clone();
async move { Self::parse_edits(output_excerpt, excerpt_range, &snapshot) }
async move { Self::parse_edits(output_excerpt, editable_range, &snapshot) }
})
.await?
.into();
@ -779,7 +747,7 @@ and then another
Ok(Some(InlineCompletion {
id: InlineCompletionId::new(),
path,
excerpt_range,
excerpt_range: editable_range,
cursor_offset,
edits,
edit_preview,
@ -796,7 +764,7 @@ and then another
fn parse_edits(
output_excerpt: Arc<str>,
excerpt_range: Range<usize>,
editable_range: Range<usize>,
snapshot: &BufferSnapshot,
) -> Result<Vec<(Range<Anchor>, String)>> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
@ -840,13 +808,13 @@ and then another
let new_text = &content[..codefence_end];
let old_text = snapshot
.text_for_range(excerpt_range.clone())
.text_for_range(editable_range.clone())
.collect::<String>();
Ok(Self::compute_edits(
old_text,
new_text,
excerpt_range.start,
editable_range.start,
&snapshot,
))
}
@ -1080,9 +1048,7 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
.unwrap();
if let Some(outline) = snapshot.outline(None) {
let guess_size = outline.items.len() * 15;
input_outline.reserve(guess_size);
for item in outline.items.iter() {
for item in &outline.items {
let spacing = " ".repeat(item.depth);
writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
}
@ -1093,181 +1059,20 @@ fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
input_outline
}
fn prompt_for_excerpt(
offset: usize,
excerpt_range: &Range<usize>,
mut len_guess: usize,
path: &str,
snapshot: &BufferSnapshot,
) -> String {
let point_range = excerpt_range.to_point(snapshot);
// Include one line of extra context before and after editable range, if those lines are non-empty.
let extra_context_before_range =
if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
let range =
(Point::new(point_range.start.row - 1, 0)..point_range.start).to_offset(snapshot);
len_guess += range.end - range.start;
Some(range)
} else {
None
};
let extra_context_after_range = if point_range.end.row < snapshot.max_point().row
&& !snapshot.is_line_blank(point_range.end.row + 1)
{
let range = (point_range.end
..Point::new(
point_range.end.row + 1,
snapshot.line_len(point_range.end.row + 1),
))
.to_offset(snapshot);
len_guess += range.end - range.start;
Some(range)
} else {
None
};
let mut prompt_excerpt = String::with_capacity(len_guess);
writeln!(prompt_excerpt, "```{}", path).unwrap();
if excerpt_range.start == 0 {
writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
}
if let Some(extra_context_before_range) = extra_context_before_range {
for chunk in snapshot.text_for_range(extra_context_before_range) {
prompt_excerpt.push_str(chunk);
}
}
writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
prompt_excerpt.push_str(chunk);
}
prompt_excerpt.push_str(CURSOR_MARKER);
for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
prompt_excerpt.push_str(chunk);
}
write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
if let Some(extra_context_after_range) = extra_context_after_range {
for chunk in snapshot.text_for_range(extra_context_after_range) {
prompt_excerpt.push_str(chunk);
}
}
write!(prompt_excerpt, "\n```").unwrap();
debug_assert!(
prompt_excerpt.len() <= len_guess,
"Excerpt length {} exceeds estimated length {}",
prompt_excerpt.len(),
len_guess
);
prompt_excerpt
}
fn excerpt_range_for_position(
cursor_point: Point,
byte_limit: usize,
line_limit: u32,
path: &str,
snapshot: &BufferSnapshot,
) -> Result<(Range<usize>, usize)> {
let cursor_row = cursor_point.row;
let last_buffer_row = snapshot.max_point().row;
// This is an overestimate because it includes parts of prompt_for_excerpt which are
// conditionally skipped.
let mut len_guess = 0;
len_guess += "```".len() + path.len() + 1;
len_guess += START_OF_FILE_MARKER.len() + 1;
len_guess += EDITABLE_REGION_START_MARKER.len() + 1;
len_guess += CURSOR_MARKER.len();
len_guess += EDITABLE_REGION_END_MARKER.len() + 1;
len_guess += "```".len() + 1;
len_guess += usize::try_from(snapshot.line_len(cursor_row) + 1).unwrap();
if len_guess > byte_limit {
return Err(anyhow!("Current line too long to send to model."));
}
let mut excerpt_start_row = cursor_row;
let mut excerpt_end_row = cursor_row;
let mut no_more_before = cursor_row == 0;
let mut no_more_after = cursor_row >= last_buffer_row;
let mut row_delta = 1;
loop {
if !no_more_before {
let row = cursor_point.row - row_delta;
let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
let mut new_len_guess = len_guess + line_len;
if row == 0 {
new_len_guess += START_OF_FILE_MARKER.len() + 1;
}
if new_len_guess <= byte_limit {
len_guess = new_len_guess;
excerpt_start_row = row;
if row == 0 {
no_more_before = true;
}
} else {
no_more_before = true;
}
}
if excerpt_end_row - excerpt_start_row >= line_limit {
break;
}
if !no_more_after {
let row = cursor_point.row + row_delta;
let line_len: usize = usize::try_from(snapshot.line_len(row) + 1).unwrap();
let new_len_guess = len_guess + line_len;
if new_len_guess <= byte_limit {
len_guess = new_len_guess;
excerpt_end_row = row;
if row >= last_buffer_row {
no_more_after = true;
}
} else {
no_more_after = true;
}
}
if excerpt_end_row - excerpt_start_row >= line_limit {
break;
}
if no_more_before && no_more_after {
break;
}
row_delta += 1;
}
let excerpt_start = Point::new(excerpt_start_row, 0);
let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
Ok((
excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot),
len_guess,
))
}
fn prompt_for_events<'a>(
events: impl Iterator<Item = &'a Event>,
mut bytes_remaining: usize,
) -> String {
fn prompt_for_events(events: &VecDeque<Event>, mut remaining_tokens: usize) -> String {
let mut result = String::new();
for event in events {
if !result.is_empty() {
result.push('\n');
result.push('\n');
}
for event in events.iter().rev() {
let event_string = event.to_prompt();
let len = event_string.len();
if len > PER_EVENT_BYTE_LIMIT {
continue;
}
if len > bytes_remaining {
let event_tokens = tokens_for_bytes(event_string.len());
if event_tokens > remaining_tokens {
break;
}
bytes_remaining -= len;
result.push_str(&event_string);
if !result.is_empty() {
result.insert_str(0, "\n\n");
}
result.insert_str(0, &event_string);
remaining_tokens -= event_tokens;
}
result
}
@ -1750,6 +1555,13 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
}
}
fn tokens_for_bytes(bytes: usize) -> usize {
/// Typical number of string bytes per token for the purposes of limiting model input. This is
/// intentionally low to err on the side of underestimating limits.
const BYTES_PER_TOKEN_GUESS: usize = 3;
bytes / BYTES_PER_TOKEN_GUESS
}
#[cfg(test)]
mod tests {
use client::test::FakeServer;
@ -1757,6 +1569,7 @@ mod tests {
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use indoc::indoc;
use language::Point;
use language_models::RefreshLlmTokenListener;
use rpc::proto;
use settings::SettingsStore;