Show "Restore Checkpoint" only when there were changes (#27243)
Release Notes: - N/A --------- Co-authored-by: Agus Zubiaga <hi@aguz.me> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
This commit is contained in:
parent
9d965bc98a
commit
e14ebcf267
9 changed files with 350 additions and 112 deletions
|
@ -1,5 +1,6 @@
|
|||
use std::fmt::Write as _;
|
||||
use std::io::Write;
|
||||
use std::mem;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
|
@ -176,7 +177,7 @@ pub struct Thread {
|
|||
context: BTreeMap<ContextId, ContextSnapshot>,
|
||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||
system_prompt_context: Option<AssistantSystemPromptContext>,
|
||||
checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
|
||||
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
project: Entity<Project>,
|
||||
|
@ -185,6 +186,8 @@ pub struct Thread {
|
|||
tool_use: ToolUseState,
|
||||
action_log: Entity<ActionLog>,
|
||||
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
|
||||
pending_checkpoint: Option<Task<Result<ThreadCheckpoint>>>,
|
||||
checkpoint_on_next_user_message: bool,
|
||||
scripting_session: Entity<ScriptingSession>,
|
||||
scripting_tool_use: ToolUseState,
|
||||
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
|
||||
|
@ -216,6 +219,8 @@ impl Thread {
|
|||
prompt_builder,
|
||||
tools: tools.clone(),
|
||||
last_restore_checkpoint: None,
|
||||
pending_checkpoint: None,
|
||||
checkpoint_on_next_user_message: true,
|
||||
tool_use: ToolUseState::new(tools.clone()),
|
||||
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
|
||||
scripting_tool_use: ToolUseState::new(tools),
|
||||
|
@ -287,6 +292,8 @@ impl Thread {
|
|||
completion_count: 0,
|
||||
pending_completions: Vec::new(),
|
||||
last_restore_checkpoint: None,
|
||||
pending_checkpoint: None,
|
||||
checkpoint_on_next_user_message: true,
|
||||
project,
|
||||
prompt_builder,
|
||||
tools,
|
||||
|
@ -348,11 +355,7 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
|
||||
let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
|
||||
Some(ThreadCheckpoint {
|
||||
message_id: id,
|
||||
git_checkpoint: checkpoint,
|
||||
})
|
||||
self.checkpoints_by_message.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub fn restore_checkpoint(
|
||||
|
@ -364,12 +367,13 @@ impl Thread {
|
|||
message_id: checkpoint.message_id,
|
||||
});
|
||||
cx.emit(ThreadEvent::CheckpointChanged);
|
||||
cx.notify();
|
||||
|
||||
let project = self.project.read(cx);
|
||||
let restore = project
|
||||
.git_store()
|
||||
.read(cx)
|
||||
.restore_checkpoint(checkpoint.git_checkpoint, cx);
|
||||
.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = restore.await;
|
||||
this.update(cx, |this, cx| {
|
||||
|
@ -379,15 +383,62 @@ impl Thread {
|
|||
error: err.to_string(),
|
||||
});
|
||||
} else {
|
||||
this.last_restore_checkpoint = None;
|
||||
this.truncate(checkpoint.message_id, cx);
|
||||
this.last_restore_checkpoint = None;
|
||||
this.pending_checkpoint = Some(Task::ready(Ok(ThreadCheckpoint {
|
||||
message_id: this.next_message_id,
|
||||
git_checkpoint: checkpoint.git_checkpoint,
|
||||
})));
|
||||
}
|
||||
cx.emit(ThreadEvent::CheckpointChanged);
|
||||
cx.notify();
|
||||
})?;
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn checkpoint(&mut self, cx: &mut Context<Self>) {
|
||||
if self.is_generating() {
|
||||
return;
|
||||
}
|
||||
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let new_checkpoint = git_store.read(cx).checkpoint(cx);
|
||||
let old_checkpoint = self.pending_checkpoint.take();
|
||||
let next_user_message_id = self.next_message_id;
|
||||
self.pending_checkpoint = Some(cx.spawn(async move |this, cx| {
|
||||
let new_checkpoint = new_checkpoint.await?;
|
||||
|
||||
if let Some(old_checkpoint) = old_checkpoint {
|
||||
if let Ok(old_checkpoint) = old_checkpoint.await {
|
||||
let equal = git_store
|
||||
.read_with(cx, |store, cx| {
|
||||
store.compare_checkpoints(
|
||||
old_checkpoint.git_checkpoint.clone(),
|
||||
new_checkpoint.clone(),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
if equal.ok() != Some(true) {
|
||||
this.update(cx, |this, cx| {
|
||||
this.checkpoints_by_message
|
||||
.insert(old_checkpoint.message_id, old_checkpoint);
|
||||
cx.emit(ThreadEvent::CheckpointChanged);
|
||||
cx.notify();
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ThreadCheckpoint {
|
||||
message_id: next_user_message_id,
|
||||
git_checkpoint: new_checkpoint,
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
|
||||
self.last_restore_checkpoint.as_ref()
|
||||
}
|
||||
|
@ -466,18 +517,18 @@ impl Thread {
|
|||
&mut self,
|
||||
text: impl Into<String>,
|
||||
context: Vec<ContextSnapshot>,
|
||||
checkpoint: Option<GitStoreCheckpoint>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
if mem::take(&mut self.checkpoint_on_next_user_message) {
|
||||
self.checkpoint(cx);
|
||||
}
|
||||
|
||||
let message_id =
|
||||
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
|
||||
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
|
||||
self.context
|
||||
.extend(context.into_iter().map(|context| (context.id, context)));
|
||||
self.context_by_message.insert(message_id, context_ids);
|
||||
if let Some(checkpoint) = checkpoint {
|
||||
self.checkpoints_by_message.insert(message_id, checkpoint);
|
||||
}
|
||||
message_id
|
||||
}
|
||||
|
||||
|
@ -999,6 +1050,7 @@ impl Thread {
|
|||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.checkpoint(cx);
|
||||
match result.as_ref() {
|
||||
Ok(stop_reason) => match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
|
@ -1267,7 +1319,6 @@ impl Thread {
|
|||
// so for now we provide some text to keep the model on track.
|
||||
"Here are the tool results.",
|
||||
Vec::new(),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
@ -1276,7 +1327,7 @@ impl Thread {
|
|||
///
|
||||
/// Returns whether a completion was canceled.
|
||||
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
if self.pending_completions.pop().is_some() {
|
||||
let canceled = if self.pending_completions.pop().is_some() {
|
||||
true
|
||||
} else {
|
||||
let mut canceled = false;
|
||||
|
@ -1289,7 +1340,9 @@ impl Thread {
|
|||
});
|
||||
}
|
||||
canceled
|
||||
}
|
||||
};
|
||||
self.checkpoint(cx);
|
||||
canceled
|
||||
}
|
||||
|
||||
/// Returns the feedback given to the thread, if any.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue