zeta: Record recently active files when data collection is enabled

This commit is contained in:
Michael Sloan 2025-08-25 21:56:22 -06:00
parent f8667a8379
commit b696a32518
No known key found for this signature in database
4 changed files with 140 additions and 46 deletions

View file

@ -152,6 +152,8 @@ pub struct PredictEditsBody {
/// Info about the git repository state, only present when can_collect_data is true. /// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)] #[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>, pub git_info: Option<PredictEditsGitInfo>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub recent_files: Option<Vec<PredictEditsRecentFile>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -167,6 +169,12 @@ pub struct PredictEditsGitInfo {
pub remote_upstream_url: Option<String>, pub remote_upstream_url: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsRecentFile {
pub repo_path: String,
pub active_to_now_ms: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse { pub struct PredictEditsResponse {
pub request_id: Uuid, pub request_id: Uuid,

View file

@ -204,10 +204,7 @@ fn assign_edit_prediction_provider(
} }
} }
let workspace = window let workspace = window.root::<Workspace>().flatten();
.root::<Workspace>()
.flatten()
.map(|workspace| workspace.downgrade());
let zeta = let zeta =
zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx); zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx);

View file

@ -6,19 +6,21 @@ mod onboarding_modal;
mod onboarding_telemetry; mod onboarding_telemetry;
mod rate_completion_modal; mod rate_completion_modal;
use arrayvec::ArrayVec;
pub(crate) use completion_diff_element::*; pub(crate) use completion_diff_element::*;
use db::kvp::{Dismissable, KEY_VALUE_STORE}; use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction::DataCollectionState; use edit_prediction::DataCollectionState;
pub use init::*; pub use init::*;
use license_detection::LicenseDetectionWatcher; use license_detection::LicenseDetectionWatcher;
use project::git_store::Repository;
pub use rate_completion_modal::*; pub use rate_completion_modal::*;
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore}; use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::{ use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
ZED_VERSION_HEADER_NAME,
}; };
use collections::{HashMap, HashSet, VecDeque}; use collections::{HashMap, HashSet, VecDeque};
use futures::AsyncReadExt; use futures::AsyncReadExt;
@ -32,7 +34,7 @@ use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
}; };
use language_model::{LlmApiToken, RefreshLlmTokenListener}; use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::{Project, ProjectPath}; use project::{Project, ProjectEntryId, ProjectPath};
use release_channel::AppVersion; use release_channel::AppVersion;
use settings::WorktreeId; use settings::WorktreeId;
use std::str::FromStr; use std::str::FromStr;
@ -70,6 +72,12 @@ const MAX_DIAGNOSTIC_GROUPS: usize = 10;
/// Maximum number of events to track. /// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16; const MAX_EVENT_COUNT: usize = 16;
/// Maximum number of recent files to track.
const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
/// Maximum number of edit predictions to store for feedback.
const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
actions!( actions!(
edit_prediction, edit_prediction,
[ [
@ -212,7 +220,7 @@ impl std::fmt::Debug for EditPrediction {
} }
pub struct Zeta { pub struct Zeta {
workspace: Option<WeakEntity<Workspace>>, workspace: WeakEntity<Workspace>,
client: Arc<Client>, client: Arc<Client>,
events: VecDeque<Event>, events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>, registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
@ -225,6 +233,7 @@ pub struct Zeta {
update_required: bool, update_required: bool,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>, license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
recent_project_entries: VecDeque<(ProjectEntryId, Instant)>,
} }
impl Zeta { impl Zeta {
@ -233,7 +242,7 @@ impl Zeta {
} }
pub fn register( pub fn register(
workspace: Option<WeakEntity<Workspace>>, workspace: Option<Entity<Workspace>>,
worktree: Option<Entity<Worktree>>, worktree: Option<Entity<Worktree>>,
client: Arc<Client>, client: Arc<Client>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
@ -266,7 +275,7 @@ impl Zeta {
} }
fn new( fn new(
workspace: Option<WeakEntity<Workspace>>, workspace: Option<Entity<Workspace>>,
client: Arc<Client>, client: Arc<Client>,
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
@ -276,11 +285,27 @@ impl Zeta {
let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice); let data_collection_choice = cx.new(|_| data_collection_choice);
if let Some(workspace) = &workspace {
cx.subscribe(
&workspace.read(cx).project().clone(),
|this, _workspace, event, _cx| match event {
project::Event::ActiveEntryChanged(Some(project_entry_id)) => {
this.push_recent_project_entry(*project_entry_id)
}
_ => {}
},
)
.detach();
}
Self { Self {
workspace, workspace: workspace.map_or_else(
|| WeakEntity::new_invalid(),
|workspace| workspace.downgrade(),
),
client, client,
events: VecDeque::new(), events: VecDeque::with_capacity(MAX_EVENT_COUNT),
shown_completions: VecDeque::new(), shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
rated_completions: HashSet::default(), rated_completions: HashSet::default(),
registered_buffers: HashMap::default(), registered_buffers: HashMap::default(),
data_collection_choice, data_collection_choice,
@ -300,6 +325,7 @@ impl Zeta {
update_required: false, update_required: false,
license_detection_watchers: HashMap::default(), license_detection_watchers: HashMap::default(),
user_store, user_store,
recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT),
} }
} }
@ -327,11 +353,12 @@ impl Zeta {
} }
} }
self.events.push_back(event);
if self.events.len() >= MAX_EVENT_COUNT { if self.events.len() >= MAX_EVENT_COUNT {
// These are halved instead of popping to improve prompt caching. // These are halved instead of popping to improve prompt caching.
self.events.drain(..MAX_EVENT_COUNT / 2); self.events.drain(..MAX_EVENT_COUNT / 2);
} }
self.events.push_back(event);
} }
pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) { pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
@ -393,12 +420,17 @@ impl Zeta {
let llm_token = self.llm_token.clone(); let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx); let app_version = AppVersion::global(cx);
let git_info = if let (true, Some(project), Some(file)) = let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
(can_collect_data, project, snapshot.file()) (can_collect_data, project, snapshot.file())
&& let Some(repository) =
git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
{ {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) let repository = repository.read(cx);
let git_info = make_predict_edits_git_info(repository);
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
(git_info, Some(recent_files))
} else { } else {
None (None, None)
}; };
let full_path: Arc<Path> = snapshot let full_path: Arc<Path> = snapshot
@ -417,6 +449,7 @@ impl Zeta {
make_events_prompt, make_events_prompt,
can_collect_data, can_collect_data,
git_info, git_info,
recent_files,
cx, cx,
); );
@ -702,12 +735,8 @@ and then another
can_collect_data: bool, can_collect_data: bool,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> { ) -> Task<Result<Option<EditPrediction>>> {
let workspace = self
.workspace
.as_ref()
.and_then(|workspace| workspace.upgrade());
self.request_completion_impl( self.request_completion_impl(
workspace, self.workspace.upgrade(),
project, project,
buffer, buffer,
position, position,
@ -1021,11 +1050,11 @@ and then another
} }
pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) { pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context<Self>) {
self.shown_completions.push_front(completion.clone()); if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT {
if self.shown_completions.len() > 50 {
let completion = self.shown_completions.pop_back().unwrap(); let completion = self.shown_completions.pop_back().unwrap();
self.rated_completions.remove(&completion.id); self.rated_completions.remove(&completion.id);
} }
self.shown_completions.push_front(completion.clone());
cx.notify(); cx.notify();
} }
@ -1099,6 +1128,63 @@ and then another
None => DataCollectionChoice::NotAnswered, None => DataCollectionChoice::NotAnswered,
} }
} }
fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
let now = Instant::now();
if let Some(existing_ix) = self
.recent_project_entries
.iter()
.rposition(|(id, _)| *id == project_entry_id)
{
self.recent_project_entries.remove(existing_ix);
}
if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT {
self.recent_project_entries.pop_front();
}
self.recent_project_entries
.push_back((project_entry_id, now));
}
fn recent_files(
&mut self,
now: &Instant,
repository: &Repository,
cx: &Context<Self>,
) -> Vec<PredictEditsRecentFile> {
let Ok(project) = self
.workspace
.read_with(cx, |workspace, _cx| workspace.project().clone())
else {
return Vec::new();
};
let mut results = Vec::new();
for ix in (0..self.recent_project_entries.len()).rev() {
let (id, last_active_at) = &self.recent_project_entries[ix];
let Some(project_path) = project.read(cx).path_for_entry(*id, cx) else {
self.recent_project_entries.remove(ix);
continue;
};
let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx) else {
// entry not removed since queries involving other repositories might occur later
continue;
};
let Some(repo_path) = repo_path.to_str() else {
// paths may not be valid UTF-8
self.recent_project_entries.remove(ix);
continue;
};
let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into()
else {
self.recent_project_entries.remove(ix);
continue;
};
results.push(PredictEditsRecentFile {
repo_path: repo_path.to_string(),
active_to_now_ms,
});
}
results
}
} }
pub struct PerformPredictEditsParams { pub struct PerformPredictEditsParams {
@ -1123,33 +1209,32 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum() .sum()
} }
fn git_info_for_file( fn git_repository_for_file(
project: &Entity<Project>, project: &Entity<Project>,
project_path: &ProjectPath, project_path: &ProjectPath,
cx: &App, cx: &App,
) -> Option<PredictEditsGitInfo> { ) -> Option<Entity<Repository>> {
let git_store = project.read(cx).git_store().read(cx); let git_store = project.read(cx).git_store().read(cx);
if let Some((repository, _repo_path)) = git_store
git_store.repository_and_path_for_project_path(project_path, cx) .repository_and_path_for_project_path(project_path, cx)
{ .map(|(repo, _repo_path)| repo)
let repository = repository.read(cx); }
let head_sha = repository
.head_commit fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
.as_ref() let head_sha = repository
.map(|head_commit| head_commit.sha.to_string()); .head_commit
let remote_origin_url = repository.remote_origin_url.clone(); .as_ref()
let remote_upstream_url = repository.remote_upstream_url.clone(); .map(|head_commit| head_commit.sha.to_string());
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { let remote_origin_url = repository.remote_origin_url.clone();
return None; let remote_upstream_url = repository.remote_upstream_url.clone();
} if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
Some(PredictEditsGitInfo { return None;
head_sha,
remote_origin_url,
remote_upstream_url,
})
} else {
None
} }
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
})
} }
pub struct GatherContextOutput { pub struct GatherContextOutput {
@ -1165,6 +1250,7 @@ pub fn gather_context(
make_events_prompt: impl FnOnce() -> String + Send + 'static, make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool, can_collect_data: bool,
git_info: Option<PredictEditsGitInfo>, git_info: Option<PredictEditsGitInfo>,
recent_files: Option<Vec<PredictEditsRecentFile>>,
cx: &App, cx: &App,
) -> Task<Result<GatherContextOutput>> { ) -> Task<Result<GatherContextOutput>> {
let local_lsp_store = let local_lsp_store =
@ -1216,6 +1302,7 @@ pub fn gather_context(
git_info, git_info,
outline: None, outline: None,
speculated_output: None, speculated_output: None,
recent_files,
}; };
Ok(GatherContextOutput { Ok(GatherContextOutput {

View file

@ -174,6 +174,7 @@ async fn get_context(
// Enable gathering extra data not currently needed for edit predictions // Enable gathering extra data not currently needed for edit predictions
let can_collect_data = true; let can_collect_data = true;
let git_info = None; let git_info = None;
let recent_files = None;
let mut gather_context_output = cx let mut gather_context_output = cx
.update(|cx| { .update(|cx| {
gather_context( gather_context(
@ -184,6 +185,7 @@ async fn get_context(
move || events, move || events,
can_collect_data, can_collect_data,
git_info, git_info,
recent_files,
cx, cx,
) )
})? })?