acp-native-rewind
This commit is contained in:
parent
1460573dd4
commit
c985789d84
13 changed files with 202 additions and 168 deletions
|
@ -32,10 +32,15 @@ use std::time::{Duration, Instant};
|
|||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||
use ui::App;
|
||||
use util::ResultExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn new_prompt_id() -> acp::PromptId {
|
||||
acp::PromptId(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UserMessage {
|
||||
pub id: Option<UserMessageId>,
|
||||
pub prompt_id: Option<acp::PromptId>,
|
||||
pub content: ContentBlock,
|
||||
pub chunks: Vec<acp::ContentBlock>,
|
||||
pub checkpoint: Option<Checkpoint>,
|
||||
|
@ -962,7 +967,7 @@ impl AcpThread {
|
|||
|
||||
pub fn push_user_content_block(
|
||||
&mut self,
|
||||
message_id: Option<UserMessageId>,
|
||||
prompt_id: Option<acp::PromptId>,
|
||||
chunk: acp::ContentBlock,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
@ -971,13 +976,13 @@ impl AcpThread {
|
|||
|
||||
if let Some(last_entry) = self.entries.last_mut()
|
||||
&& let AgentThreadEntry::UserMessage(UserMessage {
|
||||
id,
|
||||
prompt_id: id,
|
||||
content,
|
||||
chunks,
|
||||
..
|
||||
}) = last_entry
|
||||
{
|
||||
*id = message_id.or(id.take());
|
||||
*id = prompt_id.or(id.take());
|
||||
content.append(chunk.clone(), &language_registry, cx);
|
||||
chunks.push(chunk);
|
||||
let idx = entries_len - 1;
|
||||
|
@ -986,7 +991,7 @@ impl AcpThread {
|
|||
let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
|
||||
self.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id,
|
||||
prompt_id,
|
||||
content,
|
||||
chunks: vec![chunk],
|
||||
checkpoint: None,
|
||||
|
@ -1336,6 +1341,7 @@ impl AcpThread {
|
|||
cx: &mut Context<Self>,
|
||||
) -> BoxFuture<'static, Result<()>> {
|
||||
self.send(
|
||||
new_prompt_id(),
|
||||
vec![acp::ContentBlock::Text(acp::TextContent {
|
||||
text: message.to_string(),
|
||||
annotations: None,
|
||||
|
@ -1346,6 +1352,7 @@ impl AcpThread {
|
|||
|
||||
pub fn send(
|
||||
&mut self,
|
||||
prompt_id: acp::PromptId,
|
||||
message: Vec<acp::ContentBlock>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> BoxFuture<'static, Result<()>> {
|
||||
|
@ -1355,22 +1362,17 @@ impl AcpThread {
|
|||
cx,
|
||||
);
|
||||
let request = acp::PromptRequest {
|
||||
prompt_id: Some(prompt_id.clone()),
|
||||
prompt: message.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
|
||||
Some(UserMessageId::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.run_turn(cx, async move |this, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id.clone(),
|
||||
prompt_id: Some(prompt_id),
|
||||
content: block,
|
||||
chunks: message,
|
||||
checkpoint: None,
|
||||
|
@ -1392,7 +1394,7 @@ impl AcpThread {
|
|||
show: false,
|
||||
});
|
||||
}
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
this.connection.prompt(request, cx)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
|
@ -1509,8 +1511,8 @@ impl AcpThread {
|
|||
|
||||
/// Rewinds this thread to before the entry at `index`, removing it and all
|
||||
/// subsequent entries while reverting any changes made from that point.
|
||||
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
|
||||
pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let Some(rewind) = self.connection.rewind(&self.session_id, cx) else {
|
||||
return Task::ready(Err(anyhow!("not supported")));
|
||||
};
|
||||
let Some(message) = self.user_message(&id) else {
|
||||
|
@ -1530,7 +1532,7 @@ impl AcpThread {
|
|||
.await?;
|
||||
}
|
||||
|
||||
cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
|
||||
cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((ix, _)) = this.user_message_mut(&id) {
|
||||
let range = ix..this.entries.len();
|
||||
|
@ -1594,10 +1596,10 @@ impl AcpThread {
|
|||
})
|
||||
}
|
||||
|
||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||
fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> {
|
||||
self.entries.iter().find_map(|entry| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
if message.id.as_ref() == Some(id) {
|
||||
if message.prompt_id.as_ref() == Some(id) {
|
||||
Some(message)
|
||||
} else {
|
||||
None
|
||||
|
@ -1608,10 +1610,10 @@ impl AcpThread {
|
|||
})
|
||||
}
|
||||
|
||||
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
|
||||
fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> {
|
||||
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
if message.id.as_ref() == Some(id) {
|
||||
if message.prompt_id.as_ref() == Some(id) {
|
||||
Some((ix, message))
|
||||
} else {
|
||||
None
|
||||
|
@ -1905,7 +1907,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 1);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||
assert_eq!(user_msg.id, None);
|
||||
assert_eq!(user_msg.prompt_id, None);
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
|
||||
} else {
|
||||
panic!("Expected UserMessage");
|
||||
|
@ -1913,7 +1915,7 @@ mod tests {
|
|||
});
|
||||
|
||||
// Test appending to existing user message
|
||||
let message_1_id = UserMessageId::new();
|
||||
let message_1_id = new_prompt_id();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(
|
||||
Some(message_1_id.clone()),
|
||||
|
@ -1928,7 +1930,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 1);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
|
||||
assert_eq!(user_msg.id, Some(message_1_id));
|
||||
assert_eq!(user_msg.prompt_id, Some(message_1_id));
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
|
||||
} else {
|
||||
panic!("Expected UserMessage");
|
||||
|
@ -1947,7 +1949,7 @@ mod tests {
|
|||
);
|
||||
});
|
||||
|
||||
let message_2_id = UserMessageId::new();
|
||||
let message_2_id = new_prompt_id();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(
|
||||
Some(message_2_id.clone()),
|
||||
|
@ -1962,7 +1964,7 @@ mod tests {
|
|||
thread.update(cx, |thread, cx| {
|
||||
assert_eq!(thread.entries.len(), 3);
|
||||
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
|
||||
assert_eq!(user_msg.id, Some(message_2_id));
|
||||
assert_eq!(user_msg.prompt_id, Some(message_2_id));
|
||||
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
|
||||
} else {
|
||||
panic!("Expected UserMessage at index 2");
|
||||
|
@ -2259,9 +2261,13 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["Hi".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
|
||||
}
|
||||
|
@ -2320,9 +2326,13 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["Lorem".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
|
@ -2340,9 +2350,13 @@ mod tests {
|
|||
});
|
||||
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["ipsum".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
|
@ -2376,9 +2390,13 @@ mod tests {
|
|||
|
||||
// Checkpoint isn't stored when there are no changes.
|
||||
simulate_changes.store(false, SeqCst);
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["dolor".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
|
@ -2424,7 +2442,7 @@ mod tests {
|
|||
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
|
||||
panic!("unexpected entries {:?}", thread.entries)
|
||||
};
|
||||
thread.rewind(message.id.clone().unwrap(), cx)
|
||||
thread.rewind(message.prompt_id.clone().unwrap(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -2490,9 +2508,13 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["hello".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
|
@ -2512,9 +2534,13 @@ mod tests {
|
|||
// Simulate refusing the second message, ensuring the conversation gets
|
||||
// truncated to before sending it.
|
||||
refuse_next.store(true, SeqCst);
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(new_prompt_id(), vec!["world".into()], cx)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
|
@ -2653,7 +2679,6 @@ mod tests {
|
|||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
|
@ -2683,11 +2708,11 @@ mod tests {
|
|||
.detach();
|
||||
}
|
||||
|
||||
fn truncate(
|
||||
fn rewind(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
_cx: &App,
|
||||
) -> Option<Rc<dyn AgentSessionTruncate>> {
|
||||
) -> Option<Rc<dyn AgentSessionRewind>> {
|
||||
Some(Rc::new(FakeAgentSessionEditor {
|
||||
_session_id: session_id.clone(),
|
||||
}))
|
||||
|
@ -2702,8 +2727,8 @@ mod tests {
|
|||
_session_id: acp::SessionId,
|
||||
}
|
||||
|
||||
impl AgentSessionTruncate for FakeAgentSessionEditor {
|
||||
fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
|
||||
impl AgentSessionRewind for FakeAgentSessionEditor {
|
||||
fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue