diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 741945af10..7d1548b322 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -152,6 +152,8 @@ pub struct PredictEditsBody { /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub recent_files: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -167,6 +169,12 @@ pub struct PredictEditsGitInfo { pub remote_upstream_url: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsRecentFile { + pub repo_path: String, + pub active_to_now_ms: u32, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsResponse { pub request_id: Uuid, diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index bc2d757fd1..06ed524e65 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -204,10 +204,7 @@ fn assign_edit_prediction_provider( } } - let workspace = window - .root::() - .flatten() - .map(|workspace| workspace.downgrade()); + let workspace = window.root::().flatten(); let zeta = zeta::Zeta::register(workspace, worktree, client.clone(), user_store, cx); diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 7b14d12796..45cd7b1614 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -6,19 +6,21 @@ mod onboarding_modal; mod onboarding_telemetry; mod rate_completion_modal; +use arrayvec::ArrayVec; pub(crate) use completion_diff_element::*; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction::DataCollectionState; pub use init::*; use license_detection::LicenseDetectionWatcher; +use project::git_store::Repository; pub use rate_completion_modal::*; use anyhow::{Context as _, Result, anyhow}; -use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse, + ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; @@ -32,7 +34,7 @@ use language::{ Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; -use project::{Project, ProjectPath}; +use project::{Project, ProjectEntryId, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; @@ -70,6 +72,12 @@ const MAX_DIAGNOSTIC_GROUPS: usize = 10; /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; +/// Maximum number of recent files to track. +const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16; + +/// Maximum number of edit predictions to store for feedback. +const MAX_SHOWN_COMPLETION_COUNT: usize = 50; + actions!( edit_prediction, [ @@ -212,7 +220,7 @@ impl std::fmt::Debug for EditPrediction { } pub struct Zeta { - workspace: Option>, + workspace: WeakEntity, client: Arc, events: VecDeque, registered_buffers: HashMap, @@ -225,6 +233,7 @@ pub struct Zeta { update_required: bool, user_store: Entity, license_detection_watchers: HashMap>, + recent_project_entries: VecDeque<(ProjectEntryId, Instant)>, } impl Zeta { @@ -233,7 +242,7 @@ impl Zeta { } pub fn register( - workspace: Option>, + workspace: Option>, worktree: Option>, client: Arc, user_store: Entity, @@ -266,7 +275,7 @@ impl Zeta { } fn new( - workspace: Option>, + workspace: Option>, client: Arc, user_store: Entity, cx: &mut Context, @@ -276,11 +285,27 @@ impl Zeta { let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = cx.new(|_| data_collection_choice); + if let Some(workspace) = &workspace { + cx.subscribe( + &workspace.read(cx).project().clone(), + |this, _workspace, event, _cx| match event { + project::Event::ActiveEntryChanged(Some(project_entry_id)) => { + this.push_recent_project_entry(*project_entry_id) + } + _ => {} + }, + ) + .detach(); + } + Self { - workspace, + workspace: workspace.map_or_else( + || WeakEntity::new_invalid(), + |workspace| workspace.downgrade(), + ), client, - events: VecDeque::new(), - shown_completions: VecDeque::new(), + events: VecDeque::with_capacity(MAX_EVENT_COUNT), + shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT), rated_completions: HashSet::default(), registered_buffers: HashMap::default(), data_collection_choice, @@ -300,6 +325,7 @@ impl Zeta { update_required: false, license_detection_watchers: HashMap::default(), user_store, + recent_project_entries: VecDeque::with_capacity(MAX_RECENT_PROJECT_ENTRIES_COUNT), } } @@ -327,11 +353,12 @@ impl Zeta { } } - self.events.push_back(event); if self.events.len() >= MAX_EVENT_COUNT { // These are halved instead of popping to improve prompt caching. self.events.drain(..MAX_EVENT_COUNT / 2); } + + self.events.push_back(event); } pub fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { @@ -393,12 +420,17 @@ impl Zeta { let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let git_info = if let (true, Some(project), Some(file)) = + let (git_info, recent_files) = if let (true, Some(project), Some(file)) = (can_collect_data, project, snapshot.file()) + && let Some(repository) = + git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) { - git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + let repository = repository.read(cx); + let git_info = make_predict_edits_git_info(repository); + let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx); + (git_info, Some(recent_files)) } else { - None + (None, None) }; let full_path: Arc = snapshot @@ -417,6 +449,7 @@ impl Zeta { make_events_prompt, can_collect_data, git_info, + recent_files, cx, ); @@ -702,12 +735,8 @@ and then another can_collect_data: bool, cx: &mut Context, ) -> Task>> { - let workspace = self - .workspace - .as_ref() - .and_then(|workspace| workspace.upgrade()); self.request_completion_impl( - workspace, + self.workspace.upgrade(), project, buffer, position, @@ -1021,11 +1050,11 @@ and then another } pub fn completion_shown(&mut self, completion: &EditPrediction, cx: &mut Context) { - self.shown_completions.push_front(completion.clone()); - if self.shown_completions.len() > 50 { + if self.shown_completions.len() >= MAX_SHOWN_COMPLETION_COUNT { let completion = self.shown_completions.pop_back().unwrap(); self.rated_completions.remove(&completion.id); } + self.shown_completions.push_front(completion.clone()); cx.notify(); } @@ -1099,6 +1128,63 @@ and then another None => DataCollectionChoice::NotAnswered, } } + + fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) { + let now = Instant::now(); + if let Some(existing_ix) = self + .recent_project_entries + .iter() + .rposition(|(id, _)| *id == project_entry_id) + { + self.recent_project_entries.remove(existing_ix); + } + if self.recent_project_entries.len() >= MAX_RECENT_PROJECT_ENTRIES_COUNT { + self.recent_project_entries.pop_front(); + } + self.recent_project_entries + .push_back((project_entry_id, now)); + } + + fn recent_files( + &mut self, + now: &Instant, + repository: &Repository, + cx: &Context, + ) -> Vec { + let Ok(project) = self + .workspace + .read_with(cx, |workspace, _cx| workspace.project().clone()) + else { + return Vec::new(); + }; + let mut results = Vec::new(); + for ix in (0..self.recent_project_entries.len()).rev() { + let (id, last_active_at) = &self.recent_project_entries[ix]; + let Some(project_path) = project.read(cx).path_for_entry(*id, cx) else { + self.recent_project_entries.remove(ix); + continue; + }; + let Some(repo_path) = repository.project_path_to_repo_path(&project_path, cx) else { + // entry not removed since queries involving other repositories might occur later + continue; + }; + let Some(repo_path) = repo_path.to_str() else { + // paths may not be valid UTF-8 + self.recent_project_entries.remove(ix); + continue; + }; + let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into() + else { + self.recent_project_entries.remove(ix); + continue; + }; + results.push(PredictEditsRecentFile { + repo_path: repo_path.to_string(), + active_to_now_ms, + }); + } + results + } } pub struct PerformPredictEditsParams { @@ -1123,33 +1209,32 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } -fn git_info_for_file( +fn git_repository_for_file( project: &Entity, project_path: &ProjectPath, cx: &App, -) -> Option { +) -> Option> { let git_store = project.read(cx).git_store().read(cx); - if let Some((repository, _repo_path)) = - git_store.repository_and_path_for_project_path(project_path, cx) - { - let repository = repository.read(cx); - let head_sha = repository - .head_commit - .as_ref() - .map(|head_commit| head_commit.sha.to_string()); - let remote_origin_url = repository.remote_origin_url.clone(); - let remote_upstream_url = repository.remote_upstream_url.clone(); - if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { - return None; - } - Some(PredictEditsGitInfo { - head_sha, - remote_origin_url, - remote_upstream_url, - }) - } else { - None + git_store + .repository_and_path_for_project_path(project_path, cx) + .map(|(repo, _repo_path)| repo) +} + +fn make_predict_edits_git_info(repository: &Repository) -> Option { + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) } pub struct GatherContextOutput { @@ -1165,6 +1250,7 @@ pub fn gather_context( make_events_prompt: impl FnOnce() -> String + Send + 'static, can_collect_data: bool, git_info: Option, + recent_files: Option>, cx: &App, ) -> Task> { let local_lsp_store = @@ -1216,6 +1302,7 @@ pub fn gather_context( git_info, outline: None, speculated_output: None, + recent_files, }; Ok(GatherContextOutput { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 5b2d4cf615..76f638057a 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -174,6 +174,7 @@ async fn get_context( // Enable gathering extra data not currently needed for edit predictions let can_collect_data = true; let git_info = None; + let recent_files = None; let mut gather_context_output = cx .update(|cx| { gather_context( @@ -184,6 +185,7 @@ async fn get_context( move || events, can_collect_data, git_info, + recent_files, cx, ) })?