zeta: Send up diagnostics with prediction requests (#24384)

This PR makes it so we send up the diagnostic groups as additional data
with the edit prediction request.

We're not yet making use of them, but we are recording them so we can
use them later (e.g., to train the model).

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Marshall Bowers 2025-02-06 13:07:26 -05:00 committed by GitHub
parent 13089d7ec6
commit 09967ac3d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 145 additions and 31 deletions

View file

@ -37,6 +37,7 @@ language_models.workspace = true
log.workspace = true
menu.workspace = true
postage.workspace = true
project.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true

View file

@ -30,6 +30,7 @@ use language::{
};
use language_models::LlmApiToken;
use postage::watch;
use project::Project;
use settings::WorktreeId;
use std::{
borrow::Cow,
@ -363,6 +364,7 @@ impl Zeta {
pub fn request_completion_impl<F, R>(
&mut self,
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: bool,
@ -374,6 +376,7 @@ impl Zeta {
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
{
let snapshot = self.report_changes_for_buffer(&buffer, cx);
let diagnostic_groups = snapshot.diagnostic_groups(None);
let cursor_point = cursor.to_point(&snapshot);
let cursor_offset = cursor_point.to_offset(&snapshot);
let events = self.events.clone();
@ -387,10 +390,39 @@ impl Zeta {
let is_staff = cx.is_staff();
let buffer = buffer.clone();
let local_lsp_store =
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
Some(
diagnostic_groups
.into_iter()
.filter_map(|(language_server_id, diagnostic_group)| {
let language_server =
local_lsp_store.running_language_server_for_id(language_server_id)?;
Some((
language_server.name(),
diagnostic_group.resolve::<usize>(&snapshot),
))
})
.collect::<Vec<_>>(),
)
} else {
None
};
cx.spawn(|_, cx| async move {
let request_sent_at = Instant::now();
let (input_events, input_excerpt, excerpt_range, input_outline) = cx
struct BackgroundValues {
input_events: String,
input_excerpt: String,
excerpt_range: Range<usize>,
input_outline: String,
}
let values = cx
.background_executor()
.spawn({
let snapshot = snapshot.clone();
@ -419,18 +451,36 @@ impl Zeta {
// is not counted towards TOTAL_BYTE_LIMIT.
let input_outline = prompt_for_outline(&snapshot);
anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
anyhow::Ok(BackgroundValues {
input_events,
input_excerpt,
excerpt_range,
input_outline,
})
}
})
.await?;
log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
log::debug!(
"Events:\n{}\nExcerpt:\n{}",
values.input_events,
values.input_excerpt
);
let body = PredictEditsBody {
input_events: input_events.clone(),
input_excerpt: input_excerpt.clone(),
outline: Some(input_outline.clone()),
input_events: values.input_events.clone(),
input_excerpt: values.input_excerpt.clone(),
outline: Some(values.input_outline.clone()),
can_collect_data,
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
diagnostic_groups
.into_iter()
.map(|(name, diagnostic_group)| {
Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
})
.collect::<Result<Vec<_>>>()
.log_err()
}),
};
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
@ -442,12 +492,12 @@ impl Zeta {
output_excerpt,
buffer,
&snapshot,
excerpt_range,
values.excerpt_range,
cursor_offset,
path,
input_outline,
input_events,
input_excerpt,
values.input_outline,
values.input_events,
values.input_excerpt,
request_sent_at,
&cx,
)
@ -466,11 +516,13 @@ impl Zeta {
and then another
"#};
let project = None;
let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
let position = buffer.read(cx).anchor_before(Point::new(1, 0));
let completion_tasks = vec![
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -486,6 +538,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -501,6 +554,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -517,6 +571,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -533,6 +588,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -548,6 +604,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -562,6 +619,7 @@ and then another
cx,
),
self.fake_completion(
project,
&buffer,
position,
PredictEditsResponse {
@ -594,6 +652,7 @@ and then another
#[cfg(any(test, feature = "test-support"))]
pub fn fake_completion(
&mut self,
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
position: language::Anchor,
response: PredictEditsResponse,
@ -601,19 +660,21 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;
self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| {
ready(Ok(response))
})
}
pub fn request_completion(
&mut self,
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<InlineCompletion>>> {
self.request_completion_impl(
project,
buffer,
position,
can_collect_data,
@ -1494,6 +1555,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
fn refresh(
&mut self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
position: language::Anchor,
_debounce: bool,
@ -1529,7 +1591,13 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
let completion_request = this.update(&mut cx, |this, cx| {
this.last_request_timestamp = Instant::now();
this.zeta.update(cx, |zeta, cx| {
zeta.request_completion(&buffer, position, can_collect_data, cx)
zeta.request_completion(
project.as_ref(),
&buffer,
position,
can_collect_data,
cx,
)
})
});
@ -1858,7 +1926,7 @@ mod tests {
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(&buffer, cursor, false, cx)
zeta.request_completion(None, &buffer, cursor, false, cx)
});
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();