Start on a Git-based review flow (#27103)

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2025-03-19 20:00:21 +01:00 committed by GitHub
parent 68262fe7e4
commit 33faa66e35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 396 additions and 131 deletions

3
Cargo.lock generated
View file

@ -467,6 +467,7 @@ dependencies = [
"futures 0.3.31", "futures 0.3.31",
"fuzzy", "fuzzy",
"git", "git",
"git_ui",
"gpui", "gpui",
"heed", "heed",
"html_to_markdown", "html_to_markdown",
@ -5601,11 +5602,13 @@ dependencies = [
"serde_json", "serde_json",
"smol", "smol",
"sum_tree", "sum_tree",
"tempfile",
"text", "text",
"time", "time",
"unindent", "unindent",
"url", "url",
"util", "util",
"uuid",
] ]
[[package]] [[package]]

View file

@ -39,6 +39,7 @@ fs.workspace = true
futures.workspace = true futures.workspace = true
fuzzy.workspace = true fuzzy.workspace = true
git.workspace = true git.workspace = true
git_ui.workspace = true
gpui.workspace = true gpui.workspace = true
heed.workspace = true heed.workspace = true
html_to_markdown.workspace = true html_to_markdown.workspace = true

View file

@ -550,6 +550,7 @@ impl ActiveThread {
let thread = self.thread.read(cx); let thread = self.thread.read(cx);
// Get all the data we need from thread before we start using it in closures // Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let context = thread.context_for_message(message_id); let context = thread.context_for_message(message_id);
let tool_uses = thread.tool_uses_for_message(message_id); let tool_uses = thread.tool_uses_for_message(message_id);
let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id); let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
@ -734,7 +735,25 @@ impl ActiveThread {
), ),
}; };
styled_message.into_any() v_flex()
.when_some(checkpoint, |parent, checkpoint| {
parent.child(
h_flex().pl_2().child(
Button::new("restore-checkpoint", "Restore Checkpoint")
.icon(IconName::Undo)
.size(ButtonSize::Compact)
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
});
})),
),
)
})
.child(styled_message)
.into_any()
} }
fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement { fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {

View file

@ -3,23 +3,25 @@ use std::sync::Arc;
use collections::HashSet; use collections::HashSet;
use editor::actions::MoveUp; use editor::actions::MoveUp;
use editor::{Editor, EditorElement, EditorEvent, EditorStyle}; use editor::{Editor, EditorElement, EditorEvent, EditorStyle};
use file_icons::FileIcons;
use fs::Fs; use fs::Fs;
use git::ExpandCommitEditor;
use git_ui::git_panel;
use gpui::{ use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle, Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
WeakEntity, WeakEntity,
}; };
use language_model::LanguageModelRegistry; use language_model::LanguageModelRegistry;
use language_model_selector::ToggleModelSelector; use language_model_selector::ToggleModelSelector;
use project::Project;
use rope::Point; use rope::Point;
use settings::Settings; use settings::Settings;
use std::time::Duration; use std::time::Duration;
use text::Bias; use text::Bias;
use theme::ThemeSettings; use theme::ThemeSettings;
use ui::{ use ui::{
prelude::*, ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip,
Tooltip,
}; };
use util::ResultExt;
use vim_mode_setting::VimModeSetting; use vim_mode_setting::VimModeSetting;
use workspace::notifications::{NotificationId, NotifyTaskExt}; use workspace::notifications::{NotificationId, NotifyTaskExt};
use workspace::{Toast, Workspace}; use workspace::{Toast, Workspace};
@ -37,6 +39,7 @@ pub struct MessageEditor {
thread: Entity<Thread>, thread: Entity<Thread>,
editor: Entity<Editor>, editor: Entity<Editor>,
workspace: WeakEntity<Workspace>, workspace: WeakEntity<Workspace>,
project: Entity<Project>,
context_store: Entity<ContextStore>, context_store: Entity<ContextStore>,
context_strip: Entity<ContextStrip>, context_strip: Entity<ContextStrip>,
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>, context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
@ -44,7 +47,6 @@ pub struct MessageEditor {
inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>, inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>, model_selector: Entity<AssistantModelSelector>,
tool_selector: Entity<ToolSelector>, tool_selector: Entity<ToolSelector>,
edits_expanded: bool,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -107,8 +109,9 @@ impl MessageEditor {
]; ];
Self { Self {
thread,
editor: editor.clone(), editor: editor.clone(),
project: thread.read(cx).project().clone(),
thread,
workspace, workspace,
context_store, context_store,
context_strip, context_strip,
@ -125,7 +128,6 @@ impl MessageEditor {
) )
}), }),
tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)), tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
edits_expanded: false,
_subscriptions: subscriptions, _subscriptions: subscriptions,
} }
} }
@ -206,12 +208,15 @@ impl MessageEditor {
let thread = self.thread.clone(); let thread = self.thread.clone();
let context_store = self.context_store.clone(); let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store();
let checkpoint = git_store.read(cx).checkpoint(cx);
cx.spawn(async move |_, cx| { cx.spawn(async move |_, cx| {
refresh_task.await; refresh_task.await;
let checkpoint = checkpoint.await.log_err();
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>(); let context = context_store.read(cx).snapshot(cx).collect::<Vec<_>>();
thread.insert_user_message(user_message, context, cx); thread.insert_user_message(user_message, context, checkpoint, cx);
thread.send_to_model(model, request_kind, cx); thread.send_to_model(model, request_kind, cx);
}) })
.ok(); .ok();
@ -347,8 +352,12 @@ impl Render for MessageEditor {
px(64.) px(64.)
}; };
let changed_buffers = self.thread.read(cx).scripting_changed_buffers(cx); let project = self.thread.read(cx).project();
let changed_buffers_count = changed_buffers.len(); let changed_files = if let Some(repository) = project.read(cx).active_repository(cx) {
repository.read(cx).status().count()
} else {
0
};
v_flex() v_flex()
.size_full() .size_full()
@ -410,7 +419,7 @@ impl Render for MessageEditor {
), ),
) )
}) })
.when(changed_buffers_count > 0, |parent| { .when(changed_files > 0, |parent| {
parent.child( parent.child(
v_flex() v_flex()
.mx_2() .mx_2()
@ -421,96 +430,60 @@ impl Render for MessageEditor {
.rounded_t_md() .rounded_t_md()
.child( .child(
h_flex() h_flex()
.gap_2() .justify_between()
.p_2() .p_2()
.child( .child(
Disclosure::new("edits-disclosure", self.edits_expanded) h_flex()
.on_click(cx.listener(|this, _ev, _window, cx| { .gap_2()
this.edits_expanded = !this.edits_expanded;
cx.notify();
})),
)
.child( .child(
Label::new("Edits") IconButton::new(
.size(LabelSize::XSmall) "edits-disclosure",
.color(Color::Muted), IconName::GitBranchSmall,
)
.icon_size(IconSize::Small)
.on_click(
|_ev, _window, cx| {
cx.defer(|cx| {
cx.dispatch_action(&git_panel::ToggleFocus)
});
},
),
) )
.child(Label::new("").size(LabelSize::XSmall).color(Color::Muted))
.child( .child(
Label::new(format!( Label::new(format!(
"{} {}", "{} {} changed",
changed_buffers_count, changed_files,
if changed_buffers_count == 1 { if changed_files == 1 { "file" } else { "files" }
"file"
} else {
"files"
}
)) ))
.size(LabelSize::XSmall) .size(LabelSize::XSmall)
.color(Color::Muted), .color(Color::Muted),
), ),
) )
.when(self.edits_expanded, |parent| {
parent.child(
v_flex().bg(cx.theme().colors().editor_background).children(
changed_buffers.enumerate().flat_map(|(index, buffer)| {
let file = buffer.read(cx).file()?;
let path = file.path();
let parent_label = path.parent().and_then(|parent| {
let parent_str = parent.to_string_lossy();
if parent_str.is_empty() {
None
} else {
Some(
Label::new(format!(
"{}{}",
parent_str,
std::path::MAIN_SEPARATOR_STR
))
.color(Color::Muted)
.size(LabelSize::Small),
)
}
});
let name_label = path.file_name().map(|name| {
Label::new(name.to_string_lossy().to_string())
.size(LabelSize::Small)
});
let file_icon = FileIcons::get_icon(&path, cx)
.map(Icon::from_path)
.unwrap_or_else(|| Icon::new(IconName::File));
let element = div()
.p_2()
.when(index + 1 < changed_buffers_count, |parent| {
parent
.border_color(cx.theme().colors().border)
.border_b_1()
})
.child( .child(
h_flex() h_flex()
.gap_2() .gap_2()
.child(file_icon)
.child( .child(
// TODO: handle overflow Button::new("review", "Review")
h_flex() .label_size(LabelSize::XSmall)
.children(parent_label) .on_click(|_event, _window, cx| {
.children(name_label), cx.defer(|cx| {
) cx.dispatch_action(
// TODO: show lines changed &git_ui::project_diff::Diff,
.child(Label::new("+").color(Color::Created))
.child(Label::new("-").color(Color::Deleted)),
); );
});
Some(element) }),
)
.child(
Button::new("commit", "Commit")
.label_size(LabelSize::XSmall)
.on_click(|_event, _window, cx| {
cx.defer(|cx| {
cx.dispatch_action(&ExpandCommitEditor)
});
}), }),
), ),
) ),
}), ),
) )
}) })
.child( .child(

View file

@ -16,6 +16,7 @@ use language_model::{
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, TokenUsage, Role, StopReason, TokenUsage,
}; };
use project::git::GitStoreCheckpoint;
use project::Project; use project::Project;
use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder}; use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
use scripting_tool::{ScriptingSession, ScriptingTool}; use scripting_tool::{ScriptingSession, ScriptingTool};
@ -89,6 +90,12 @@ pub struct GitState {
pub diff: Option<String>, pub diff: Option<String>,
} }
#[derive(Clone)]
pub struct ThreadCheckpoint {
message_id: MessageId,
git_checkpoint: GitStoreCheckpoint,
}
/// A thread of conversation with the LLM. /// A thread of conversation with the LLM.
pub struct Thread { pub struct Thread {
id: ThreadId, id: ThreadId,
@ -99,6 +106,7 @@ pub struct Thread {
next_message_id: MessageId, next_message_id: MessageId,
context: BTreeMap<ContextId, ContextSnapshot>, context: BTreeMap<ContextId, ContextSnapshot>,
context_by_message: HashMap<MessageId, Vec<ContextId>>, context_by_message: HashMap<MessageId, Vec<ContextId>>,
checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
completion_count: usize, completion_count: usize,
pending_completions: Vec<PendingCompletion>, pending_completions: Vec<PendingCompletion>,
project: Entity<Project>, project: Entity<Project>,
@ -128,6 +136,7 @@ impl Thread {
next_message_id: MessageId(0), next_message_id: MessageId(0),
context: BTreeMap::default(), context: BTreeMap::default(),
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
checkpoints_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project: project.clone(), project: project.clone(),
@ -188,6 +197,7 @@ impl Thread {
next_message_id, next_message_id,
context: BTreeMap::default(), context: BTreeMap::default(),
context_by_message: HashMap::default(), context_by_message: HashMap::default(),
checkpoints_by_message: HashMap::default(),
completion_count: 0, completion_count: 0,
pending_completions: Vec::new(), pending_completions: Vec::new(),
project, project,
@ -249,6 +259,45 @@ impl Thread {
&self.tools &self.tools
} }
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
Some(ThreadCheckpoint {
message_id: id,
git_checkpoint: checkpoint,
})
}
pub fn restore_checkpoint(
&mut self,
checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let project = self.project.read(cx);
let restore = project
.git_store()
.read(cx)
.restore_checkpoint(checkpoint.git_checkpoint, cx);
cx.spawn(async move |this, cx| {
restore.await?;
this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
})
}
pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(message_ix) = self
.messages
.iter()
.rposition(|message| message.id == message_id)
else {
return;
};
for deleted_message in self.messages.drain(message_ix..) {
self.context_by_message.remove(&deleted_message.id);
self.checkpoints_by_message.remove(&deleted_message.id);
}
cx.notify();
}
pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> { pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
let context = self.context_by_message.get(&id)?; let context = self.context_by_message.get(&id)?;
Some( Some(
@ -296,13 +345,6 @@ impl Thread {
self.scripting_tool_use.tool_results_for_message(id) self.scripting_tool_use.tool_results_for_message(id)
} }
pub fn scripting_changed_buffers<'a>(
&self,
cx: &'a App,
) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
self.scripting_session.read(cx).changed_buffers()
}
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
self.tool_use.message_has_tool_results(message_id) self.tool_use.message_has_tool_results(message_id)
} }
@ -315,6 +357,7 @@ impl Thread {
&mut self, &mut self,
text: impl Into<String>, text: impl Into<String>,
context: Vec<ContextSnapshot>, context: Vec<ContextSnapshot>,
checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> MessageId { ) -> MessageId {
let message_id = self.insert_message(Role::User, text, cx); let message_id = self.insert_message(Role::User, text, cx);
@ -322,6 +365,9 @@ impl Thread {
self.context self.context
.extend(context.into_iter().map(|context| (context.id, context))); .extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids); self.context_by_message.insert(message_id, context_ids);
if let Some(checkpoint) = checkpoint {
self.checkpoints_by_message.insert(message_id, checkpoint);
}
message_id message_id
} }
@ -941,6 +987,7 @@ impl Thread {
// so for now we provide some text to keep the model on track. // so for now we provide some text to keep the model on track.
"Here are the tool results.", "Here are the tool results.",
Vec::new(), Vec::new(),
None,
cx, cx,
); );
} }
@ -1144,6 +1191,10 @@ impl Thread {
&self.action_log &self.action_log
} }
pub fn project(&self) -> &Entity<Project> {
&self.project
}
pub fn cumulative_token_usage(&self) -> TokenUsage { pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone() self.cumulative_token_usage.clone()
} }

View file

@ -82,7 +82,7 @@ impl Eval {
assistant.update(cx, |assistant, cx| { assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| { assistant.thread.update(cx, |thread, cx| {
let context = vec![]; let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, cx); thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx); thread.send_to_model(model, RequestKind::Chat, cx);
}); });
})?; })?;

View file

@ -408,4 +408,12 @@ impl GitRepository for FakeGitRepository {
) -> BoxFuture<Result<String>> { ) -> BoxFuture<Result<String>> {
unimplemented!() unimplemented!()
} }
fn checkpoint(&self, _cx: AsyncApp) -> BoxFuture<Result<git::Oid>> {
unimplemented!()
}
fn restore_checkpoint(&self, _oid: git::Oid, _cx: AsyncApp) -> BoxFuture<Result<()>> {
unimplemented!()
}
} }

View file

@ -35,6 +35,7 @@ text.workspace = true
time.workspace = true time.workspace = true
url.workspace = true url.workspace = true
util.workspace = true util.workspace = true
uuid.workspace = true
futures.workspace = true futures.workspace = true
[dev-dependencies] [dev-dependencies]
@ -43,3 +44,4 @@ serde_json.workspace = true
text = { workspace = true, features = ["test-support"] } text = { workspace = true, features = ["test-support"] }
unindent.workspace = true unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] }
tempfile.workspace = true

View file

@ -1,5 +1,5 @@
use crate::status::GitStatus; use crate::status::GitStatus;
use crate::SHORT_SHA_LENGTH; use crate::{Oid, SHORT_SHA_LENGTH};
use anyhow::{anyhow, Context as _, Result}; use anyhow::{anyhow, Context as _, Result};
use collections::HashMap; use collections::HashMap;
use futures::future::BoxFuture; use futures::future::BoxFuture;
@ -22,6 +22,7 @@ use std::{
use sum_tree::MapSeekTarget; use sum_tree::MapSeekTarget;
use util::command::new_smol_command; use util::command::new_smol_command;
use util::ResultExt; use util::ResultExt;
use uuid::Uuid;
pub use askpass::{AskPassResult, AskPassSession}; pub use askpass::{AskPassResult, AskPassSession};
@ -287,6 +288,12 @@ pub trait GitRepository: Send + Sync {
/// Run git diff /// Run git diff
fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture<Result<String>>; fn diff(&self, diff: DiffType, cx: AsyncApp) -> BoxFuture<Result<String>>;
/// Creates a checkpoint for the repository.
fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>>;
/// Resets to a previously-created checkpoint.
fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>>;
} }
pub enum DiffType { pub enum DiffType {
@ -1025,6 +1032,89 @@ impl GitRepository for RealGitRepository {
}) })
.boxed() .boxed()
} }
fn checkpoint(&self, cx: AsyncApp) -> BoxFuture<Result<Oid>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
let executor = cx.background_executor().clone();
cx.background_spawn(async move {
let working_directory = working_directory?;
let index_file_path = working_directory.join(".git/index.tmp");
let delete_temp_index = util::defer({
let index_file_path = index_file_path.clone();
|| {
executor
.spawn(async move {
smol::fs::remove_file(index_file_path).await.log_err();
})
.detach();
}
});
let run_git_command = async |args: &[&str]| {
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.env("GIT_INDEX_FILE", &index_file_path)
.env("GIT_AUTHOR_NAME", "Zed")
.env("GIT_AUTHOR_EMAIL", "hi@zed.dev")
.env("GIT_COMMITTER_NAME", "Zed")
.env("GIT_COMMITTER_EMAIL", "hi@zed.dev")
.args(args)
.output()
.await?;
if output.status.success() {
anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
} else {
let error = String::from_utf8_lossy(&output.stderr);
Err(anyhow!("Git command failed: {:?}", error))
}
};
run_git_command(&["add", "--all"]).await?;
let tree = run_git_command(&["write-tree"]).await?;
let commit_sha = run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await?;
let ref_name = Uuid::new_v4().to_string();
run_git_command(&["update-ref", &format!("refs/heads/{ref_name}"), &commit_sha])
.await?;
smol::fs::remove_file(index_file_path).await.ok();
delete_temp_index.abort();
commit_sha.parse()
})
.boxed()
}
fn restore_checkpoint(&self, oid: Oid, cx: AsyncApp) -> BoxFuture<Result<()>> {
let working_directory = self.working_directory();
let git_binary_path = self.git_binary_path.clone();
cx.background_spawn(async move {
let working_directory = working_directory?;
let index_file_path = working_directory.join(".git/index.tmp");
let run_git_command = async |args: &[&str]| {
let output = new_smol_command(&git_binary_path)
.current_dir(&working_directory)
.env("GIT_INDEX_FILE", &index_file_path)
.args(args)
.output()
.await?;
if output.status.success() {
anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string())
} else {
let error = String::from_utf8_lossy(&output.stderr);
Err(anyhow!("Git command failed: {:?}", error))
}
};
run_git_command(&["restore", "--source", &oid.to_string(), "--worktree", "."]).await?;
run_git_command(&["read-tree", &oid.to_string()]).await?;
run_git_command(&["clean", "-d", "--force"]).await?;
Ok(())
})
.boxed()
}
} }
async fn run_remote_command( async fn run_remote_command(
@ -1260,6 +1350,48 @@ fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> {
} }
} }
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test]
async fn test_checkpoint(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let repo_dir = tempfile::tempdir().unwrap();
git2::Repository::init(repo_dir.path()).unwrap();
let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap();
smol::fs::write(repo_dir.path().join("foo"), "foo")
.await
.unwrap();
let checkpoint_sha = repo.checkpoint(cx.to_async()).await.unwrap();
smol::fs::write(repo_dir.path().join("foo"), "bar")
.await
.unwrap();
smol::fs::write(repo_dir.path().join("baz"), "qux")
.await
.unwrap();
repo.restore_checkpoint(checkpoint_sha, cx.to_async())
.await
.unwrap();
assert_eq!(
smol::fs::read_to_string(repo_dir.path().join("foo"))
.await
.unwrap(),
"foo"
);
assert_eq!(
smol::fs::read_to_string(repo_dir.path().join("baz"))
.await
.ok(),
None
);
}
#[test] #[test]
fn test_branches_parsing() { fn test_branches_parsing() {
// suppress "help: octal escapes are not supported, `\0` is always null" // suppress "help: octal escapes are not supported, `\0` is always null"
@ -1286,3 +1418,4 @@ fn test_branches_parsing() {
}] }]
) )
} }
}

View file

@ -11,10 +11,10 @@ use collections::HashMap;
use fs::Fs; use fs::Fs;
use futures::{ use futures::{
channel::{mpsc, oneshot}, channel::{mpsc, oneshot},
future::{OptionFuture, Shared}, future::{self, OptionFuture, Shared},
FutureExt as _, StreamExt as _, FutureExt as _, StreamExt as _,
}; };
use git::repository::DiffType; use git::{repository::DiffType, Oid};
use git::{ use git::{
repository::{ repository::{
Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath, Branch, CommitDetails, GitRepository, PushOptions, Remote, RemoteCommandOutput, RepoPath,
@ -117,6 +117,16 @@ enum GitStoreState {
}, },
} }
#[derive(Clone)]
pub struct GitStoreCheckpoint {
checkpoints_by_dot_git_abs_path: HashMap<PathBuf, RepositoryCheckpoint>,
}
#[derive(Copy, Clone)]
pub struct RepositoryCheckpoint {
sha: Oid,
}
pub struct Repository { pub struct Repository {
commit_message_buffer: Option<Entity<Buffer>>, commit_message_buffer: Option<Entity<Buffer>>,
git_store: WeakEntity<GitStore>, git_store: WeakEntity<GitStore>,
@ -506,6 +516,45 @@ impl GitStore {
diff_state.read(cx).uncommitted_diff.as_ref()?.upgrade() diff_state.read(cx).uncommitted_diff.as_ref()?.upgrade()
} }
pub fn checkpoint(&self, cx: &App) -> Task<Result<GitStoreCheckpoint>> {
let mut dot_git_abs_paths = Vec::new();
let mut checkpoints = Vec::new();
for repository in self.repositories.values() {
let repository = repository.read(cx);
dot_git_abs_paths.push(repository.dot_git_abs_path.clone());
checkpoints.push(repository.checkpoint().map(|checkpoint| checkpoint?));
}
cx.background_executor().spawn(async move {
let checkpoints: Vec<RepositoryCheckpoint> = future::try_join_all(checkpoints).await?;
Ok(GitStoreCheckpoint {
checkpoints_by_dot_git_abs_path: dot_git_abs_paths
.into_iter()
.zip(checkpoints)
.collect(),
})
})
}
pub fn restore_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task<Result<()>> {
let repositories_by_dot_git_abs_path = self
.repositories
.values()
.map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo))
.collect::<HashMap<_, _>>();
let mut tasks = Vec::new();
for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path {
if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) {
tasks.push(repository.read(cx).restore_checkpoint(checkpoint));
}
}
cx.background_spawn(async move {
future::try_join_all(tasks).await?;
Ok(())
})
}
fn downstream_client(&self) -> Option<(AnyProtoClient, ProjectId)> { fn downstream_client(&self) -> Option<(AnyProtoClient, ProjectId)> {
match &self.state { match &self.state {
GitStoreState::Local { GitStoreState::Local {
@ -2922,4 +2971,30 @@ impl Repository {
} }
}) })
} }
pub fn checkpoint(&self) -> oneshot::Receiver<Result<RepositoryCheckpoint>> {
self.send_job(|repo, cx| async move {
match repo {
GitRepo::Local(git_repository) => {
let sha = git_repository.checkpoint(cx).await?;
Ok(RepositoryCheckpoint { sha })
}
GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
}
})
}
pub fn restore_checkpoint(
&self,
checkpoint: RepositoryCheckpoint,
) -> oneshot::Receiver<Result<()>> {
self.send_job(move |repo, cx| async move {
match repo {
GitRepo::Local(git_repository) => {
git_repository.restore_checkpoint(checkpoint.sha, cx).await
}
GitRepo::Remote { .. } => Err(anyhow!("not implemented yet")),
}
})
}
} }