diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 7d1548b322..eca725bef1 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -152,8 +152,6 @@ pub struct PredictEditsBody { /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub recent_files: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -167,6 +165,8 @@ pub struct PredictEditsGitInfo { /// URL of the remote called `upstream`. #[serde(skip_serializing_if = "Option::is_none", default)] pub remote_upstream_url: Option, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub recent_files: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index a415c95d3b..6dbb8e317c 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -404,7 +404,7 @@ impl Zeta { project: Option<&Entity>, buffer: &Entity, cursor: language::Anchor, - can_collect_data: bool, + can_collect_data: CanCollectData, cx: &mut Context, perform_predict_edits: F, ) -> Task>> @@ -423,17 +423,10 @@ impl Zeta { let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let (git_info, recent_files) = if let (true, Some(project), Some(file)) = - (can_collect_data, project, snapshot.file()) - && let Some(repository) = - git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) - { - let repository = repository.read(cx); - let git_info = make_predict_edits_git_info(repository); - let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx); - (git_info, Some(recent_files)) + let git_info = if matches!(can_collect_data, CanCollectData(true)) { + self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx) } else { - (None, None) + None }; let full_path: Arc = snapshot @@ -452,7 +445,6 @@ impl Zeta { make_events_prompt, can_collect_data, git_info, - recent_files, cx, ); @@ -725,9 +717,15 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { - ready(Ok((response, None))) - }) + self.request_completion_impl( + None, + project, + buffer, + position, + CanCollectData(false), + cx, + |_params| ready(Ok((response, None))), + ) } pub fn request_completion( @@ -735,7 +733,7 @@ and then another project: Option<&Entity>, buffer: &Entity, position: language::Anchor, - can_collect_data: bool, + can_collect_data: CanCollectData, cx: &mut Context, ) -> Task>> { self.request_completion_impl( @@ -1132,6 +1130,46 @@ and then another } } + fn gather_git_info( + &mut self, + project: Option<&Entity>, + buffer_snapshotted_at: &Instant, + snapshot: &BufferSnapshot, + cx: &Context, + ) -> Option { + let project = project?.read(cx); + let file = snapshot.file()?; + let project_path = ProjectPath::from_file(file.as_ref(), cx); + let entry = project.entry_for_path(&project_path, cx)?; + if !worktree_entry_eligible_for_collection(&entry) { + return None; + } + + let git_store = project.git_store().read(cx); + let (repository, _repo_path) = + git_store.repository_and_path_for_project_path(&project_path, cx)?; + + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + + let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx); + + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + recent_files: Some(recent_files), + }) + } + fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) { let now = Instant::now(); if let Some(existing_ix) = self @@ -1166,12 +1204,7 @@ and then another if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx) && let worktree = worktree.read(cx) && let Some(entry) = worktree.entry_for_id(*entry_id) - && entry.is_file() - && entry.is_created() - && !entry.is_ignored - && !entry.is_private - && !entry.is_external - && !entry.is_fifo + && worktree_entry_eligible_for_collection(entry) { let project_path = ProjectPath { worktree_id: worktree.id(), @@ -1191,12 +1224,6 @@ and then another self.recent_project_entries.remove(ix); continue; } - if let Some(file_status) = repository.status_for_path(&repo_path) { - if file_status.is_ignored() || file_status.is_untracked() { - // entry not removed because it may belong to a nested repository - continue; - } - } let Ok(active_to_now_ms) = now.duration_since(*last_active_at).as_millis().try_into() else { @@ -1215,6 +1242,15 @@ and then another } } +fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool { + entry.is_file() + && entry.is_created() + && !entry.is_ignored + && !entry.is_private + && !entry.is_external + && !entry.is_fifo +} + pub struct PerformPredictEditsParams { pub client: Arc, pub llm_token: LlmApiToken, @@ -1237,34 +1273,6 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } -fn git_repository_for_file( - project: &Entity, - project_path: &ProjectPath, - cx: &App, -) -> Option> { - let git_store = project.read(cx).git_store().read(cx); - git_store - .repository_and_path_for_project_path(project_path, cx) - .map(|(repo, _repo_path)| repo) -} - -fn make_predict_edits_git_info(repository: &Repository) -> Option { - let head_sha = repository - .head_commit - .as_ref() - .map(|head_commit| head_commit.sha.to_string()); - let remote_origin_url = repository.remote_origin_url.clone(); - let remote_upstream_url = repository.remote_upstream_url.clone(); - if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { - return None; - } - Some(PredictEditsGitInfo { - head_sha, - remote_origin_url, - remote_upstream_url, - }) -} - pub struct GatherContextOutput { pub body: PredictEditsBody, pub editable_range: Range, @@ -1276,15 +1284,16 @@ pub fn gather_context( snapshot: &BufferSnapshot, cursor_point: language::Point, make_events_prompt: impl FnOnce() -> String + Send + 'static, - can_collect_data: bool, + can_collect_data: CanCollectData, git_info: Option, - recent_files: Option>, cx: &App, ) -> Task> { let local_lsp_store = project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); let diagnostic_groups: Vec<(String, serde_json::Value)> = - if can_collect_data && let Some(local_lsp_store) = local_lsp_store { + if matches!(can_collect_data, CanCollectData(true)) + && let Some(local_lsp_store) = local_lsp_store + { snapshot .diagnostic_groups(None) .into_iter() @@ -1325,12 +1334,11 @@ pub fn gather_context( let body = PredictEditsBody { input_events, input_excerpt: input_excerpt.prompt, - can_collect_data, + can_collect_data: can_collect_data.0, diagnostic_groups, git_info, outline: None, speculated_output: None, - recent_files, }; Ok(GatherContextOutput { @@ -1491,6 +1499,9 @@ pub struct ProviderDataCollection { license_detection_watcher: Option>, } +#[derive(Debug, Clone, Copy)] +pub struct CanCollectData(pub bool); + impl ProviderDataCollection { pub fn new(zeta: Entity, buffer: Option>, cx: &mut App) -> Self { let choice_and_watcher = buffer.and_then(|buffer| { @@ -1524,8 +1535,8 @@ impl ProviderDataCollection { } } - pub fn can_collect_data(&self, cx: &App) -> bool { - self.is_data_collection_enabled(cx) && self.is_project_open_source() + pub fn can_collect_data(&self, cx: &App) -> CanCollectData { + CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source()) } pub fn is_data_collection_enabled(&self, cx: &App) -> bool { @@ -2149,7 +2160,7 @@ mod tests { let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(None, &buffer, cursor, false, cx) + zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx) }); let completion = completion_task.await.unwrap().unwrap(); @@ -2214,7 +2225,7 @@ mod tests { let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(None, &buffer, cursor, false, cx) + zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx) }); let completion = completion_task.await.unwrap().unwrap(); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 76f638057a..7ffbd68898 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -18,7 +18,7 @@ use std::process::exit; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; +use zeta::{CanCollectData, GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; use crate::headless::ZetaCliAppState; @@ -172,9 +172,7 @@ async fn get_context( None => String::new(), }; // Enable gathering extra data not currently needed for edit predictions - let can_collect_data = true; let git_info = None; - let recent_files = None; let mut gather_context_output = cx .update(|cx| { gather_context( @@ -183,9 +181,8 @@ async fn get_context( &snapshot, clipped_cursor, move || events, - can_collect_data, + CanCollectData(true), git_info, - recent_files, cx, ) })?