Polish edit predictions (#24732)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: as-cii <as-cii@zed.dev>
Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
Agus Zubiaga 2025-02-12 12:56:31 -03:00 committed by GitHub
parent 2b7d3726b4
commit 51092c4e31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 353 additions and 161 deletions

View file

@ -27,7 +27,10 @@ use gpui::{
};
use http_client::{HttpClient, Method};
use input_excerpt::excerpt_for_cursor_position;
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint};
use language::{
Anchor, Buffer, BufferSnapshot, CharClassifier, CharKind, EditPreview, OffsetRangeExt,
ToOffset, ToPoint,
};
use language_models::LlmApiToken;
use postage::watch;
use project::Project;
@ -57,9 +60,9 @@ const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
const MAX_CONTEXT_TOKENS: usize = 100;
const MAX_REWRITE_TOKENS: usize = 300;
const MAX_EVENT_TOKENS: usize = 400;
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
const MAX_EVENT_TOKENS: usize = 500;
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
@ -834,8 +837,34 @@ and then another
offset: usize,
snapshot: &BufferSnapshot,
) -> Vec<(Range<Anchor>, String)> {
let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
fn tokenize(text: &str) -> Vec<&str> {
let classifier = CharClassifier::new(None).for_completion(true);
let mut chars = text.chars().peekable();
let mut prev_ch = chars.peek().copied();
let mut tokens = Vec::new();
let mut start = 0;
let mut end = 0;
while let Some(ch) = chars.next() {
let prev_kind = prev_ch.map(|ch| classifier.kind(ch));
let kind = classifier.kind(ch);
if Some(kind) != prev_kind || (kind == CharKind::Punctuation && Some(ch) != prev_ch)
{
tokens.push(&text[start..end]);
start = end;
}
end += ch.len_utf8();
prev_ch = Some(ch);
}
tokens.push(&text[start..end]);
tokens
}
let old_tokens = tokenize(&old_text);
let new_tokens = tokenize(new_text);
let diff = similar::TextDiffConfig::default()
.algorithm(similar::Algorithm::Patience)
.diff_slices(&old_tokens, &new_tokens);
let mut edits: Vec<(Range<usize>, String)> = Vec::new();
let mut old_start = offset;
for change in diff.iter_all_changes() {
@ -1705,6 +1734,70 @@ mod tests {
})
}
#[gpui::test]
async fn test_clean_up_diff(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
client::init_settings(cx);
});
let edits = edits_for_prediction(
indoc! {"
fn main() {
let word_1 = \"lorem\";
let range = word.len()..word.len();
}
"},
indoc! {"
<|editable_region_start|>
fn main() {
let word_1 = \"lorem\";
let range = word_1.len()..word_1.len();
}
<|editable_region_end|>
"},
cx,
)
.await;
assert_eq!(
edits,
[
(Point::new(2, 20)..Point::new(2, 20), "_1".to_string()),
(Point::new(2, 32)..Point::new(2, 32), "_1".to_string()),
]
);
let edits = edits_for_prediction(
indoc! {"
fn main() {
let story = \"the quick\"
}
"},
indoc! {"
<|editable_region_start|>
fn main() {
let story = \"the quick brown fox jumps over the lazy dog\";
}
<|editable_region_end|>
"},
cx,
)
.await;
assert_eq!(
edits,
[
(
Point::new(1, 26)..Point::new(1, 26),
" brown fox jumps over the lazy dog".to_string()
),
(Point::new(1, 27)..Point::new(1, 27), ";".to_string()),
]
);
}
#[gpui::test]
async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
cx.update(|cx| {
@ -1768,6 +1861,58 @@ mod tests {
);
}
async fn edits_for_prediction(
buffer_content: &str,
completion_response: &str,
cx: &mut TestAppContext,
) -> Vec<(Range<Point>, String)> {
let completion_response = completion_response.to_string();
let http_client = FakeHttpClient::create(move |_| {
let completion = completion_response.clone();
async move {
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&PredictEditsResponse {
request_id: Uuid::new_v4(),
output_excerpt: completion,
})
.unwrap()
.into(),
)
.unwrap())
}
});
let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
RefreshLlmTokenListener::register(client.clone(), cx);
});
let server = FakeServer::for_client(42, &client, cx).await;
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let zeta = cx.new(|cx| Zeta::new(client, user_store, cx));
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(None, &buffer, cursor, false, cx)
});
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
server.respond(
token_request.receipt(),
proto::GetLlmTokenResponse { token: "".into() },
);
let completion = completion_task.await.unwrap().unwrap();
completion
.edits
.into_iter()
.map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone()))
.collect::<Vec<_>>()
}
fn to_completion_edits(
iterator: impl IntoIterator<Item = (Range<usize>, String)>,
buffer: &Entity<Buffer>,