Cleanup + only record git info if current file may be in repo

This commit is contained in:
Michael Sloan 2025-08-25 23:52:23 -06:00
parent 87609557f0
commit b9cd8f5d2a
No known key found for this signature in database
3 changed files with 80 additions and 72 deletions

View file

@ -152,8 +152,6 @@ pub struct PredictEditsBody {
/// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub recent_files: Option<Vec<PredictEditsRecentFile>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -167,6 +165,8 @@ pub struct PredictEditsGitInfo {
/// URL of the remote called `upstream`.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub remote_upstream_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub recent_files: Option<Vec<PredictEditsRecentFile>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -404,7 +404,7 @@ impl Zeta {
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
can_collect_data: bool,
can_collect_data: CanCollectData,
cx: &mut Context<Self>,
perform_predict_edits: F,
) -> Task<Result<Option<EditPrediction>>>
@ -423,17 +423,10 @@ impl Zeta {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, recent_files) = if let (true, Some(project), Some(file)) =
(can_collect_data, project, snapshot.file())
&& let Some(repository) =
git_repository_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
{
let repository = repository.read(cx);
let git_info = make_predict_edits_git_info(repository);
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
(git_info, Some(recent_files))
let git_info = if matches!(can_collect_data, CanCollectData(true)) {
self.gather_git_info(project.clone(), &buffer_snapshotted_at, &snapshot, cx)
} else {
(None, None)
None
};
let full_path: Arc<Path> = snapshot
@ -452,7 +445,6 @@ impl Zeta {
make_events_prompt,
can_collect_data,
git_info,
recent_files,
cx,
);
@ -725,9 +717,15 @@ and then another
) -> Task<Result<Option<EditPrediction>>> {
use std::future::ready;
self.request_completion_impl(None, project, buffer, position, false, cx, |_params| {
ready(Ok((response, None)))
})
self.request_completion_impl(
None,
project,
buffer,
position,
CanCollectData(false),
cx,
|_params| ready(Ok((response, None))),
)
}
pub fn request_completion(
@ -735,7 +733,7 @@ and then another
project: Option<&Entity<Project>>,
buffer: &Entity<Buffer>,
position: language::Anchor,
can_collect_data: bool,
can_collect_data: CanCollectData,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
self.request_completion_impl(
@ -1132,6 +1130,46 @@ and then another
}
}
fn gather_git_info(
&mut self,
project: Option<&Entity<Project>>,
buffer_snapshotted_at: &Instant,
snapshot: &BufferSnapshot,
cx: &Context<Self>,
) -> Option<PredictEditsGitInfo> {
let project = project?.read(cx);
let file = snapshot.file()?;
let project_path = ProjectPath::from_file(file.as_ref(), cx);
let entry = project.entry_for_path(&project_path, cx)?;
if !worktree_entry_eligible_for_collection(&entry) {
return None;
}
let git_store = project.git_store().read(cx);
let (repository, _repo_path) =
git_store.repository_and_path_for_project_path(&project_path, cx)?;
let repository = repository.read(cx);
let head_sha = repository
.head_commit
.as_ref()
.map(|head_commit| head_commit.sha.to_string());
let remote_origin_url = repository.remote_origin_url.clone();
let remote_upstream_url = repository.remote_upstream_url.clone();
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
return None;
}
let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
recent_files: Some(recent_files),
})
}
fn push_recent_project_entry(&mut self, project_entry_id: ProjectEntryId) {
let now = Instant::now();
if let Some(existing_ix) = self
@ -1166,12 +1204,7 @@ and then another
if let Some(worktree) = project.read(cx).worktree_for_entry(*entry_id, cx)
&& let worktree = worktree.read(cx)
&& let Some(entry) = worktree.entry_for_id(*entry_id)
&& entry.is_file()
&& entry.is_created()
&& !entry.is_ignored
&& !entry.is_private
&& !entry.is_external
&& !entry.is_fifo
&& worktree_entry_eligible_for_collection(entry)
{
let project_path = ProjectPath {
worktree_id: worktree.id(),
@ -1191,12 +1224,6 @@ and then another
self.recent_project_entries.remove(ix);
continue;
}
if let Some(file_status) = repository.status_for_path(&repo_path) {
if file_status.is_ignored() || file_status.is_untracked() {
// entry not removed because it may belong to a nested repository
continue;
}
}
let Ok(active_to_now_ms) =
now.duration_since(*last_active_at).as_millis().try_into()
else {
@ -1215,6 +1242,15 @@ and then another
}
}
fn worktree_entry_eligible_for_collection(entry: &worktree::Entry) -> bool {
entry.is_file()
&& entry.is_created()
&& !entry.is_ignored
&& !entry.is_private
&& !entry.is_external
&& !entry.is_fifo
}
pub struct PerformPredictEditsParams {
pub client: Arc<Client>,
pub llm_token: LlmApiToken,
@ -1237,34 +1273,6 @@ fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b:
.sum()
}
fn git_repository_for_file(
project: &Entity<Project>,
project_path: &ProjectPath,
cx: &App,
) -> Option<Entity<Repository>> {
let git_store = project.read(cx).git_store().read(cx);
git_store
.repository_and_path_for_project_path(project_path, cx)
.map(|(repo, _repo_path)| repo)
}
fn make_predict_edits_git_info(repository: &Repository) -> Option<PredictEditsGitInfo> {
let head_sha = repository
.head_commit
.as_ref()
.map(|head_commit| head_commit.sha.to_string());
let remote_origin_url = repository.remote_origin_url.clone();
let remote_upstream_url = repository.remote_upstream_url.clone();
if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
return None;
}
Some(PredictEditsGitInfo {
head_sha,
remote_origin_url,
remote_upstream_url,
})
}
pub struct GatherContextOutput {
pub body: PredictEditsBody,
pub editable_range: Range<usize>,
@ -1276,15 +1284,16 @@ pub fn gather_context(
snapshot: &BufferSnapshot,
cursor_point: language::Point,
make_events_prompt: impl FnOnce() -> String + Send + 'static,
can_collect_data: bool,
can_collect_data: CanCollectData,
git_info: Option<PredictEditsGitInfo>,
recent_files: Option<Vec<PredictEditsRecentFile>>,
cx: &App,
) -> Task<Result<GatherContextOutput>> {
let local_lsp_store =
project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
let diagnostic_groups: Vec<(String, serde_json::Value)> =
if can_collect_data && let Some(local_lsp_store) = local_lsp_store {
if matches!(can_collect_data, CanCollectData(true))
&& let Some(local_lsp_store) = local_lsp_store
{
snapshot
.diagnostic_groups(None)
.into_iter()
@ -1325,12 +1334,11 @@ pub fn gather_context(
let body = PredictEditsBody {
input_events,
input_excerpt: input_excerpt.prompt,
can_collect_data,
can_collect_data: can_collect_data.0,
diagnostic_groups,
git_info,
outline: None,
speculated_output: None,
recent_files,
};
Ok(GatherContextOutput {
@ -1491,6 +1499,9 @@ pub struct ProviderDataCollection {
license_detection_watcher: Option<Rc<LicenseDetectionWatcher>>,
}
#[derive(Debug, Clone, Copy)]
pub struct CanCollectData(pub bool);
impl ProviderDataCollection {
pub fn new(zeta: Entity<Zeta>, buffer: Option<Entity<Buffer>>, cx: &mut App) -> Self {
let choice_and_watcher = buffer.and_then(|buffer| {
@ -1524,8 +1535,8 @@ impl ProviderDataCollection {
}
}
pub fn can_collect_data(&self, cx: &App) -> bool {
self.is_data_collection_enabled(cx) && self.is_project_open_source()
pub fn can_collect_data(&self, cx: &App) -> CanCollectData {
CanCollectData(self.is_data_collection_enabled(cx) && self.is_project_open_source())
}
pub fn is_data_collection_enabled(&self, cx: &App) -> bool {
@ -2149,7 +2160,7 @@ mod tests {
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(None, &buffer, cursor, false, cx)
zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
});
let completion = completion_task.await.unwrap().unwrap();
@ -2214,7 +2225,7 @@ mod tests {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
let completion_task = zeta.update(cx, |zeta, cx| {
zeta.request_completion(None, &buffer, cursor, false, cx)
zeta.request_completion(None, &buffer, cursor, CanCollectData(false), cx)
});
let completion = completion_task.await.unwrap().unwrap();

View file

@ -18,7 +18,7 @@ use std::process::exit;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
use zeta::{CanCollectData, GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
use crate::headless::ZetaCliAppState;
@ -172,9 +172,7 @@ async fn get_context(
None => String::new(),
};
// Enable gathering extra data not currently needed for edit predictions
let can_collect_data = true;
let git_info = None;
let recent_files = None;
let mut gather_context_output = cx
.update(|cx| {
gather_context(
@ -183,9 +181,8 @@ async fn get_context(
&snapshot,
clipped_cursor,
move || events,
can_collect_data,
CanCollectData(true),
git_info,
recent_files,
cx,
)
})?