zeta: Send up diagnostics with prediction requests (#24384)
This PR makes it so we send up the diagnostic groups as additional data with the edit prediction request. We're not yet making use of them, but we are recording them so we can use them later (e.g., to train the model). Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
13089d7ec6
commit
09967ac3d0
16 changed files with 145 additions and 31 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
@ -6428,6 +6428,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"gpui",
|
||||
"language",
|
||||
"project",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -7160,7 +7161,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"windows-targets 0.48.5",
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -10215,7 +10216,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
|
||||
dependencies = [
|
||||
"bytes 1.10.0",
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"multimap 0.10.0",
|
||||
|
@ -15484,7 +15485,7 @@ version = "0.1.9"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||
dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -16729,11 +16730,12 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.1.2"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ab9496dc5c80b2c5fb9654a76d7208d31b53130fb282085fcdde07653831843"
|
||||
checksum = "9ea4d8ead1e1158e5ebdd6735df25973781da70de5c8008e3a13595865ca4f31"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -16956,6 +16958,7 @@ dependencies = [
|
|||
"log",
|
||||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"regex",
|
||||
"reqwest_client",
|
||||
"rpc",
|
||||
|
|
|
@ -557,7 +557,7 @@ wasmtime = { version = "24", default-features = false, features = [
|
|||
wasmtime-wasi = "24"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
zed_llm_client = "0.1.1"
|
||||
zed_llm_client = "0.2"
|
||||
zstd = "0.11"
|
||||
metal = "0.31"
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ use anyhow::Result;
|
|||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use inline_completion::{Direction, InlineCompletion, InlineCompletionProvider};
|
||||
use language::{language_settings::AllLanguageSettings, Buffer, OffsetRangeExt, ToOffset};
|
||||
use project::Project;
|
||||
use settings::Settings;
|
||||
use std::{path::Path, time::Duration};
|
||||
|
||||
|
@ -79,6 +80,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
|
|
@ -4648,7 +4648,13 @@ impl Editor {
|
|||
}
|
||||
|
||||
self.update_visible_inline_completion(window, cx);
|
||||
provider.refresh(buffer, cursor_buffer_position, debounce, cx);
|
||||
provider.refresh(
|
||||
self.project.clone(),
|
||||
buffer,
|
||||
cursor_buffer_position,
|
||||
debounce,
|
||||
cx,
|
||||
);
|
||||
Some(())
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ use indoc::indoc;
|
|||
use inline_completion::InlineCompletionProvider;
|
||||
use language::{Language, LanguageConfig};
|
||||
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
|
||||
use project::Project;
|
||||
use std::{num::NonZeroU32, ops::Range, sync::Arc};
|
||||
use text::{Point, ToOffset};
|
||||
|
||||
|
@ -394,6 +395,7 @@ impl InlineCompletionProvider for FakeInlineCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
_buffer: gpui::Entity<language::Buffer>,
|
||||
_cursor_position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
|
|
@ -14,3 +14,4 @@ path = "src/inline_completion.rs"
|
|||
[dependencies]
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use gpui::{App, Context, Entity};
|
||||
use language::Buffer;
|
||||
use project::Project;
|
||||
use std::ops::Range;
|
||||
|
||||
// TODO: Find a better home for `Direction`.
|
||||
|
@ -58,6 +59,7 @@ pub trait InlineCompletionProvider: 'static + Sized {
|
|||
fn is_refreshing(&self) -> bool;
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
@ -101,6 +103,7 @@ pub trait InlineCompletionProviderHandle {
|
|||
fn is_refreshing(&self, cx: &App) -> bool;
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
|
@ -174,13 +177,14 @@ where
|
|||
|
||||
fn refresh(
|
||||
&self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: language::Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.refresh(buffer, cursor_position, debounce, cx)
|
||||
this.refresh(project, buffer, cursor_position, debounce, cx)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -197,7 +197,7 @@ struct SelectionSet {
|
|||
}
|
||||
|
||||
/// A diagnostic associated with a certain range of a buffer.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Diagnostic {
|
||||
/// The name of the service that produced this diagnostic.
|
||||
pub source: Option<String>,
|
||||
|
|
|
@ -2,6 +2,7 @@ use crate::{range_to_lsp, Diagnostic};
|
|||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use lsp::LanguageServerId;
|
||||
use serde::Serialize;
|
||||
use std::{
|
||||
cmp::{Ordering, Reverse},
|
||||
iter,
|
||||
|
@ -25,7 +26,7 @@ pub struct DiagnosticSet {
|
|||
/// the diagnostics are stored internally as [`Anchor`]s, but can be
|
||||
/// resolved to different coordinates types like [`usize`] byte offsets or
|
||||
/// [`Point`](gpui::Point)s.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
|
||||
pub struct DiagnosticEntry<T> {
|
||||
/// The range of the buffer where the diagnostic applies.
|
||||
pub range: Range<T>,
|
||||
|
@ -35,7 +36,7 @@ pub struct DiagnosticEntry<T> {
|
|||
|
||||
/// A group of related diagnostics, ordered by their start position
|
||||
/// in the buffer.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct DiagnosticGroup<T> {
|
||||
/// The diagnostics.
|
||||
pub entries: Vec<DiagnosticEntry<T>>,
|
||||
|
@ -43,6 +44,20 @@ pub struct DiagnosticGroup<T> {
|
|||
pub primary_ix: usize,
|
||||
}
|
||||
|
||||
impl DiagnosticGroup<Anchor> {
|
||||
/// Converts the entries in this [`DiagnosticGroup`] to a different buffer coordinate type.
|
||||
pub fn resolve<O: FromAnchor>(&self, buffer: &text::BufferSnapshot) -> DiagnosticGroup<O> {
|
||||
DiagnosticGroup {
|
||||
entries: self
|
||||
.entries
|
||||
.iter()
|
||||
.map(|entry| entry.resolve(buffer))
|
||||
.collect(),
|
||||
primary_ix: self.primary_ix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Summary {
|
||||
start: Anchor,
|
||||
|
|
|
@ -32,10 +32,7 @@ use gpui::{App, AsyncApp, Entity, SharedString, Task};
|
|||
pub use highlight_map::HighlightMap;
|
||||
use http_client::HttpClient;
|
||||
pub use language_registry::{LanguageName, LoadedLanguage};
|
||||
use lsp::{
|
||||
CodeActionKind, InitializeParams, LanguageServerBinary, LanguageServerBinaryOptions,
|
||||
LanguageServerName,
|
||||
};
|
||||
use lsp::{CodeActionKind, InitializeParams, LanguageServerBinary, LanguageServerBinaryOptions};
|
||||
use parking_lot::Mutex;
|
||||
use regex::Regex;
|
||||
use schemars::{
|
||||
|
@ -73,12 +70,12 @@ use util::serde::default_true;
|
|||
|
||||
pub use buffer::Operation;
|
||||
pub use buffer::*;
|
||||
pub use diagnostic_set::DiagnosticEntry;
|
||||
pub use diagnostic_set::{DiagnosticEntry, DiagnosticGroup};
|
||||
pub use language_registry::{
|
||||
AvailableLanguage, LanguageNotFound, LanguageQueries, LanguageRegistry,
|
||||
LanguageServerBinaryStatus, QUERY_FILENAME_PREFIXES,
|
||||
};
|
||||
pub use lsp::LanguageServerId;
|
||||
pub use lsp::{LanguageServerId, LanguageServerName};
|
||||
pub use outline::*;
|
||||
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions};
|
||||
pub use text::{AnchorRangeExt, LineEnding};
|
||||
|
|
|
@ -2,7 +2,6 @@ use anyhow::Context as _;
|
|||
use gpui::{App, UpdateGlobal};
|
||||
use json::json_task_context;
|
||||
pub use language::*;
|
||||
use lsp::LanguageServerName;
|
||||
use node_runtime::NodeRuntime;
|
||||
use python::{PythonContextProvider, PythonToolchainProvider};
|
||||
use rust_embed::RustEmbed;
|
||||
|
|
|
@ -166,6 +166,19 @@ pub struct LocalLspStore {
|
|||
}
|
||||
|
||||
impl LocalLspStore {
|
||||
/// Returns the running language server for the given ID. Note if the language server is starting, it will not be returned.
|
||||
pub fn running_language_server_for_id(
|
||||
&self,
|
||||
id: LanguageServerId,
|
||||
) -> Option<&Arc<LanguageServer>> {
|
||||
let language_server_state = self.language_servers.get(&id)?;
|
||||
|
||||
match language_server_state {
|
||||
LanguageServerState::Running { server, .. } => Some(server),
|
||||
LanguageServerState::Starting(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_language_server(
|
||||
&mut self,
|
||||
worktree_handle: &Entity<Worktree>,
|
||||
|
|
|
@ -22,6 +22,7 @@ inline_completion.workspace = true
|
|||
language.workspace = true
|
||||
log.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
|
|
@ -4,6 +4,7 @@ use futures::StreamExt as _;
|
|||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use inline_completion::{Direction, InlineCompletion, InlineCompletionProvider};
|
||||
use language::{Anchor, Buffer, BufferSnapshot};
|
||||
use project::Project;
|
||||
use std::{
|
||||
ops::{AddAssign, Range},
|
||||
path::Path,
|
||||
|
@ -123,6 +124,7 @@ impl InlineCompletionProvider for SupermavenCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
buffer_handle: Entity<Buffer>,
|
||||
cursor_position: Anchor,
|
||||
debounce: bool,
|
||||
|
|
|
@ -37,6 +37,7 @@ language_models.workspace = true
|
|||
log.workspace = true
|
||||
menu.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
@ -30,6 +30,7 @@ use language::{
|
|||
};
|
||||
use language_models::LlmApiToken;
|
||||
use postage::watch;
|
||||
use project::Project;
|
||||
use settings::WorktreeId;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
|
@ -363,6 +364,7 @@ impl Zeta {
|
|||
|
||||
pub fn request_completion_impl<F, R>(
|
||||
&mut self,
|
||||
project: Option<&Entity<Project>>,
|
||||
buffer: &Entity<Buffer>,
|
||||
cursor: language::Anchor,
|
||||
can_collect_data: bool,
|
||||
|
@ -374,6 +376,7 @@ impl Zeta {
|
|||
R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
|
||||
{
|
||||
let snapshot = self.report_changes_for_buffer(&buffer, cx);
|
||||
let diagnostic_groups = snapshot.diagnostic_groups(None);
|
||||
let cursor_point = cursor.to_point(&snapshot);
|
||||
let cursor_offset = cursor_point.to_offset(&snapshot);
|
||||
let events = self.events.clone();
|
||||
|
@ -387,10 +390,39 @@ impl Zeta {
|
|||
let is_staff = cx.is_staff();
|
||||
|
||||
let buffer = buffer.clone();
|
||||
|
||||
let local_lsp_store =
|
||||
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
|
||||
let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store {
|
||||
Some(
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.filter_map(|(language_server_id, diagnostic_group)| {
|
||||
let language_server =
|
||||
local_lsp_store.running_language_server_for_id(language_server_id)?;
|
||||
|
||||
Some((
|
||||
language_server.name(),
|
||||
diagnostic_group.resolve::<usize>(&snapshot),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
cx.spawn(|_, cx| async move {
|
||||
let request_sent_at = Instant::now();
|
||||
|
||||
let (input_events, input_excerpt, excerpt_range, input_outline) = cx
|
||||
struct BackgroundValues {
|
||||
input_events: String,
|
||||
input_excerpt: String,
|
||||
excerpt_range: Range<usize>,
|
||||
input_outline: String,
|
||||
}
|
||||
|
||||
let values = cx
|
||||
.background_executor()
|
||||
.spawn({
|
||||
let snapshot = snapshot.clone();
|
||||
|
@ -419,18 +451,36 @@ impl Zeta {
|
|||
// is not counted towards TOTAL_BYTE_LIMIT.
|
||||
let input_outline = prompt_for_outline(&snapshot);
|
||||
|
||||
anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline))
|
||||
anyhow::Ok(BackgroundValues {
|
||||
input_events,
|
||||
input_excerpt,
|
||||
excerpt_range,
|
||||
input_outline,
|
||||
})
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
|
||||
log::debug!(
|
||||
"Events:\n{}\nExcerpt:\n{}",
|
||||
values.input_events,
|
||||
values.input_excerpt
|
||||
);
|
||||
|
||||
let body = PredictEditsBody {
|
||||
input_events: input_events.clone(),
|
||||
input_excerpt: input_excerpt.clone(),
|
||||
outline: Some(input_outline.clone()),
|
||||
input_events: values.input_events.clone(),
|
||||
input_excerpt: values.input_excerpt.clone(),
|
||||
outline: Some(values.input_outline.clone()),
|
||||
can_collect_data,
|
||||
diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| {
|
||||
diagnostic_groups
|
||||
.into_iter()
|
||||
.map(|(name, diagnostic_group)| {
|
||||
Ok((name.to_string(), serde_json::to_value(diagnostic_group)?))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.log_err()
|
||||
}),
|
||||
};
|
||||
|
||||
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
|
||||
|
@ -442,12 +492,12 @@ impl Zeta {
|
|||
output_excerpt,
|
||||
buffer,
|
||||
&snapshot,
|
||||
excerpt_range,
|
||||
values.excerpt_range,
|
||||
cursor_offset,
|
||||
path,
|
||||
input_outline,
|
||||
input_events,
|
||||
input_excerpt,
|
||||
values.input_outline,
|
||||
values.input_events,
|
||||
values.input_excerpt,
|
||||
request_sent_at,
|
||||
&cx,
|
||||
)
|
||||
|
@ -466,11 +516,13 @@ impl Zeta {
|
|||
and then another
|
||||
"#};
|
||||
|
||||
let project = None;
|
||||
let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
|
||||
let position = buffer.read(cx).anchor_before(Point::new(1, 0));
|
||||
|
||||
let completion_tasks = vec![
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -486,6 +538,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -501,6 +554,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -517,6 +571,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -533,6 +588,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -548,6 +604,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -562,6 +619,7 @@ and then another
|
|||
cx,
|
||||
),
|
||||
self.fake_completion(
|
||||
project,
|
||||
&buffer,
|
||||
position,
|
||||
PredictEditsResponse {
|
||||
|
@ -594,6 +652,7 @@ and then another
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn fake_completion(
|
||||
&mut self,
|
||||
project: Option<&Entity<Project>>,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
response: PredictEditsResponse,
|
||||
|
@ -601,19 +660,21 @@ and then another
|
|||
) -> Task<Result<Option<InlineCompletion>>> {
|
||||
use std::future::ready;
|
||||
|
||||
self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
|
||||
self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| {
|
||||
ready(Ok(response))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn request_completion(
|
||||
&mut self,
|
||||
project: Option<&Entity<Project>>,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
can_collect_data: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Option<InlineCompletion>>> {
|
||||
self.request_completion_impl(
|
||||
project,
|
||||
buffer,
|
||||
position,
|
||||
can_collect_data,
|
||||
|
@ -1494,6 +1555,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
position: language::Anchor,
|
||||
_debounce: bool,
|
||||
|
@ -1529,7 +1591,13 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
|
|||
let completion_request = this.update(&mut cx, |this, cx| {
|
||||
this.last_request_timestamp = Instant::now();
|
||||
this.zeta.update(cx, |zeta, cx| {
|
||||
zeta.request_completion(&buffer, position, can_collect_data, cx)
|
||||
zeta.request_completion(
|
||||
project.as_ref(),
|
||||
&buffer,
|
||||
position,
|
||||
can_collect_data,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
|
@ -1858,7 +1926,7 @@ mod tests {
|
|||
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
|
||||
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
|
||||
let completion_task = zeta.update(cx, |zeta, cx| {
|
||||
zeta.request_completion(&buffer, cursor, false, cx)
|
||||
zeta.request_completion(None, &buffer, cursor, false, cx)
|
||||
});
|
||||
|
||||
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue