zeta: Add CLI tool for querying edit predictions and related context (#35491)

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
Michael Sloan 2025-08-01 15:08:09 -06:00 committed by GitHub
parent 561ccf86aa
commit 6052115825
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 719 additions and 102 deletions

View file

@ -146,14 +146,14 @@ pub struct InlineCompletion {
input_events: Arc<str>,
input_excerpt: Arc<str>,
output_excerpt: Arc<str>,
request_sent_at: Instant,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
}
impl InlineCompletion {
fn latency(&self) -> Duration {
self.response_received_at
.duration_since(self.request_sent_at)
.duration_since(self.buffer_snapshotted_at)
}
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
@ -391,104 +391,48 @@ impl Zeta {
+ Send
+ 'static,
{
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let snapshot = self.report_changes_for_buffer(&buffer, cx);
let diagnostic_groups = snapshot.diagnostic_groups(None);
let cursor_point = cursor.to_point(&snapshot);
let cursor_offset = cursor_point.to_offset(&snapshot);
let events = self.events.clone();
let path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
let zeta = cx.entity();
let events = self.events.clone();
let client = self.client.clone();
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
let buffer = buffer.clone();
let local_lsp_store =
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
Some(
diagnostic_groups
.into_iter()
.filter_map(|(language_server_id, diagnostic_group)| {
let language_server =
local_lsp_store.running_language_server_for_id(language_server_id)?;
Some((
language_server.name(),
diagnostic_group.resolve::<usize>(&snapshot),
))
})
.collect::<Vec<_>>(),
)
} else {
None
};
let full_path: Arc<Path> = snapshot
.file()
.map(|f| Arc::from(f.full_path(cx).as_path()))
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
let full_path_str = full_path.to_string_lossy().to_string();
let cursor_point = cursor.to_point(&snapshot);
let cursor_offset = cursor_point.to_offset(&snapshot);
let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
let gather_task = gather_context(
project,
full_path_str,
&snapshot,
cursor_point,
make_events_prompt,
can_collect_data,
cx,
);
cx.spawn(async move |this, cx| {
let request_sent_at = Instant::now();
struct BackgroundValues {
input_events: String,
input_excerpt: String,
speculated_output: String,
editable_range: Range<usize>,
input_outline: String,
}
let values = cx
.background_spawn({
let snapshot = snapshot.clone();
let path = path.clone();
async move {
let path = path.to_string_lossy();
let input_excerpt = excerpt_for_cursor_position(
cursor_point,
&path,
&snapshot,
MAX_REWRITE_TOKENS,
MAX_CONTEXT_TOKENS,
);
let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
let input_outline = prompt_for_outline(&snapshot);
anyhow::Ok(BackgroundValues {
input_events,
input_excerpt: input_excerpt.prompt,
speculated_output: input_excerpt.speculated_output,
editable_range: input_excerpt.editable_range.to_offset(&snapshot),
input_outline,
})
}
})
.await?;
let GatherContextOutput {
body,
editable_range,
} = gather_task.await?;
log::debug!(
"Events:\n{}\nExcerpt:\n{:?}",
values.input_events,
values.input_excerpt
body.input_events,
body.input_excerpt
);
let body = PredictEditsBody {
input_events: values.input_events.clone(),
input_excerpt: values.input_excerpt.clone(),
speculated_output: Some(values.speculated_output),
outline: Some(values.input_outline.clone()),
can_collect_data,
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
diagnostic_groups
.into_iter()
.map(|(name, diagnostic_group)| {
Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
})
.collect::<Result<Vec<_>>>()
.log_err()
}),
};
let input_outline = body.outline.clone().unwrap_or_default();
let input_events = body.input_events.clone();
let input_excerpt = body.input_excerpt.clone();
let response = perform_predict_edits(PerformPredictEditsParams {
client,
@ -546,13 +490,13 @@ impl Zeta {
response,
buffer,
&snapshot,
values.editable_range,
editable_range,
cursor_offset,
path,
values.input_outline,
values.input_events,
values.input_excerpt,
request_sent_at,
full_path,
input_outline,
input_events,
input_excerpt,
buffer_snapshotted_at,
&cx,
)
.await
@ -751,7 +695,7 @@ and then another
)
}
fn perform_predict_edits(
pub fn perform_predict_edits(
params: PerformPredictEditsParams,
) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
async move {
@ -906,7 +850,7 @@ and then another
input_outline: String,
input_events: String,
input_excerpt: String,
request_sent_at: Instant,
buffer_snapshotted_at: Instant,
cx: &AsyncApp,
) -> Task<Result<Option<InlineCompletion>>> {
let snapshot = snapshot.clone();
@ -952,7 +896,7 @@ and then another
input_events: input_events.into(),
input_excerpt: input_excerpt.into(),
output_excerpt,
request_sent_at,
buffer_snapshotted_at,
response_received_at: Instant::now(),
}))
})
@ -1136,7 +1080,7 @@ and then another
}
}
struct PerformPredictEditsParams {
pub struct PerformPredictEditsParams {
pub client: Arc<Client>,
pub llm_token: LlmApiToken,
pub app_version: SemanticVersion,
@ -1211,6 +1155,77 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
pub struct GatherContextOutput {
pub body: PredictEditsBody,
pub editable_range: Range<usize>,
}
pub fn gather_context(
project: Option<&Entity<Project>>,
full_path_str: String,
snapshot: &BufferSnapshot,
cursor_point: language::Point,
make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
let local_lsp_store =
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
let diagnostic_groups: Vec<(String, serde_json::Value)> =
if let Some(local_lsp_store) = local_lsp_store {
snapshot
.diagnostic_groups(None)
.into_iter()
.filter_map(|(language_server_id, diagnostic_group)| {
let language_server =
local_lsp_store.running_language_server_for_id(language_server_id)?;
let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot);
let language_server_name = language_server.name().to_string();
let serialized = serde_json::to_value(diagnostic_group).unwrap();
Some((language_server_name, serialized))
})
.collect::<Vec<_>>()
} else {
Vec::new()
};
cx.background_spawn({
let snapshot = snapshot.clone();
async move {
let diagnostic_groups = if diagnostic_groups.is_empty() {
None
} else {
Some(diagnostic_groups)
};
let input_excerpt = excerpt_for_cursor_position(
cursor_point,
&full_path_str,
&snapshot,
MAX_REWRITE_TOKENS,
MAX_CONTEXT_TOKENS,
);
let input_events = make_events_prompt();
let input_outline = prompt_for_outline(&snapshot);
let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
let body = PredictEditsBody {
input_events,
input_excerpt: input_excerpt.prompt,
speculated_output: Some(input_excerpt.speculated_output),
outline: Some(input_outline),
can_collect_data,
diagnostic_groups,
};
Ok(GatherContextOutput {
body,
editable_range,
})
}
})
}
fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
let mut input_outline = String::new();
@ -1261,7 +1276,7 @@ struct RegisteredBuffer {
}
#[derive(Clone)]
enum Event {
pub enum Event {
BufferChange {
old_snapshot: BufferSnapshot,
new_snapshot: BufferSnapshot,
@ -1845,7 +1860,7 @@ mod tests {
input_events: "".into(),
input_excerpt: "".into(),
output_excerpt: "".into(),
request_sent_at: Instant::now(),
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
};