diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 171c923154..e78957ec49 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -149,6 +149,22 @@ pub struct PredictEditsBody { pub can_collect_data: bool, #[serde(skip_serializing_if = "Option::is_none", default)] pub diagnostic_groups: Option>, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PredictEditsGitInfo { + /// SHA of git HEAD commit at time of prediction. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub head_sha: Option, + /// URL of the remote called `origin`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_origin_url: Option, + /// URL of the remote called `upstream`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub remote_upstream_url: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index c9f0fc7959..01fc987816 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -246,6 +246,8 @@ pub struct RepositorySnapshot { pub head_commit: Option, pub scan_id: u64, pub merge: MergeDetails, + pub remote_origin_url: Option, + pub remote_upstream_url: Option, } type JobId = u64; @@ -2673,6 +2675,8 @@ impl RepositorySnapshot { head_commit: None, scan_id: 0, merge: Default::default(), + remote_origin_url: None, + remote_upstream_url: None, } } @@ -4818,6 +4822,10 @@ async fn compute_snapshot( None => None, }; + // Used by edit prediction data collection + let remote_origin_url = backend.remote_url("origin"); + let remote_upstream_url = backend.remote_url("upstream"); + let snapshot = RepositorySnapshot { id, statuses_by_path, @@ -4826,6 +4834,8 @@ async fn compute_snapshot( branch, head_commit, merge: merge_details, + remote_origin_url, + remote_upstream_url, }; Ok((snapshot, events)) diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 1cd8e8d17f..b1bd737dbf 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -19,7 +19,7 @@ use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, - PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME, + PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, ZED_VERSION_HEADER_NAME, }; use collections::{HashMap, HashSet, VecDeque}; use futures::AsyncReadExt; @@ -34,7 +34,7 @@ use language::{ }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use postage::watch; -use project::Project; +use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; use std::str::FromStr; @@ -400,6 +400,14 @@ impl Zeta { let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); + let git_info = if let (true, Some(project), Some(file)) = + (can_collect_data, project, snapshot.file()) + { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + } else { + None + }; + let full_path: Arc = snapshot .file() .map(|f| Arc::from(f.full_path(cx).as_path())) @@ -415,6 +423,7 @@ impl Zeta { cursor_point, make_events_prompt, can_collect_data, + git_info, cx, ); @@ -1155,6 +1164,35 @@ fn common_prefix, T2: Iterator>(a: T1, b: .sum() } +fn git_info_for_file( + project: &Entity, + project_path: &ProjectPath, + cx: &App, +) -> Option { + let git_store = project.read(cx).git_store().read(cx); + if let Some((repository, _repo_path)) = + git_store.repository_and_path_for_project_path(project_path, cx) + { + let repository = repository.read(cx); + let head_sha = repository + .head_commit + .as_ref() + .map(|head_commit| head_commit.sha.to_string()); + let remote_origin_url = repository.remote_origin_url.clone(); + let remote_upstream_url = repository.remote_upstream_url.clone(); + if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() { + return None; + } + Some(PredictEditsGitInfo { + head_sha, + remote_origin_url, + remote_upstream_url, + }) + } else { + None + } +} + pub struct GatherContextOutput { pub body: PredictEditsBody, pub editable_range: Range, @@ -1167,6 +1205,7 @@ pub fn gather_context( cursor_point: language::Point, make_events_prompt: impl FnOnce() -> String + Send + 'static, can_collect_data: bool, + git_info: Option, cx: &App, ) -> Task> { let local_lsp_store = @@ -1216,6 +1255,7 @@ pub fn gather_context( outline: Some(input_outline), can_collect_data, diagnostic_groups, + git_info, }; Ok(GatherContextOutput { diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index c5374b56c9..adf7683152 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -172,6 +172,7 @@ async fn get_context( None => String::new(), }; let can_collect_data = false; + let git_info = None; cx.update(|cx| { gather_context( project.as_ref(), @@ -180,6 +181,7 @@ async fn get_context( clipped_cursor, move || events, can_collect_data, + git_info, cx, ) })?