zeta: Add CLI tool for querying edit predictions and related context (#35491)
Release Notes: - N/A --------- Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
561ccf86aa
commit
6052115825
9 changed files with 719 additions and 102 deletions
|
@ -146,14 +146,14 @@ pub struct InlineCompletion {
|
|||
input_events: Arc<str>,
|
||||
input_excerpt: Arc<str>,
|
||||
output_excerpt: Arc<str>,
|
||||
request_sent_at: Instant,
|
||||
buffer_snapshotted_at: Instant,
|
||||
response_received_at: Instant,
|
||||
}
|
||||
|
||||
impl InlineCompletion {
|
||||
fn latency(&self) -> Duration {
|
||||
self.response_received_at
|
||||
.duration_since(self.request_sent_at)
|
||||
.duration_since(self.buffer_snapshotted_at)
|
||||
}
|
||||
|
||||
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
|
||||
|
@ -391,104 +391,48 @@ impl Zeta {
|
|||
+ Send
|
||||
+ 'static,
|
||||
{
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let snapshot = self.report_changes_for_buffer(&buffer, cx);
|
||||
let diagnostic_groups = snapshot.diagnostic_groups(None);
|
||||
let cursor_point = cursor.to_point(&snapshot);
|
||||
let cursor_offset = cursor_point.to_offset(&snapshot);
|
||||
let events = self.events.clone();
|
||||
let path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|f| Arc::from(f.full_path(cx).as_path()))
|
||||
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
|
||||
|
||||
let zeta = cx.entity();
|
||||
let events = self.events.clone();
|
||||
let client = self.client.clone();
|
||||
let llm_token = self.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let buffer = buffer.clone();
|
||||
|
||||
let local_lsp_store =
|
||||
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
|
||||
let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
|
||||
Some(
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.filter_map(|(language_server_id, diagnostic_group)| {
|
||||
let language_server =
|
||||
local_lsp_store.running_language_server_for_id(language_server_id)?;
|
||||
|
||||
Some((
|
||||
language_server.name(),
|
||||
diagnostic_group.resolve::<usize>(&snapshot),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
.file()
|
||||
.map(|f| Arc::from(f.full_path(cx).as_path()))
|
||||
.unwrap_or_else(|| Arc::from(Path::new("untitled")));
|
||||
let full_path_str = full_path.to_string_lossy().to_string();
|
||||
let cursor_point = cursor.to_point(&snapshot);
|
||||
let cursor_offset = cursor_point.to_offset(&snapshot);
|
||||
let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
|
||||
let gather_task = gather_context(
|
||||
project,
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
cursor_point,
|
||||
make_events_prompt,
|
||||
can_collect_data,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let request_sent_at = Instant::now();
|
||||
|
||||
struct BackgroundValues {
|
||||
input_events: String,
|
||||
input_excerpt: String,
|
||||
speculated_output: String,
|
||||
editable_range: Range<usize>,
|
||||
input_outline: String,
|
||||
}
|
||||
|
||||
let values = cx
|
||||
.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
let path = path.clone();
|
||||
async move {
|
||||
let path = path.to_string_lossy();
|
||||
let input_excerpt = excerpt_for_cursor_position(
|
||||
cursor_point,
|
||||
&path,
|
||||
&snapshot,
|
||||
MAX_REWRITE_TOKENS,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
);
|
||||
let input_events = prompt_for_events(&events, MAX_EVENT_TOKENS);
|
||||
let input_outline = prompt_for_outline(&snapshot);
|
||||
|
||||
anyhow::Ok(BackgroundValues {
|
||||
input_events,
|
||||
input_excerpt: input_excerpt.prompt,
|
||||
speculated_output: input_excerpt.speculated_output,
|
||||
editable_range: input_excerpt.editable_range.to_offset(&snapshot),
|
||||
input_outline,
|
||||
})
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
let GatherContextOutput {
|
||||
body,
|
||||
editable_range,
|
||||
} = gather_task.await?;
|
||||
|
||||
log::debug!(
|
||||
"Events:\n{}\nExcerpt:\n{:?}",
|
||||
values.input_events,
|
||||
values.input_excerpt
|
||||
body.input_events,
|
||||
body.input_excerpt
|
||||
);
|
||||
|
||||
let body = PredictEditsBody {
|
||||
input_events: values.input_events.clone(),
|
||||
input_excerpt: values.input_excerpt.clone(),
|
||||
speculated_output: Some(values.speculated_output),
|
||||
outline: Some(values.input_outline.clone()),
|
||||
can_collect_data,
|
||||
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.map(|(name, diagnostic_group)| {
|
||||
Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.log_err()
|
||||
}),
|
||||
};
|
||||
let input_outline = body.outline.clone().unwrap_or_default();
|
||||
let input_events = body.input_events.clone();
|
||||
let input_excerpt = body.input_excerpt.clone();
|
||||
|
||||
let response = perform_predict_edits(PerformPredictEditsParams {
|
||||
client,
|
||||
|
@ -546,13 +490,13 @@ impl Zeta {
|
|||
response,
|
||||
buffer,
|
||||
&snapshot,
|
||||
values.editable_range,
|
||||
editable_range,
|
||||
cursor_offset,
|
||||
path,
|
||||
values.input_outline,
|
||||
values.input_events,
|
||||
values.input_excerpt,
|
||||
request_sent_at,
|
||||
full_path,
|
||||
input_outline,
|
||||
input_events,
|
||||
input_excerpt,
|
||||
buffer_snapshotted_at,
|
||||
&cx,
|
||||
)
|
||||
.await
|
||||
|
@ -751,7 +695,7 @@ and then another
|
|||
)
|
||||
}
|
||||
|
||||
fn perform_predict_edits(
|
||||
pub fn perform_predict_edits(
|
||||
params: PerformPredictEditsParams,
|
||||
) -> impl Future<Output = Result<(PredictEditsResponse, Option<EditPredictionUsage>)>> {
|
||||
async move {
|
||||
|
@ -906,7 +850,7 @@ and then another
|
|||
input_outline: String,
|
||||
input_events: String,
|
||||
input_excerpt: String,
|
||||
request_sent_at: Instant,
|
||||
buffer_snapshotted_at: Instant,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Option<InlineCompletion>>> {
|
||||
let snapshot = snapshot.clone();
|
||||
|
@ -952,7 +896,7 @@ and then another
|
|||
input_events: input_events.into(),
|
||||
input_excerpt: input_excerpt.into(),
|
||||
output_excerpt,
|
||||
request_sent_at,
|
||||
buffer_snapshotted_at,
|
||||
response_received_at: Instant::now(),
|
||||
}))
|
||||
})
|
||||
|
@ -1136,7 +1080,7 @@ and then another
|
|||
}
|
||||
}
|
||||
|
||||
struct PerformPredictEditsParams {
|
||||
pub struct PerformPredictEditsParams {
|
||||
pub client: Arc<Client>,
|
||||
pub llm_token: LlmApiToken,
|
||||
pub app_version: SemanticVersion,
|
||||
|
@ -1211,6 +1155,77 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
|
|||
.sum()
|
||||
}
|
||||
|
||||
pub struct GatherContextOutput {
|
||||
pub body: PredictEditsBody,
|
||||
pub editable_range: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn gather_context(
|
||||
project: Option<&Entity<Project>>,
|
||||
full_path_str: String,
|
||||
snapshot: &BufferSnapshot,
|
||||
cursor_point: language::Point,
|
||||
make_events_prompt: impl FnOnce() -> String + Send + 'static,
|
||||
can_collect_data: bool,
|
||||
cx: &App,
|
||||
) -> Task<Result<GatherContextOutput>> {
|
||||
let local_lsp_store =
|
||||
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
|
||||
let diagnostic_groups: Vec<(String, serde_json::Value)> =
|
||||
if let Some(local_lsp_store) = local_lsp_store {
|
||||
snapshot
|
||||
.diagnostic_groups(None)
|
||||
.into_iter()
|
||||
.filter_map(|(language_server_id, diagnostic_group)| {
|
||||
let language_server =
|
||||
local_lsp_store.running_language_server_for_id(language_server_id)?;
|
||||
let diagnostic_group = diagnostic_group.resolve::<usize>(&snapshot);
|
||||
let language_server_name = language_server.name().to_string();
|
||||
let serialized = serde_json::to_value(diagnostic_group).unwrap();
|
||||
Some((language_server_name, serialized))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
cx.background_spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
async move {
|
||||
let diagnostic_groups = if diagnostic_groups.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(diagnostic_groups)
|
||||
};
|
||||
|
||||
let input_excerpt = excerpt_for_cursor_position(
|
||||
cursor_point,
|
||||
&full_path_str,
|
||||
&snapshot,
|
||||
MAX_REWRITE_TOKENS,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
);
|
||||
let input_events = make_events_prompt();
|
||||
let input_outline = prompt_for_outline(&snapshot);
|
||||
let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
|
||||
|
||||
let body = PredictEditsBody {
|
||||
input_events,
|
||||
input_excerpt: input_excerpt.prompt,
|
||||
speculated_output: Some(input_excerpt.speculated_output),
|
||||
outline: Some(input_outline),
|
||||
can_collect_data,
|
||||
diagnostic_groups,
|
||||
};
|
||||
|
||||
Ok(GatherContextOutput {
|
||||
body,
|
||||
editable_range,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
|
||||
let mut input_outline = String::new();
|
||||
|
||||
|
@ -1261,7 +1276,7 @@ struct RegisteredBuffer {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Event {
|
||||
pub enum Event {
|
||||
BufferChange {
|
||||
old_snapshot: BufferSnapshot,
|
||||
new_snapshot: BufferSnapshot,
|
||||
|
@ -1845,7 +1860,7 @@ mod tests {
|
|||
input_events: "".into(),
|
||||
input_excerpt: "".into(),
|
||||
output_excerpt: "".into(),
|
||||
request_sent_at: Instant::now(),
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue