diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 8e89e4c79f..4cc89784fa 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -1,14 +1,16 @@ -use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; +use crate::thread::{ + LastRestoreCheckpoint, MessageId, RequestKind, Thread, ThreadError, ThreadEvent, +}; use crate::thread_store::ThreadStore; use crate::tool_use::{ToolUse, ToolUseStatus}; use crate::ui::ContextPill; use collections::HashMap; use editor::{Editor, MultiBuffer}; use gpui::{ - list, percentage, AbsoluteLength, Animation, AnimationExt, AnyElement, App, ClickEvent, - DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset, - ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, Transformation, - UnderlineStyle, WeakEntity, + list, percentage, pulsating_between, AbsoluteLength, Animation, AnimationExt, AnyElement, App, + ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, + ListOffset, ListState, StyleRefinement, Subscription, Task, TextStyleRefinement, + Transformation, UnderlineStyle, WeakEntity, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; @@ -18,7 +20,7 @@ use settings::Settings as _; use std::sync::Arc; use std::time::Duration; use theme::ThemeSettings; -use ui::{prelude::*, Disclosure, KeyBinding}; +use ui::{prelude::*, Disclosure, KeyBinding, Tooltip}; use util::ResultExt as _; use workspace::{OpenOptions, Workspace}; @@ -401,7 +403,6 @@ impl ActiveThread { window, cx, ); - self.render_scripting_tool_use_markdown( tool_use.id.clone(), tool_use.name.as_ref(), @@ -463,6 +464,7 @@ impl ActiveThread { } } } + ThreadEvent::CheckpointChanged => cx.notify(), } } @@ -789,20 +791,62 @@ impl ActiveThread { v_flex() .when(ix == 0, |parent| parent.child(self.render_rules_item(cx))) .when_some(checkpoint, |parent, checkpoint| { - parent.child( - h_flex().pl_2().child( - Button::new(("restore-checkpoint", ix), "Restore Checkpoint") - .icon(IconName::Undo) - .size(ButtonSize::Compact) - .on_click(cx.listener(move |this, _, _window, cx| { - this.thread.update(cx, |thread, cx| { - thread - .restore_checkpoint(checkpoint.clone(), cx) - .detach_and_log_err(cx); - }); - })), - ), - ) + let mut is_pending = false; + let mut error = None; + if let Some(last_restore_checkpoint) = + self.thread.read(cx).last_restore_checkpoint() + { + if last_restore_checkpoint.message_id() == message_id { + match last_restore_checkpoint { + LastRestoreCheckpoint::Pending { .. } => is_pending = true, + LastRestoreCheckpoint::Error { error: err, .. } => { + error = Some(err.clone()); + } + } + } + } + + let restore_checkpoint_button = + Button::new(("restore-checkpoint", ix), "Restore Checkpoint") + .icon(if error.is_some() { + IconName::XCircle + } else { + IconName::Undo + }) + .size(ButtonSize::Compact) + .disabled(is_pending) + .icon_color(if error.is_some() { + Some(Color::Error) + } else { + None + }) + .on_click(cx.listener(move |this, _, _window, cx| { + this.thread.update(cx, |thread, cx| { + thread + .restore_checkpoint(checkpoint.clone(), cx) + .detach_and_log_err(cx); + }); + })); + + let restore_checkpoint_button = if is_pending { + restore_checkpoint_button + .with_animation( + ("pulsating-restore-checkpoint-button", ix), + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.6, 1.)), + |label, delta| label.alpha(delta), + ) + .into_any_element() + } else if let Some(error) = error { + restore_checkpoint_button + .tooltip(Tooltip::text(error.to_string())) + .into_any_element() + } else { + restore_checkpoint_button.into_any_element() + }; + + parent.child(h_flex().pl_2().child(restore_checkpoint_button)) }) .child(styled_message) .into_any() diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index bbab5858ed..e861a66f59 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -99,6 +99,25 @@ pub struct ThreadCheckpoint { git_checkpoint: GitStoreCheckpoint, } +pub enum LastRestoreCheckpoint { + Pending { + message_id: MessageId, + }, + Error { + message_id: MessageId, + error: String, + }, +} + +impl LastRestoreCheckpoint { + pub fn message_id(&self) -> MessageId { + match self { + LastRestoreCheckpoint::Pending { message_id } => *message_id, + LastRestoreCheckpoint::Error { message_id, .. } => *message_id, + } + } +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -118,6 +137,7 @@ pub struct Thread { tools: Arc, tool_use: ToolUseState, action_log: Entity, + last_restore_checkpoint: Option, scripting_session: Entity, scripting_tool_use: ToolUseState, initial_project_snapshot: Shared>>>, @@ -147,6 +167,7 @@ impl Thread { project: project.clone(), prompt_builder, tools: tools.clone(), + last_restore_checkpoint: None, tool_use: ToolUseState::new(tools.clone()), scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)), scripting_tool_use: ToolUseState::new(tools), @@ -207,6 +228,7 @@ impl Thread { checkpoints_by_message: HashMap::default(), completion_count: 0, pending_completions: Vec::new(), + last_restore_checkpoint: None, project, prompt_builder, tools, @@ -279,17 +301,38 @@ impl Thread { checkpoint: ThreadCheckpoint, cx: &mut Context, ) -> Task> { + self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending { + message_id: checkpoint.message_id, + }); + cx.emit(ThreadEvent::CheckpointChanged); + let project = self.project.read(cx); let restore = project .git_store() .read(cx) .restore_checkpoint(checkpoint.git_checkpoint, cx); cx.spawn(async move |this, cx| { - restore.await?; - this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx)) + let result = restore.await; + this.update(cx, |this, cx| { + if let Err(err) = result.as_ref() { + this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error { + message_id: checkpoint.message_id, + error: err.to_string(), + }); + } else { + this.last_restore_checkpoint = None; + this.truncate(checkpoint.message_id, cx); + } + cx.emit(ThreadEvent::CheckpointChanged); + })?; + result }) } + pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { + self.last_restore_checkpoint.as_ref() + } + pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context) { let Some(message_ix) = self .messages @@ -1361,6 +1404,7 @@ pub enum ThreadEvent { /// Whether the tool was canceled by the user. canceled: bool, }, + CheckpointChanged, } impl EventEmitter for Thread {} diff --git a/crates/project/src/git.rs b/crates/project/src/git.rs index 0fd923b63b..bb802e9edf 100644 --- a/crates/project/src/git.rs +++ b/crates/project/src/git.rs @@ -542,7 +542,8 @@ impl GitStore { let mut tasks = Vec::new(); for (dot_git_abs_path, checkpoint) in checkpoint.checkpoints_by_dot_git_abs_path { if let Some(repository) = repositories_by_dot_git_abs_path.get(&dot_git_abs_path) { - tasks.push(repository.read(cx).restore_checkpoint(checkpoint)); + let restore = repository.read(cx).restore_checkpoint(checkpoint); + tasks.push(async move { restore.await? }); } } cx.background_spawn(async move {