From e14ebcf26734d17ec070b6df36a46541debe5845 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 21 Mar 2025 15:10:43 +0100 Subject: [PATCH] Show "Restore Checkpoint" only when there were changes (#27243) Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga Co-authored-by: Bennet Bo Fenner Co-authored-by: Danilo Leal --- Cargo.lock | 1 + crates/assistant2/src/active_thread.rs | 3 +- crates/assistant2/src/message_editor.rs | 9 +- crates/assistant2/src/thread.rs | 83 +++++-- crates/assistant_eval/src/eval.rs | 2 +- crates/fs/src/fake_git_repo.rs | 9 + crates/git/Cargo.toml | 1 + crates/git/src/repository.rs | 301 +++++++++++++++++------- crates/project/src/git_store.rs | 53 +++++ 9 files changed, 350 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1c24bf9325..c1ad88e6d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5594,6 +5594,7 @@ dependencies = [ "sum_tree", "tempfile", "text", + "thiserror 2.0.12", "time", "unindent", "url", diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index f640a663e6..030607d12b 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -1021,8 +1021,7 @@ impl ActiveThread { .when(first_message, |parent| { parent.child(self.render_rules_item(cx)) }) - .when(!first_message && checkpoint.is_some(), |parent| { - let checkpoint = checkpoint.clone().unwrap(); + .when_some(checkpoint, |parent, checkpoint| { let mut is_pending = false; let mut error = None; if let Some(last_restore_checkpoint) = diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 4baa1da3be..e796a0aa1b 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -12,7 +12,6 @@ use gpui::{ }; use language_model::LanguageModelRegistry; use language_model_selector::ToggleModelSelector; -use project::Project; use rope::Point; use settings::Settings; use std::time::Duration; @@ -21,7 +20,6 @@ use theme::ThemeSettings; use ui::{ prelude::*, ButtonLike, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip, }; -use util::ResultExt; use vim_mode_setting::VimModeSetting; use workspace::Workspace; @@ -39,7 +37,6 @@ pub struct MessageEditor { editor: Entity, #[allow(dead_code)] workspace: WeakEntity, - project: Entity, context_store: Entity, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, @@ -110,7 +107,6 @@ impl MessageEditor { Self { editor: editor.clone(), - project: thread.read(cx).project().clone(), thread, workspace, context_store, @@ -209,8 +205,6 @@ impl MessageEditor { let thread = self.thread.clone(); let context_store = self.context_store.clone(); - let git_store = self.project.read(cx).git_store(); - let checkpoint = git_store.read(cx).checkpoint(cx); cx.spawn(async move |_, cx| { refresh_task.await; let (system_prompt_context, load_error) = system_prompt_context_task.await; @@ -222,11 +216,10 @@ impl MessageEditor { } }) .ok(); - let checkpoint = checkpoint.await.log_err(); thread .update(cx, |thread, cx| { let context = context_store.read(cx).snapshot(cx).collect::>(); - thread.insert_user_message(user_message, context, checkpoint, cx); + thread.insert_user_message(user_message, context, cx); thread.send_to_model(model, request_kind, cx); }) .ok(); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index f9c2c40c03..2f6918d206 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,5 +1,6 @@ use std::fmt::Write as _; use std::io::Write; +use std::mem; use std::sync::Arc; use anyhow::{Context as _, Result}; @@ -176,7 +177,7 @@ pub struct Thread { context: BTreeMap, context_by_message: HashMap>, system_prompt_context: Option, - checkpoints_by_message: HashMap, + checkpoints_by_message: HashMap, completion_count: usize, pending_completions: Vec, project: Entity, @@ -185,6 +186,8 @@ pub struct Thread { tool_use: ToolUseState, action_log: Entity, last_restore_checkpoint: Option, + pending_checkpoint: Option>>, + checkpoint_on_next_user_message: bool, scripting_session: Entity, scripting_tool_use: ToolUseState, initial_project_snapshot: Shared>>>, @@ -216,6 +219,8 @@ impl Thread { prompt_builder, tools: tools.clone(), last_restore_checkpoint: None, + pending_checkpoint: None, + checkpoint_on_next_user_message: true, tool_use: ToolUseState::new(tools.clone()), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), scripting_tool_use: ToolUseState::new(tools), @@ -287,6 +292,8 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), last_restore_checkpoint: None, + pending_checkpoint: None, + checkpoint_on_next_user_message: true, project, prompt_builder, tools, @@ -348,11 +355,7 @@ impl Thread { } pub fn checkpoint_for_message(&self, id: MessageId) -> Option { - let checkpoint = self.checkpoints_by_message.get(&id).cloned()?; - Some(ThreadCheckpoint { - message_id: id, - git_checkpoint: checkpoint, - }) + self.checkpoints_by_message.get(&id).cloned() } pub fn restore_checkpoint( @@ -364,12 +367,13 @@ impl Thread { message_id: checkpoint.message_id, }); cx.emit(ThreadEvent::CheckpointChanged); + cx.notify(); let project = self.project.read(cx); let restore = project .git_store() .read(cx) - .restore_checkpoint(checkpoint.git_checkpoint, cx); + .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx); cx.spawn(async move |this, cx| { let result = restore.await; this.update(cx, |this, cx| { @@ -379,15 +383,62 @@ impl Thread { error: err.to_string(), }); } else { - this.last_restore_checkpoint = None; this.truncate(checkpoint.message_id, cx); + this.last_restore_checkpoint = None; + this.pending_checkpoint = Some(Task::ready(Ok(ThreadCheckpoint { + message_id: this.next_message_id, + git_checkpoint: checkpoint.git_checkpoint, + }))); } cx.emit(ThreadEvent::CheckpointChanged); + cx.notify(); })?; result }) } + fn checkpoint(&mut self, cx: &mut Context) { + if self.is_generating() { + return; + } + + let git_store = self.project.read(cx).git_store().clone(); + let new_checkpoint = git_store.read(cx).checkpoint(cx); + let old_checkpoint = self.pending_checkpoint.take(); + let next_user_message_id = self.next_message_id; + self.pending_checkpoint = Some(cx.spawn(async move |this, cx| { + let new_checkpoint = new_checkpoint.await?; + + if let Some(old_checkpoint) = old_checkpoint { + if let Ok(old_checkpoint) = old_checkpoint.await { + let equal = git_store + .read_with(cx, |store, cx| { + store.compare_checkpoints( + old_checkpoint.git_checkpoint.clone(), + new_checkpoint.clone(), + cx, + ) + })? + .await; + + if equal.ok() != Some(true) { + this.update(cx, |this, cx| { + this.checkpoints_by_message + .insert(old_checkpoint.message_id, old_checkpoint); + cx.emit(ThreadEvent::CheckpointChanged); + cx.notify(); + })?; + } + } + } + + Ok(ThreadCheckpoint { + message_id: next_user_message_id, + git_checkpoint: new_checkpoint, + }) + })); + } + pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { self.last_restore_checkpoint.as_ref() } @@ -466,18 +517,18 @@ impl Thread { &mut self, text: impl Into, context: Vec, - checkpoint: Option, cx: &mut Context, ) -> MessageId { + if mem::take(&mut self.checkpoint_on_next_user_message) { + self.checkpoint(cx); + } + let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx); let context_ids = context.iter().map(|context| context.id).collect::>(); self.context .extend(context.into_iter().map(|context| (context.id, context))); self.context_by_message.insert(message_id, context_ids); - if let Some(checkpoint) = checkpoint { - self.checkpoints_by_message.insert(message_id, checkpoint); - } message_id } @@ -999,6 +1050,7 @@ impl Thread { thread .update(cx, |thread, cx| { + thread.checkpoint(cx); match result.as_ref() { Ok(stop_reason) => match stop_reason { StopReason::ToolUse => { @@ -1267,7 +1319,6 @@ impl Thread { // so for now we provide some text to keep the model on track. "Here are the tool results.", Vec::new(), - None, cx, ); } @@ -1276,7 +1327,7 @@ impl Thread { /// /// Returns whether a completion was canceled. pub fn cancel_last_completion(&mut self, cx: &mut Context) -> bool { - if self.pending_completions.pop().is_some() { + let canceled = if self.pending_completions.pop().is_some() { true } else { let mut canceled = false; @@ -1289,7 +1340,9 @@ impl Thread { }); } canceled - } + }; + self.checkpoint(cx); + canceled } /// Returns the feedback given to the thread, if any. diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs index 8801472e4d..f8fa743293 100644 --- a/crates/assistant_eval/src/eval.rs +++ b/crates/assistant_eval/src/eval.rs @@ -96,7 +96,7 @@ impl Eval { assistant.update(cx, |assistant, cx| { assistant.thread.update(cx, |thread, cx| { let context = vec![]; - thread.insert_user_message(self.user_prompt.clone(), context, None, cx); + thread.insert_user_message(self.user_prompt.clone(), context, cx); thread.set_system_prompt_context(system_prompt_context); thread.send_to_model(model, RequestKind::Chat, cx); }); diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index afc2524a64..3206a6853c 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -420,4 +420,13 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture> { unimplemented!() } + + fn compare_checkpoints( + &self, + _left: GitRepositoryCheckpoint, + _right: GitRepositoryCheckpoint, + _cx: AsyncApp, + ) -> BoxFuture> { + unimplemented!() + } } diff --git a/crates/git/Cargo.toml b/crates/git/Cargo.toml index 23e145c0c6..7fe0a87d61 100644 --- a/crates/git/Cargo.toml +++ b/crates/git/Cargo.toml @@ -32,6 +32,7 @@ serde.workspace = true smol.workspace = true sum_tree.workspace = true text.workspace = true +thiserror.workspace = true time.workspace = true url.workspace = true util.workspace = true diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index 6848e0004f..f466ae4f69 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -5,14 +5,15 @@ use collections::HashMap; use futures::future::BoxFuture; use futures::{select_biased, AsyncWriteExt, FutureExt as _}; use git2::BranchType; -use gpui::{AppContext, AsyncApp, SharedString}; +use gpui::{AppContext, AsyncApp, BackgroundExecutor, SharedString}; use parking_lot::Mutex; use rope::Rope; use schemars::JsonSchema; use serde::Deserialize; use std::borrow::Borrow; +use std::future; use std::path::Component; -use std::process::Stdio; +use std::process::{ExitStatus, Stdio}; use std::sync::LazyLock; use std::{ cmp::Ordering, @@ -20,6 +21,7 @@ use std::{ sync::Arc, }; use sum_tree::MapSeekTarget; +use thiserror::Error; use util::command::new_smol_command; use util::ResultExt; use uuid::Uuid; @@ -298,6 +300,14 @@ pub trait GitRepository: Send + Sync { checkpoint: GitRepositoryCheckpoint, cx: AsyncApp, ) -> BoxFuture>; + + /// Compares two checkpoints, returning true if they are equal + fn compare_checkpoints( + &self, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, + cx: AsyncApp, + ) -> BoxFuture>; } pub enum DiffType { @@ -1049,62 +1059,36 @@ impl GitRepository for RealGitRepository { let executor = cx.background_executor().clone(); cx.background_spawn(async move { let working_directory = working_directory?; - let index_file_path = working_directory.join(".git/index.tmp"); - - let delete_temp_index = util::defer({ - let index_file_path = index_file_path.clone(); - || { - executor - .spawn(async move { - smol::fs::remove_file(index_file_path).await.log_err(); - }) - .detach(); - } - }); - - let run_git_command = async |args: &[&str]| { - let output = new_smol_command(&git_binary_path) - .current_dir(&working_directory) - .env("GIT_INDEX_FILE", &index_file_path) - .envs(checkpoint_author_envs()) - .args(args) - .output() - .await?; - if output.status.success() { - anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string()) + let mut git = GitBinary::new(git_binary_path, working_directory, executor) + .envs(checkpoint_author_envs()); + git.with_temp_index(async |git| { + let head_sha = git.run(&["rev-parse", "HEAD"]).await.ok(); + git.run(&["add", "--all"]).await?; + let tree = git.run(&["write-tree"]).await?; + let checkpoint_sha = if let Some(head_sha) = head_sha.as_deref() { + git.run(&["commit-tree", &tree, "-p", head_sha, "-m", "Checkpoint"]) + .await? } else { - let error = String::from_utf8_lossy(&output.stderr); - Err(anyhow!("Git command failed: {:?}", error)) - } - }; + git.run(&["commit-tree", &tree, "-m", "Checkpoint"]).await? + }; + let ref_name = Uuid::new_v4().to_string(); + git.run(&[ + "update-ref", + &format!("refs/zed/{ref_name}"), + &checkpoint_sha, + ]) + .await?; - let head_sha = run_git_command(&["rev-parse", "HEAD"]).await.ok(); - run_git_command(&["add", "--all"]).await?; - let tree = run_git_command(&["write-tree"]).await?; - let checkpoint_sha = if let Some(head_sha) = head_sha.as_deref() { - run_git_command(&["commit-tree", &tree, "-p", head_sha, "-m", "Checkpoint"]).await? - } else { - run_git_command(&["commit-tree", &tree, "-m", "Checkpoint"]).await? - }; - let ref_name = Uuid::new_v4().to_string(); - run_git_command(&[ - "update-ref", - &format!("refs/zed/{ref_name}"), - &checkpoint_sha, - ]) - .await?; - - smol::fs::remove_file(index_file_path).await.ok(); - delete_temp_index.abort(); - - Ok(GitRepositoryCheckpoint { - head_sha: if let Some(head_sha) = head_sha { - Some(head_sha.parse()?) - } else { - None - }, - sha: checkpoint_sha.parse()?, + Ok(GitRepositoryCheckpoint { + head_sha: if let Some(head_sha) = head_sha { + Some(head_sha.parse()?) + } else { + None + }, + sha: checkpoint_sha.parse()?, + }) }) + .await }) .boxed() } @@ -1116,50 +1100,165 @@ impl GitRepository for RealGitRepository { ) -> BoxFuture> { let working_directory = self.working_directory(); let git_binary_path = self.git_binary_path.clone(); + + let executor = cx.background_executor().clone(); cx.background_spawn(async move { let working_directory = working_directory?; - let index_file_path = working_directory.join(".git/index.tmp"); - let run_git_command = async |args: &[&str], use_temp_index: bool| { - let mut command = new_smol_command(&git_binary_path); - command.current_dir(&working_directory); - command.args(args); - if use_temp_index { - command.env("GIT_INDEX_FILE", &index_file_path); - } - let output = command.output().await?; - if output.status.success() { - anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string()) - } else { - let error = String::from_utf8_lossy(&output.stderr); - Err(anyhow!("Git command failed: {:?}", error)) - } - }; - - run_git_command( - &[ - "restore", - "--source", - &checkpoint.sha.to_string(), - "--worktree", - ".", - ], - false, - ) + let mut git = GitBinary::new(git_binary_path, working_directory, executor); + git.run(&[ + "restore", + "--source", + &checkpoint.sha.to_string(), + "--worktree", + ".", + ]) + .await?; + + git.with_temp_index(async move |git| { + git.run(&["read-tree", &checkpoint.sha.to_string()]).await?; + git.run(&["clean", "-d", "--force"]).await + }) .await?; - run_git_command(&["read-tree", &checkpoint.sha.to_string()], true).await?; - run_git_command(&["clean", "-d", "--force"], true).await?; if let Some(head_sha) = checkpoint.head_sha { - run_git_command(&["reset", "--mixed", &head_sha.to_string()], false).await?; + git.run(&["reset", "--mixed", &head_sha.to_string()]) + .await?; } else { - run_git_command(&["update-ref", "-d", "HEAD"], false).await?; + git.run(&["update-ref", "-d", "HEAD"]).await?; } Ok(()) }) .boxed() } + + fn compare_checkpoints( + &self, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, + cx: AsyncApp, + ) -> BoxFuture> { + if left.head_sha != right.head_sha { + return future::ready(Ok(false)).boxed(); + } + + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + let result = git + .run(&[ + "diff-tree", + "--quiet", + &left.sha.to_string(), + &right.sha.to_string(), + ]) + .await; + match result { + Ok(_) => Ok(true), + Err(error) => { + if let Some(GitBinaryCommandError { status, .. }) = + error.downcast_ref::() + { + if status.code() == Some(1) { + return Ok(false); + } + } + + Err(error) + } + } + }) + .boxed() + } +} + +struct GitBinary { + git_binary_path: PathBuf, + working_directory: PathBuf, + executor: BackgroundExecutor, + index_file_path: Option, + envs: HashMap, +} + +impl GitBinary { + fn new( + git_binary_path: PathBuf, + working_directory: PathBuf, + executor: BackgroundExecutor, + ) -> Self { + Self { + git_binary_path, + working_directory, + executor, + index_file_path: None, + envs: HashMap::default(), + } + } + + fn envs(mut self, envs: HashMap) -> Self { + self.envs = envs; + self + } + + pub async fn with_temp_index( + &mut self, + f: impl AsyncFnOnce(&Self) -> Result, + ) -> Result { + let index_file_path = self.working_directory.join(".git/index.tmp"); + + let delete_temp_index = util::defer({ + let index_file_path = index_file_path.clone(); + let executor = self.executor.clone(); + move || { + executor + .spawn(async move { + smol::fs::remove_file(index_file_path).await.log_err(); + }) + .detach(); + } + }); + + self.index_file_path = Some(index_file_path.clone()); + let result = f(self).await; + self.index_file_path = None; + let result = result?; + + smol::fs::remove_file(index_file_path).await.ok(); + delete_temp_index.abort(); + + Ok(result) + } + + pub async fn run(&self, args: &[&str]) -> Result { + let mut command = new_smol_command(&self.git_binary_path); + command.current_dir(&self.working_directory); + command.args(args); + if let Some(index_file_path) = self.index_file_path.as_ref() { + command.env("GIT_INDEX_FILE", index_file_path); + } + command.envs(&self.envs); + let output = command.output().await?; + if output.status.success() { + anyhow::Ok(String::from_utf8(output.stdout)?.trim_end().to_string()) + } else { + Err(anyhow!(GitBinaryCommandError { + stdout: String::from_utf8_lossy(&output.stdout).to_string(), + status: output.status, + })) + } + } +} + +#[derive(Error, Debug)] +#[error("Git command failed: {stdout}")] +struct GitBinaryCommandError { + stdout: String, + status: ExitStatus, } async fn run_remote_command( @@ -1619,6 +1718,36 @@ mod tests { ); } + #[gpui::test] + async fn test_compare_checkpoints(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + + let repo_dir = tempfile::tempdir().unwrap(); + git2::Repository::init(repo_dir.path()).unwrap(); + let repo = RealGitRepository::new(&repo_dir.path().join(".git"), None).unwrap(); + + smol::fs::write(repo_dir.path().join("file1"), "content1") + .await + .unwrap(); + let checkpoint1 = repo.checkpoint(cx.to_async()).await.unwrap(); + + smol::fs::write(repo_dir.path().join("file2"), "content2") + .await + .unwrap(); + let checkpoint2 = repo.checkpoint(cx.to_async()).await.unwrap(); + + assert!(!repo + .compare_checkpoints(checkpoint1, checkpoint2, cx.to_async()) + .await + .unwrap()); + + let checkpoint3 = repo.checkpoint(cx.to_async()).await.unwrap(); + assert!(repo + .compare_checkpoints(checkpoint2, checkpoint3, cx.to_async()) + .await + .unwrap()); + } + #[test] fn test_branches_parsing() { // suppress "help: octal escapes are not supported, `\0` is always null" diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index ddd2edf189..4307953c71 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -580,6 +580,44 @@ impl GitStore { }) } + /// Compares two checkpoints, returning true if they are equal. + pub fn compare_checkpoints( + &self, + left: GitStoreCheckpoint, + mut right: GitStoreCheckpoint, + cx: &App, + ) -> Task> { + let repositories_by_dot_git_abs_path = self + .repositories + .values() + .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo)) + .collect::>(); + + let mut tasks = Vec::new(); + for (dot_git_abs_path, left_checkpoint) in left.checkpoints_by_dot_git_abs_path { + if let Some(right_checkpoint) = right + .checkpoints_by_dot_git_abs_path + .remove(&dot_git_abs_path) + { + if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) { + let compare = repository + .read(cx) + .compare_checkpoints(left_checkpoint, right_checkpoint); + tasks.push(async move { compare.await? }); + } + } else { + return Task::ready(Ok(false)); + } + } + cx.background_spawn(async move { + Ok(future::try_join_all(tasks) + .await? + .into_iter() + .all(|result| result)) + }) + } + + /// Blames a buffer. pub fn blame_buffer( &self, buffer: &Entity, @@ -3266,6 +3304,21 @@ impl Repository { } }) } + + pub fn compare_checkpoints( + &self, + left: GitRepositoryCheckpoint, + right: GitRepositoryCheckpoint, + ) -> oneshot::Receiver> { + self.send_job(move |repo, cx| async move { + match repo { + RepositoryState::Local(git_repository) => { + git_repository.compare_checkpoints(left, right, cx).await + } + RepositoryState::Remote { .. } => Err(anyhow!("not implemented yet")), + } + }) + } } fn get_permalink_in_rust_registry_src(