agent2: Initial infra for checkpoints and message editing (#36120)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Ben Brandt 2025-08-13 17:46:28 +02:00 committed by GitHub
parent f4b0332f78
commit 23cd5b59b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1374 additions and 582 deletions

View file

@ -36,6 +36,7 @@ terminal.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true

View file

@ -9,18 +9,19 @@ pub use mention::*;
pub use terminal::*;
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use anyhow::{Context as _, Result};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use editor::Bias;
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use itertools::Itertools;
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
use markdown::Markdown;
use project::{AgentLocation, Project};
use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Formatter;
use std::fmt::{Formatter, Write};
use std::ops::Range;
use std::process::ExitStatus;
use std::rc::Rc;
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
@ -29,24 +30,23 @@ use util::ResultExt;
#[derive(Debug)]
pub struct UserMessage {
pub id: Option<UserMessageId>,
pub content: ContentBlock,
pub checkpoint: Option<GitStoreCheckpoint>,
}
impl UserMessage {
pub fn from_acp(
message: impl IntoIterator<Item = acp::ContentBlock>,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Self {
let mut content = ContentBlock::Empty;
for chunk in message {
content.append(chunk, &language_registry, cx)
}
Self { content: content }
}
fn to_markdown(&self, cx: &App) -> String {
format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
let mut markdown = String::new();
if let Some(_) = self.checkpoint {
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
writeln!(markdown, "## User").unwrap();
}
writeln!(markdown).unwrap();
writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
writeln!(markdown).unwrap();
markdown
}
}
@ -633,6 +633,7 @@ pub struct AcpThread {
pub enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
Stopped,
Error,
@ -772,7 +773,7 @@ impl AcpThread {
) -> Result<()> {
match update {
acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(content, cx);
self.push_user_content_block(None, content, cx);
}
acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content, false, cx);
@ -793,18 +794,32 @@ impl AcpThread {
Ok(())
}
pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
pub fn push_user_content_block(
&mut self,
message_id: Option<UserMessageId>,
chunk: acp::ContentBlock,
cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
&& let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
{
*id = message_id.or(id.take());
content.append(chunk, &language_registry, cx);
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
let idx = entries_len - 1;
cx.emit(AcpThreadEvent::EntryUpdated(idx));
} else {
let content = ContentBlock::new(chunk, &language_registry, cx);
self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
self.push_entry(
AgentThreadEntry::UserMessage(UserMessage {
id: message_id,
content,
checkpoint: None,
}),
cx,
);
}
}
@ -819,7 +834,8 @@ impl AcpThread {
if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
{
cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
let idx = entries_len - 1;
cx.emit(AcpThreadEvent::EntryUpdated(idx));
match (chunks.last_mut(), is_thought) {
(Some(AssistantMessageChunk::Message { block }), false)
| (Some(AssistantMessageChunk::Thought { block }), true) => {
@ -1118,69 +1134,113 @@ impl AcpThread {
self.project.read(cx).languages().clone(),
cx,
);
let git_store = self.project.read(cx).git_store().clone();
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
let message_id = if self
.connection
.session_editor(&self.session_id, cx)
.is_some()
{
Some(UserMessageId::new())
} else {
None
};
self.push_entry(
AgentThreadEntry::UserMessage(UserMessage { content: block }),
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
checkpoint: None,
}),
cx,
);
self.clear_completed_plan_entries(cx);
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
let request = acp::PromptRequest {
prompt: message,
session_id: self.session_id.clone(),
};
self.send_task = Some(cx.spawn(async move |this, cx| {
async {
self.send_task = Some(cx.spawn({
let message_id = message_id.clone();
async move |this, cx| {
cancel_task.await;
let result = this
.update(cx, |this, cx| {
this.connection.prompt(
acp::PromptRequest {
prompt: message,
session_id: this.session_id.clone(),
},
cx,
)
})?
.await;
tx.send(result).log_err();
anyhow::Ok(())
old_checkpoint_tx.send(old_checkpoint.await).ok();
if let Ok(result) = this.update(cx, |this, cx| {
this.connection.prompt(message_id, request, cx)
}) {
tx.send(result.await).log_err();
}
}
.await
.log_err();
}));
cx.spawn(async move |this, cx| match rx.await {
Ok(Err(e)) => {
this.update(cx, |this, cx| {
this.send_task.take();
cx.emit(AcpThreadEvent::Error)
})
cx.spawn(async move |this, cx| {
let old_checkpoint = old_checkpoint_rx
.await
.map_err(|_| anyhow!("send canceled"))
.flatten()
.context("failed to get old checkpoint")
.log_err();
Err(e)?
}
result => {
let cancelled = matches!(
result,
Ok(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled
}))
);
// We only take the task if the current prompt wasn't cancelled.
//
// This prompt may have been cancelled because another one was sent
// while it was still generating. In these cases, dropping `send_task`
// would cause the next generation to be cancelled.
if !cancelled {
this.update(cx, |this, _cx| this.send_task.take()).ok();
}
let response = rx.await;
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
let new_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
.context("failed to get new checkpoint")
.log_err();
Ok(())
if let Some(new_checkpoint) = new_checkpoint {
let equal = git_store
.update(cx, |git, cx| {
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
})?
.await
.unwrap_or(true);
if !equal {
this.update(cx, |this, cx| {
if let Some((ix, message)) = this.user_message_mut(&message_id) {
message.checkpoint = Some(old_checkpoint);
cx.emit(AcpThreadEvent::EntryUpdated(ix));
}
})?;
}
}
}
this.update(cx, |this, cx| {
match response {
Ok(Err(e)) => {
this.send_task.take();
cx.emit(AcpThreadEvent::Error);
Err(e)
}
result => {
let cancelled = matches!(
result,
Ok(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled
}))
);
// We only take the task if the current prompt wasn't cancelled.
//
// This prompt may have been cancelled because another one was sent
// while it was still generating. In these cases, dropping `send_task`
// would cause the next generation to be cancelled.
if !cancelled {
this.send_task.take();
}
cx.emit(AcpThreadEvent::Stopped);
Ok(())
}
}
})?
})
.boxed()
}
@ -1212,6 +1272,66 @@ impl AcpThread {
cx.foreground_executor().spawn(send_task)
}
/// Rewinds this thread to before the entry at `index`, removing it and all
/// subsequent entries while reverting any changes made from that point.
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
return Task::ready(Err(anyhow!("not supported")));
};
let Some(message) = self.user_message(&id) else {
return Task::ready(Err(anyhow!("message not found")));
};
let checkpoint = message.checkpoint.clone();
let git_store = self.project.read(cx).git_store().clone();
cx.spawn(async move |this, cx| {
if let Some(checkpoint) = checkpoint {
git_store
.update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
.await?;
}
cx.update(|cx| session_editor.truncate(id.clone(), cx))?
.await?;
this.update(cx, |this, cx| {
if let Some((ix, _)) = this.user_message_mut(&id) {
let range = ix..this.entries.len();
this.entries.truncate(ix);
cx.emit(AcpThreadEvent::EntriesRemoved(range));
}
})
})
}
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(&id) {
Some(message)
} else {
None
}
} else {
None
}
})
}
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(&id) {
Some((ix, message))
} else {
None
}
} else {
None
}
})
}
pub fn read_text_file(
&self,
path: PathBuf,
@ -1414,13 +1534,18 @@ mod tests {
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::FakeFs;
use project::{FakeFs, Fs};
use rand::Rng as _;
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
use std::{
cell::RefCell,
path::Path,
rc::Rc,
sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
time::Duration,
};
use util::path;
fn init_test(cx: &mut TestAppContext) {
@ -1452,6 +1577,7 @@ mod tests {
// Test creating a new user message
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
None,
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "Hello, ".to_string(),
@ -1463,6 +1589,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
assert_eq!(user_msg.id, None);
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
} else {
panic!("Expected UserMessage");
@ -1470,8 +1597,10 @@ mod tests {
});
// Test appending to existing user message
let message_1_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
Some(message_1_id.clone()),
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "world!".to_string(),
@ -1483,6 +1612,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
assert_eq!(user_msg.id, Some(message_1_id));
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
} else {
panic!("Expected UserMessage");
@ -1501,8 +1631,10 @@ mod tests {
);
});
let message_2_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
Some(message_2_id.clone()),
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "New user message".to_string(),
@ -1514,6 +1646,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 3);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
assert_eq!(user_msg.id, Some(message_2_id));
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
} else {
panic!("Expected UserMessage at index 2");
@ -1830,6 +1963,180 @@ mod tests {
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
}
#[gpui::test(iterations = 10)]
async fn test_checkpoints(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(
path!("/test"),
json!({
".git": {}
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let simulate_changes = Arc::new(AtomicBool::new(true));
let next_filename = Arc::new(AtomicUsize::new(0));
let connection = Rc::new(FakeAgentConnection::new().on_user_message({
let simulate_changes = simulate_changes.clone();
let next_filename = next_filename.clone();
let fs = fs.clone();
move |request, thread, mut cx| {
let fs = fs.clone();
let simulate_changes = simulate_changes.clone();
let next_filename = next_filename.clone();
async move {
if simulate_changes.load(SeqCst) {
let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
fs.write(Path::new(&filename), b"").await?;
}
let acp::ContentBlock::Text(content) = &request.prompt[0] else {
panic!("expected text content block");
};
thread.update(&mut cx, |thread, cx| {
thread
.handle_session_update(
acp::SessionUpdate::AgentMessageChunk {
content: content.text.to_uppercase().into(),
},
cx,
)
.unwrap();
})?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
}
.boxed_local()
}
}));
let thread = connection
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
.await
.unwrap();
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
.await
.unwrap();
thread.read_with(cx, |thread, cx| {
assert_eq!(
thread.to_markdown(cx),
indoc! {"
## User (checkpoint)
Lorem
## Assistant
LOREM
"}
);
});
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
.await
.unwrap();
thread.read_with(cx, |thread, cx| {
assert_eq!(
thread.to_markdown(cx),
indoc! {"
## User (checkpoint)
Lorem
## Assistant
LOREM
## User (checkpoint)
ipsum
## Assistant
IPSUM
"}
);
});
assert_eq!(
fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
);
// Checkpoint isn't stored when there are no changes.
simulate_changes.store(false, SeqCst);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
.await
.unwrap();
thread.read_with(cx, |thread, cx| {
assert_eq!(
thread.to_markdown(cx),
indoc! {"
## User (checkpoint)
Lorem
## Assistant
LOREM
## User (checkpoint)
ipsum
## Assistant
IPSUM
## User
dolor
## Assistant
DOLOR
"}
);
});
assert_eq!(
fs.files(),
vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
);
// Rewinding the conversation truncates the history and restores the checkpoint.
thread
.update(cx, |thread, cx| {
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
panic!("unexpected entries {:?}", thread.entries)
};
thread.rewind(message.id.clone().unwrap(), cx)
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| {
assert_eq!(
thread.to_markdown(cx),
indoc! {"
## User (checkpoint)
Lorem
## Assistant
LOREM
"}
);
});
assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
}
async fn run_until_first_tool_call(
thread: &Entity<AcpThread>,
cx: &mut TestAppContext,
@ -1938,6 +2245,7 @@ mod tests {
fn prompt(
&self,
_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
@ -1966,5 +2274,25 @@ mod tests {
})
.detach();
}
fn session_editor(
&self,
session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(FakeAgentSessionEditor {
_session_id: session_id.clone(),
}))
}
}
struct FakeAgentSessionEditor {
_session_id: acp::SessionId,
}
impl AgentSessionEditor for FakeAgentSessionEditor {
fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
}

View file

@ -1,13 +1,21 @@
use std::{error::Error, fmt, path::Path, rc::Rc};
use crate::AcpThread;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{AsyncApp, Entity, SharedString, Task};
use project::Project;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
use crate::AcpThread;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
pub trait AgentConnection {
fn new_thread(
@ -21,11 +29,23 @@ pub trait AgentConnection {
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
-> Task<Result<acp::PromptResponse>>;
fn prompt(
&self,
user_message_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
&self,
_session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
None
}
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
@ -35,6 +55,10 @@ pub trait AgentConnection {
}
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
#[derive(Debug)]
pub struct AuthRequired;