Cleanup + only record git info if current file may be in repo

This commit is contained in:
Michael Sloan 2025-08-25 23:52:23 -06:00
parent 87609557f0
commit b9cd8f5d2a
No known key found for this signature in database
3 changed files with 80 additions and 72 deletions

View file

@ -152,8 +152,6 @@ pub struct PredictEditsBody {
/// Info about the git repository state, only present when can_collect_data is true. /// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)] #[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>, pub git_info: Option<PredictEditsGitInfo>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub recent_files: Option<Vec<PredictEditsRecentFile>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -167,6 +165,8 @@ pub struct PredictEditsGitInfo {
/// URL of the remote called `upstream`. /// URL of the remote called `upstream`.
#[serde(skip_serializing_if = "Option::is_none", default)] #[serde(skip_serializing_if = "Option::is_none", default)]
pub remote_upstream_url: Option<String>, pub remote_upstream_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub recent_files: Option<Vec<PredictEditsRecentFile>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -404,7 +404,7 @@ impl Zeta {
project: Option<&Entity<Project>>, project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>, buffer: &Entity<Buffer>,
cursor: language::Anchor, cursor: language::Anchor,
can_collect_data: bool, can_collect_data: CanCollectData,
cx: &mut Context<Self>, cx: &mut Context<Self>,
perform_predict_edits: F, perform_predict_edits: F,
) -> Task<Result<Option<EditPrediction>>> ) -> Task<Result<Option<EditPrediction>>>
@ -423,17 +423,10 @@ impl Zeta {
let llm_token = self.llm_token.clone(); let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx); let app_version = AppVersion::global(cx);
let (git_info, recent_files) = if let (true, Some(project), Some(file)) = let git_info = if matches!(can_collect_data, CanCollectData(true)) {
(can_collect_data, project, snapshot.file()) self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
&& let Some(repository) =
git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
{
let repository = repository.read(cx);
let git_info = make_predict_edits_git_info(repository);
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
(git_info, Some(recent_files))
} else { } else {
(None, None) None
}; };
let full_path: Arc<Path> = snapshot let full_path: Arc<Path> = snapshot
@ -452,7 +445,6 @@ impl Zeta {
make_events_prompt, make_events_prompt,
can_collect_data, can_collect_data,
git_info, git_info,
recent_files,
cx, cx,
); );
@ -725,9 +717,15 @@ and then another
) -> Task<Result<Option<EditPrediction>>> { ) -> Task<Result<Option<EditPrediction>>> {
use std::future::ready; use std::future::ready;
self.request_completion_impl(None, project, buffer, position, false, cx, |_params| { self.request_completion_impl(
ready(Ok((response, None))) None,
}) project,
buffer,
position,
CanCollectData(false),
cx,
|_params| ready(Ok((response, None))),
)
} }
pub fn request_completion( pub fn request_completion(
@ -735,7 +733,7 @@ and then another
project: Option<&Entity<Project>>, project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>, buffer: &Entity<Buffer>,
position: language::Anchor, position: language::Anchor,
can_collect_data: bool, can_collect_data: CanCollectData,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> { ) -> Task<Result<Option<EditPrediction>>> {
self.request_completion_impl( self.request_completion_impl(
@ -1132,6 +1130,46 @@ and then another
} }
} }
fn gather_git_info(
&mut self,
project: Option<&Entity<Project>>,
buffer_snapshotted_at: &Instant,
snapshot: &BufferSnapshot,
cx: &Context<Self>,
) -> Option<PredictEditsGitInfo> {
let project = project?.read(cx);
let file = snapshot.file()?;
let project_path = ProjectPath::from_file(file.as_ref(), cx);
let entry = project.entry_for_path(&project_path, cx)?;
if !worktree_entry_eligible_for_collection(&entry) {
return None;
}
let git_store = project.git_store().read(cx);
let (repository, _repo_path) =
git_store.repository_and_path_for_project_path(&project_path, cx)?;
let repository = repository.read(cx);
let head_sha = repository
.head_commit
.as_ref()
.map(|head_commit| head_commit.sha.to_string());
let remote_origin_url = repository.remote_origin_url.clone();
let remote_upstream_url = repository.remote_upstream_url.clone();
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
return None;
}
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
recent_files: Some(recent_files),
})
}
fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) { fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
let now = Instant::now(); let now = Instant::now();
if let Some(existing_ix) = self if let Some(existing_ix) = self
@ -1166,12 +1204,7 @@ and then another
if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx) if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
&& let worktree = worktree.read(cx) && let worktree = worktree.read(cx)
&& let Some(entry) = worktree.entry_for_id(*entry_id) && let Some(entry) = worktree.entry_for_id(*entry_id)
&& entry.is_file() && worktree_entry_eligible_for_collection(entry)
&& entry.is_created()
&& !entry.is_ignored
&& !entry.is_private
&& !entry.is_external
&& !entry.is_fifo
{ {
let project_path = ProjectPath { let project_path = ProjectPath {
worktree_id: worktree.id(), worktree_id: worktree.id(),
@ -1191,12 +1224,6 @@ and then another
self.recent_project_entries.remove(ix); self.recent_project_entries.remove(ix);
continue; continue;
} }
if let Some(file_status) = repository.status_for_path(&repo_path) {
if file_status.is_ignored() || file_status.is_untracked() {
// entry not removed because it may belong to a nested repository
continue;
}
}
let Ok(active_to_now_ms) = let Ok(active_to_now_ms) =
now.duration_since(*last_active_at).as_millis().try_into() now.duration_since(*last_active_at).as_millis().try_into()
else { else {
@ -1215,6 +1242,15 @@ and then another
} }
} }
fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
entry.is_file()
&& entry.is_created()
&& !entry.is_ignored
&& !entry.is_private
&& !entry.is_external
&& !entry.is_fifo
}
pub struct PerformPredictEditsParams { pub struct PerformPredictEditsParams {
pub client: Arc<Client>, pub client: Arc<Client>,
pub llm_token: LlmApiToken, pub llm_token: LlmApiToken,
@ -1237,34 +1273,6 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum() .sum()
} }
fn git_repository_for_file(
project: &Entity<Project>,
project_path: &ProjectPath,
cx: &App,
) -> Option<Entity<Repository>> {
let git_store = project.read(cx).git_store().read(cx);
git_store
.repository_and_path_for_project_path(project_path, cx)
.map(|(repo, _repo_path)| repo)
}
fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
let head_sha = repository
.head_commit
.as_ref()
.map(|head_commit| head_commit.sha.to_string());
let remote_origin_url = repository.remote_origin_url.clone();
let remote_upstream_url = repository.remote_upstream_url.clone();
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
return None;
}
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
})
}
pub struct GatherContextOutput { pub struct GatherContextOutput {
pub body: PredictEditsBody, pub body: PredictEditsBody,
pub editable_range: Range<usize>, pub editable_range: Range<usize>,
@ -1276,15 +1284,16 @@ pub fn gather_context(
snapshot: &BufferSnapshot, snapshot: &BufferSnapshot,
cursor_point: language::Point, cursor_point: language::Point,
make_events_prompt: impl FnOnce() -> String + Send + 'static, make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool, can_collect_data: CanCollectData,
git_info: Option<PredictEditsGitInfo>, git_info: Option<PredictEditsGitInfo>,
recent_files: Option<Vec<PredictEditsRecentFile>>,
cx: &App, cx: &App,
) -> Task<Result<GatherContextOutput>> { ) -> Task<Result<GatherContextOutput>> {
let local_lsp_store = let local_lsp_store =
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
let diagnostic_groups: Vec<(String, serde_json::Value)> = let diagnostic_groups: Vec<(String, serde_json::Value)> =
if can_collect_data && let Some(local_lsp_store) = local_lsp_store { if matches!(can_collect_data, CanCollectData(true))
&& let Some(local_lsp_store) = local_lsp_store
{
snapshot snapshot
.diagnostic_groups(None) .diagnostic_groups(None)
.into_iter() .into_iter()
@ -1325,12 +1334,11 @@ pub fn gather_context(
let body = PredictEditsBody { let body = PredictEditsBody {
input_events, input_events,
input_excerpt: input_excerpt.prompt, input_excerpt: input_excerpt.prompt,
can_collect_data, can_collect_data: can_collect_data.0,
diagnostic_groups, diagnostic_groups,
git_info, git_info,
outline: None, outline: None,
speculated_output: None, speculated_output: None,
recent_files,
}; };
Ok(GatherContextOutput { Ok(GatherContextOutput {
@ -1491,6 +1499,9 @@ pub struct ProviderDataCollection {
license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>, license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
} }
#[derive(Debug, Clone, Copy)]
pub struct CanCollectData(pub bool);
impl ProviderDataCollection { impl ProviderDataCollection {
pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self { pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
let choice_and_watcher = buffer.and_then(|buffer| { let choice_and_watcher = buffer.and_then(|buffer| {
@ -1524,8 +1535,8 @@ impl ProviderDataCollection {
} }
} }
pub fn can_collect_data(&self, cx: &App) -> bool { pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
self.is_data_collection_enabled(cx) && self.is_project_open_source() CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
} }
pub fn is_data_collection_enabled(&self, cx: &App) -> bool { pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
@ -2149,7 +2160,7 @@ mod tests {
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| { let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(None, &buffer, cursor, false, cx) zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
}); });
let completion = completion_task.await.unwrap().unwrap(); let completion = completion_task.await.unwrap().unwrap();
@ -2214,7 +2225,7 @@ mod tests {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| { let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(None, &buffer, cursor, false, cx) zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
}); });
let completion = completion_task.await.unwrap().unwrap(); let completion = completion_task.await.unwrap().unwrap();

View file

@ -18,7 +18,7 @@ use std::process::exit;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context}; use zeta::{CanCollectData, GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
use crate::headless::ZetaCliAppState; use crate::headless::ZetaCliAppState;
@ -172,9 +172,7 @@ async fn get_context(
None => String::new(), None => String::new(),
}; };
// Enable gathering extra data not currently needed for edit predictions // Enable gathering extra data not currently needed for edit predictions
let can_collect_data = true;
let git_info = None; let git_info = None;
let recent_files = None;
let mut gather_context_output = cx let mut gather_context_output = cx
.update(|cx| { .update(|cx| {
gather_context( gather_context(
@ -183,9 +181,8 @@ async fn get_context(
&snapshot, &snapshot,
clipped_cursor, clipped_cursor,
move || events, move || events,
can_collect_data, CanCollectData(true),
git_info, git_info,
recent_files,
cx, cx,
) )
})? })?