Delete unused checkpoints (#27260)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-03-21 17:39:01 +01:00 committed by GitHub
parent a52e2f9553
commit 0e9e2d70cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 177 additions and 67 deletions

View file

@ -1,6 +1,5 @@
use std::fmt::Write as _;
use std::io::Write;
use std::mem;
use std::sync::Arc;
use anyhow::{Context as _, Result};
@ -186,8 +185,7 @@ pub struct Thread {
tool_use: ToolUseState,
action_log: Entity<ActionLog>,
last_restore_checkpoint: Option<LastRestoreCheckpoint>,
pending_checkpoint: Option<Task<Result<ThreadCheckpoint>>>,
checkpoint_on_next_user_message: bool,
pending_checkpoint: Option<ThreadCheckpoint>,
scripting_session: Entity<ScriptingSession>,
scripting_tool_use: ToolUseState,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
@ -220,7 +218,6 @@ impl Thread {
tools: tools.clone(),
last_restore_checkpoint: None,
pending_checkpoint: None,
checkpoint_on_next_user_message: true,
tool_use: ToolUseState::new(tools.clone()),
scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
scripting_tool_use: ToolUseState::new(tools),
@ -293,7 +290,6 @@ impl Thread {
pending_completions: Vec::new(),
last_restore_checkpoint: None,
pending_checkpoint: None,
checkpoint_on_next_user_message: true,
project,
prompt_builder,
tools,
@ -385,11 +381,8 @@ impl Thread {
} else {
this.truncate(checkpoint.message_id, cx);
this.last_restore_checkpoint = None;
this.pending_checkpoint = Some(Task::ready(Ok(ThreadCheckpoint {
message_id: this.next_message_id,
git_checkpoint: checkpoint.git_checkpoint,
})));
}
this.pending_checkpoint = None;
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
})?;
@ -397,46 +390,62 @@ impl Thread {
})
}
fn checkpoint(&mut self, cx: &mut Context<Self>) {
if self.is_generating() {
fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
let pending_checkpoint = if self.is_generating() {
return;
}
} else if let Some(checkpoint) = self.pending_checkpoint.take() {
checkpoint
} else {
return;
};
let git_store = self.project.read(cx).git_store().clone();
let new_checkpoint = git_store.read(cx).checkpoint(cx);
let old_checkpoint = self.pending_checkpoint.take();
let next_user_message_id = self.next_message_id;
self.pending_checkpoint = Some(cx.spawn(async move |this, cx| {
let new_checkpoint = new_checkpoint.await?;
let final_checkpoint = git_store.read(cx).checkpoint(cx);
cx.spawn(async move |this, cx| match final_checkpoint.await {
Ok(final_checkpoint) => {
let equal = git_store
.read_with(cx, |store, cx| {
store.compare_checkpoints(
pending_checkpoint.git_checkpoint.clone(),
final_checkpoint.clone(),
cx,
)
})?
.await
.unwrap_or(false);
if let Some(old_checkpoint) = old_checkpoint {
if let Ok(old_checkpoint) = old_checkpoint.await {
let equal = git_store
if equal {
git_store
.read_with(cx, |store, cx| {
store.compare_checkpoints(
old_checkpoint.git_checkpoint.clone(),
new_checkpoint.clone(),
cx,
)
store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
})?
.await;
if equal.ok() != Some(true) {
this.update(cx, |this, cx| {
this.checkpoints_by_message
.insert(old_checkpoint.message_id, old_checkpoint);
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
})?;
}
.detach();
} else {
this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
})?;
}
}
Ok(ThreadCheckpoint {
message_id: next_user_message_id,
git_checkpoint: new_checkpoint,
})
}));
git_store
.read_with(cx, |store, cx| {
store.delete_checkpoint(final_checkpoint, cx)
})?
.detach();
Ok(())
}
Err(_) => this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
}),
})
.detach();
}
fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
self.checkpoints_by_message
.insert(checkpoint.message_id, checkpoint);
cx.emit(ThreadEvent::CheckpointChanged);
cx.notify();
}
pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
@ -517,18 +526,21 @@ impl Thread {
&mut self,
text: impl Into<String>,
context: Vec<ContextSnapshot>,
git_checkpoint: Option<GitStoreCheckpoint>,
cx: &mut Context<Self>,
) -> MessageId {
if mem::take(&mut self.checkpoint_on_next_user_message) {
self.checkpoint(cx);
}
let message_id =
self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
self.context
.extend(context.into_iter().map(|context| (context.id, context)));
self.context_by_message.insert(message_id, context_ids);
if let Some(git_checkpoint) = git_checkpoint {
self.pending_checkpoint = Some(ThreadCheckpoint {
message_id,
git_checkpoint,
});
}
message_id
}
@ -1050,7 +1062,7 @@ impl Thread {
thread
.update(cx, |thread, cx| {
thread.checkpoint(cx);
thread.finalize_pending_checkpoint(cx);
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
@ -1319,6 +1331,7 @@ impl Thread {
// so for now we provide some text to keep the model on track.
"Here are the tool results.",
Vec::new(),
None,
cx,
);
}
@ -1341,7 +1354,7 @@ impl Thread {
}
canceled
};
self.checkpoint(cx);
self.finalize_pending_checkpoint(cx);
canceled
}