diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index e796a0aa1b..4f3e346918 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -12,6 +12,7 @@ use gpui::{ }; use language_model::LanguageModelRegistry; use language_model_selector::ToggleModelSelector; +use project::Project; use rope::Point; use settings::Settings; use std::time::Duration; @@ -37,6 +38,7 @@ pub struct MessageEditor { editor: Entity, #[allow(dead_code)] workspace: WeakEntity, + project: Entity, context_store: Entity, context_strip: Entity, context_picker_menu_handle: PopoverMenuHandle, @@ -107,6 +109,7 @@ impl MessageEditor { Self { editor: editor.clone(), + project: thread.read(cx).project().clone(), thread, workspace, context_store, @@ -205,7 +208,9 @@ impl MessageEditor { let thread = self.thread.clone(); let context_store = self.context_store.clone(); + let checkpoint = self.project.read(cx).git_store().read(cx).checkpoint(cx); cx.spawn(async move |_, cx| { + let checkpoint = checkpoint.await.ok(); refresh_task.await; let (system_prompt_context, load_error) = system_prompt_context_task.await; thread @@ -219,7 +224,7 @@ impl MessageEditor { thread .update(cx, |thread, cx| { let context = context_store.read(cx).snapshot(cx).collect::>(); - thread.insert_user_message(user_message, context, cx); + thread.insert_user_message(user_message, context, checkpoint, cx); thread.send_to_model(model, request_kind, cx); }) .ok(); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 2f6918d206..41834a6e5b 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,6 +1,5 @@ use std::fmt::Write as _; use std::io::Write; -use std::mem; use std::sync::Arc; use anyhow::{Context as _, Result}; @@ -186,8 +185,7 @@ pub struct Thread { tool_use: ToolUseState, action_log: Entity, last_restore_checkpoint: Option, - pending_checkpoint: Option>>, - checkpoint_on_next_user_message: bool, + pending_checkpoint: Option, scripting_session: Entity, scripting_tool_use: ToolUseState, initial_project_snapshot: Shared>>>, @@ -220,7 +218,6 @@ impl Thread { tools: tools.clone(), last_restore_checkpoint: None, pending_checkpoint: None, - checkpoint_on_next_user_message: true, tool_use: ToolUseState::new(tools.clone()), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), scripting_tool_use: ToolUseState::new(tools), @@ -293,7 +290,6 @@ impl Thread { pending_completions: Vec::new(), last_restore_checkpoint: None, pending_checkpoint: None, - checkpoint_on_next_user_message: true, project, prompt_builder, tools, @@ -385,11 +381,8 @@ impl Thread { } else { this.truncate(checkpoint.message_id, cx); this.last_restore_checkpoint = None; - this.pending_checkpoint = Some(Task::ready(Ok(ThreadCheckpoint { - message_id: this.next_message_id, - git_checkpoint: checkpoint.git_checkpoint, - }))); } + this.pending_checkpoint = None; cx.emit(ThreadEvent::CheckpointChanged); cx.notify(); })?; @@ -397,46 +390,62 @@ impl Thread { }) } - fn checkpoint(&mut self, cx: &mut Context) { - if self.is_generating() { + fn finalize_pending_checkpoint(&mut self, cx: &mut Context) { + let pending_checkpoint = if self.is_generating() { return; - } + } else if let Some(checkpoint) = self.pending_checkpoint.take() { + checkpoint + } else { + return; + }; let git_store = self.project.read(cx).git_store().clone(); - let new_checkpoint = git_store.read(cx).checkpoint(cx); - let old_checkpoint = self.pending_checkpoint.take(); - let next_user_message_id = self.next_message_id; - self.pending_checkpoint = Some(cx.spawn(async move |this, cx| { - let new_checkpoint = new_checkpoint.await?; + let final_checkpoint = git_store.read(cx).checkpoint(cx); + cx.spawn(async move |this, cx| match final_checkpoint.await { + Ok(final_checkpoint) => { + let equal = git_store + .read_with(cx, |store, cx| { + store.compare_checkpoints( + pending_checkpoint.git_checkpoint.clone(), + final_checkpoint.clone(), + cx, + ) + })? + .await + .unwrap_or(false); - if let Some(old_checkpoint) = old_checkpoint { - if let Ok(old_checkpoint) = old_checkpoint.await { - let equal = git_store + if equal { + git_store .read_with(cx, |store, cx| { - store.compare_checkpoints( - old_checkpoint.git_checkpoint.clone(), - new_checkpoint.clone(), - cx, - ) + store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx) })? - .await; - - if equal.ok() != Some(true) { - this.update(cx, |this, cx| { - this.checkpoints_by_message - .insert(old_checkpoint.message_id, old_checkpoint); - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); - })?; - } + .detach(); + } else { + this.update(cx, |this, cx| { + this.insert_checkpoint(pending_checkpoint, cx) + })?; } - } - Ok(ThreadCheckpoint { - message_id: next_user_message_id, - git_checkpoint: new_checkpoint, - }) - })); + git_store + .read_with(cx, |store, cx| { + store.delete_checkpoint(final_checkpoint, cx) + })? + .detach(); + + Ok(()) + } + Err(_) => this.update(cx, |this, cx| { + this.insert_checkpoint(pending_checkpoint, cx) + }), + }) + .detach(); + } + + fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context) { + self.checkpoints_by_message + .insert(checkpoint.message_id, checkpoint); + cx.emit(ThreadEvent::CheckpointChanged); + cx.notify(); } pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { @@ -517,18 +526,21 @@ impl Thread { &mut self, text: impl Into, context: Vec, + git_checkpoint: Option, cx: &mut Context, ) -> MessageId { - if mem::take(&mut self.checkpoint_on_next_user_message) { - self.checkpoint(cx); - } - let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx); let context_ids = context.iter().map(|context| context.id).collect::>(); self.context .extend(context.into_iter().map(|context| (context.id, context))); self.context_by_message.insert(message_id, context_ids); + if let Some(git_checkpoint) = git_checkpoint { + self.pending_checkpoint = Some(ThreadCheckpoint { + message_id, + git_checkpoint, + }); + } message_id } @@ -1050,7 +1062,7 @@ impl Thread { thread .update(cx, |thread, cx| { - thread.checkpoint(cx); + thread.finalize_pending_checkpoint(cx); match result.as_ref() { Ok(stop_reason) => match stop_reason { StopReason::ToolUse => { @@ -1319,6 +1331,7 @@ impl Thread { // so for now we provide some text to keep the model on track. "Here are the tool results.", Vec::new(), + None, cx, ); } @@ -1341,7 +1354,7 @@ impl Thread { } canceled }; - self.checkpoint(cx); + self.finalize_pending_checkpoint(cx); canceled } diff --git a/crates/assistant_eval/src/eval.rs b/crates/assistant_eval/src/eval.rs index f8fa743293..8801472e4d 100644 --- a/crates/assistant_eval/src/eval.rs +++ b/crates/assistant_eval/src/eval.rs @@ -96,7 +96,7 @@ impl Eval { assistant.update(cx, |assistant, cx| { assistant.thread.update(cx, |thread, cx| { let context = vec![]; - thread.insert_user_message(self.user_prompt.clone(), context, cx); + thread.insert_user_message(self.user_prompt.clone(), context, None, cx); thread.set_system_prompt_context(system_prompt_context); thread.send_to_model(model, RequestKind::Chat, cx); }); diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 3206a6853c..1b472e6f3a 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -429,4 +429,12 @@ impl GitRepository for FakeGitRepository { ) -> BoxFuture> { unimplemented!() } + + fn delete_checkpoint( + &self, + _checkpoint: GitRepositoryCheckpoint, + _cx: AsyncApp, + ) -> BoxFuture> { + unimplemented!() + } } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index f466ae4f69..70e370adf9 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -308,6 +308,13 @@ pub trait GitRepository: Send + Sync { right: GitRepositoryCheckpoint, cx: AsyncApp, ) -> BoxFuture>; + + /// Deletes a previously-created checkpoint. + fn delete_checkpoint( + &self, + checkpoint: GitRepositoryCheckpoint, + cx: AsyncApp, + ) -> BoxFuture>; } pub enum DiffType { @@ -351,10 +358,11 @@ impl RealGitRepository { } } -#[derive(Copy, Clone)] +#[derive(Clone, Debug)] pub struct GitRepositoryCheckpoint { + ref_name: String, head_sha: Option, - sha: Oid, + commit_sha: Oid, } // https://git-scm.com/book/en/v2/Git-Internals-Git-Objects @@ -1071,21 +1079,17 @@ impl GitRepository for RealGitRepository { } else { git.run(&["commit-tree", &tree, "-m", "Checkpoint"]).await? }; - let ref_name = Uuid::new_v4().to_string(); - git.run(&[ - "update-ref", - &format!("refs/zed/{ref_name}"), - &checkpoint_sha, - ]) - .await?; + let ref_name = format!("refs/zed/{}", Uuid::new_v4()); + git.run(&["update-ref", &ref_name, &checkpoint_sha]).await?; Ok(GitRepositoryCheckpoint { + ref_name, head_sha: if let Some(head_sha) = head_sha { Some(head_sha.parse()?) } else { None }, - sha: checkpoint_sha.parse()?, + commit_sha: checkpoint_sha.parse()?, }) }) .await @@ -1109,14 +1113,15 @@ impl GitRepository for RealGitRepository { git.run(&[ "restore", "--source", - &checkpoint.sha.to_string(), + &checkpoint.commit_sha.to_string(), "--worktree", ".", ]) .await?; git.with_temp_index(async move |git| { - git.run(&["read-tree", &checkpoint.sha.to_string()]).await?; + git.run(&["read-tree", &checkpoint.commit_sha.to_string()]) + .await?; git.run(&["clean", "-d", "--force"]).await }) .await?; @@ -1154,8 +1159,8 @@ impl GitRepository for RealGitRepository { .run(&[ "diff-tree", "--quiet", - &left.sha.to_string(), - &right.sha.to_string(), + &left.commit_sha.to_string(), + &right.commit_sha.to_string(), ]) .await; match result { @@ -1175,6 +1180,24 @@ impl GitRepository for RealGitRepository { }) .boxed() } + + fn delete_checkpoint( + &self, + checkpoint: GitRepositoryCheckpoint, + cx: AsyncApp, + ) -> BoxFuture> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + git.run(&["update-ref", "-d", &checkpoint.ref_name]).await?; + Ok(()) + }) + .boxed() + } } struct GitBinary { @@ -1574,7 +1597,9 @@ mod tests { .await .unwrap(); - repo.restore_checkpoint(checkpoint, cx.to_async()) + // Ensure checkpoint stays alive even after a Git GC. + repo.gc(cx.to_async()).await.unwrap(); + repo.restore_checkpoint(checkpoint.clone(), cx.to_async()) .await .unwrap(); @@ -1595,6 +1620,15 @@ mod tests { .ok(), None ); + + // Garbage collecting after deleting a checkpoint makes it unreachable. + repo.delete_checkpoint(checkpoint.clone(), cx.to_async()) + .await + .unwrap(); + repo.gc(cx.to_async()).await.unwrap(); + repo.restore_checkpoint(checkpoint.clone(), cx.to_async()) + .await + .unwrap_err(); } #[gpui::test] @@ -1737,7 +1771,7 @@ mod tests { let checkpoint2 = repo.checkpoint(cx.to_async()).await.unwrap(); assert!(!repo - .compare_checkpoints(checkpoint1, checkpoint2, cx.to_async()) + .compare_checkpoints(checkpoint1, checkpoint2.clone(), cx.to_async()) .await .unwrap()); @@ -1774,4 +1808,21 @@ mod tests { }] ) } + + impl RealGitRepository { + /// Force a Git garbage collection on the repository. + fn gc(&self, cx: AsyncApp) -> BoxFuture> { + let working_directory = self.working_directory(); + let git_binary_path = self.git_binary_path.clone(); + let executor = cx.background_executor().clone(); + cx.background_spawn(async move { + let git_binary_path = git_binary_path.clone(); + let working_directory = working_directory?; + let git = GitBinary::new(git_binary_path, working_directory, executor); + git.run(&["gc", "--prune=now"]).await?; + Ok(()) + }) + .boxed() + } + } } diff --git a/crates/project/src/git_store.rs b/crates/project/src/git_store.rs index 4307953c71..4776b00104 100644 --- a/crates/project/src/git_store.rs +++ b/crates/project/src/git_store.rs @@ -549,8 +549,7 @@ impl GitStore { } cx.background_executor().spawn(async move { - let checkpoints: Vec = - future::try_join_all(checkpoints).await?; + let checkpoints = future::try_join_all(checkpoints).await?; Ok(GitStoreCheckpoint { checkpoints_by_dot_git_abs_path: dot_git_abs_paths .into_iter() @@ -617,6 +616,26 @@ impl GitStore { }) } + pub fn delete_checkpoint(&self, checkpoint: GitStoreCheckpoint, cx: &App) -> Task> { + let repositories_by_dot_git_abs_path = self + .repositories + .values() + .map(|repo| (repo.read(cx).dot_git_abs_path.clone(), repo)) + .collect::>(); + + let mut tasks = Vec::new(); + for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path { + if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) { + let delete = repository.read(cx).delete_checkpoint(checkpoint); + tasks.push(async move { delete.await? }); + } + } + cx.background_spawn(async move { + future::try_join_all(tasks).await?; + Ok(()) + }) + } + /// Blames a buffer. pub fn blame_buffer( &self, @@ -3319,6 +3338,20 @@ impl Repository { } }) } + + pub fn delete_checkpoint( + &self, + checkpoint: GitRepositoryCheckpoint, + ) -> oneshot::Receiver> { + self.send_job(move |repo, cx| async move { + match repo { + RepositoryState::Local(git_repository) => { + git_repository.delete_checkpoint(checkpoint, cx).await + } + RepositoryState::Remote { .. } => Err(anyhow!("not implemented yet")), + } + }) + } } fn get_permalink_in_rust_registry_src(