Compare commits

...
Sign in to create a new pull request.

1 commit

Author SHA1 Message Date
Conrad Irwin
c985789d84 acp-native-rewind 2025-08-25 20:58:03 -06:00
13 changed files with 202 additions and 168 deletions

4
Cargo.lock generated
View file

@ -191,9 +191,7 @@ dependencies = [
[[package]] [[package]]
name = "agent-client-protocol" name = "agent-client-protocol"
version = "0.0.31" version = "0.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "289eb34ee17213dadcca47eedadd386a5e7678094095414e475965d1bcca2860"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-broadcast", "async-broadcast",

View file

@ -426,7 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" }
# External crates # External crates
# #
agent-client-protocol = "0.0.31" agent-client-protocol = { path = "../agent-client-protocol"}
aho-corasick = "1.1" aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14" any_vec = "0.14"

View file

@ -32,10 +32,15 @@ use std::time::{Duration, Instant};
use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
use ui::App; use ui::App;
use util::ResultExt; use util::ResultExt;
use uuid::Uuid;
pub fn new_prompt_id() -> acp::PromptId {
acp::PromptId(Uuid::new_v4().to_string().into())
}
#[derive(Debug)] #[derive(Debug)]
pub struct UserMessage { pub struct UserMessage {
pub id: Option<UserMessageId>, pub prompt_id: Option<acp::PromptId>,
pub content: ContentBlock, pub content: ContentBlock,
pub chunks: Vec<acp::ContentBlock>, pub chunks: Vec<acp::ContentBlock>,
pub checkpoint: Option<Checkpoint>, pub checkpoint: Option<Checkpoint>,
@ -962,7 +967,7 @@ impl AcpThread {
pub fn push_user_content_block( pub fn push_user_content_block(
&mut self, &mut self,
message_id: Option<UserMessageId>, prompt_id: Option<acp::PromptId>,
chunk: acp::ContentBlock, chunk: acp::ContentBlock,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) { ) {
@ -971,13 +976,13 @@ impl AcpThread {
if let Some(last_entry) = self.entries.last_mut() if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntry::UserMessage(UserMessage { && let AgentThreadEntry::UserMessage(UserMessage {
id, prompt_id: id,
content, content,
chunks, chunks,
.. ..
}) = last_entry }) = last_entry
{ {
*id = message_id.or(id.take()); *id = prompt_id.or(id.take());
content.append(chunk.clone(), &language_registry, cx); content.append(chunk.clone(), &language_registry, cx);
chunks.push(chunk); chunks.push(chunk);
let idx = entries_len - 1; let idx = entries_len - 1;
@ -986,7 +991,7 @@ impl AcpThread {
let content = ContentBlock::new(chunk.clone(), &language_registry, cx); let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
self.push_entry( self.push_entry(
AgentThreadEntry::UserMessage(UserMessage { AgentThreadEntry::UserMessage(UserMessage {
id: message_id, prompt_id,
content, content,
chunks: vec![chunk], chunks: vec![chunk],
checkpoint: None, checkpoint: None,
@ -1336,6 +1341,7 @@ impl AcpThread {
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> BoxFuture<'static, Result<()>> { ) -> BoxFuture<'static, Result<()>> {
self.send( self.send(
new_prompt_id(),
vec![acp::ContentBlock::Text(acp::TextContent { vec![acp::ContentBlock::Text(acp::TextContent {
text: message.to_string(), text: message.to_string(),
annotations: None, annotations: None,
@ -1346,6 +1352,7 @@ impl AcpThread {
pub fn send( pub fn send(
&mut self, &mut self,
prompt_id: acp::PromptId,
message: Vec<acp::ContentBlock>, message: Vec<acp::ContentBlock>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> BoxFuture<'static, Result<()>> { ) -> BoxFuture<'static, Result<()>> {
@ -1355,22 +1362,17 @@ impl AcpThread {
cx, cx,
); );
let request = acp::PromptRequest { let request = acp::PromptRequest {
prompt_id: Some(prompt_id.clone()),
prompt: message.clone(), prompt: message.clone(),
session_id: self.session_id.clone(), session_id: self.session_id.clone(),
}; };
let git_store = self.project.read(cx).git_store().clone(); let git_store = self.project.read(cx).git_store().clone();
let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
Some(UserMessageId::new())
} else {
None
};
self.run_turn(cx, async move |this, cx| { self.run_turn(cx, async move |this, cx| {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.push_entry( this.push_entry(
AgentThreadEntry::UserMessage(UserMessage { AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(), prompt_id: Some(prompt_id),
content: block, content: block,
chunks: message, chunks: message,
checkpoint: None, checkpoint: None,
@ -1392,7 +1394,7 @@ impl AcpThread {
show: false, show: false,
}); });
} }
this.connection.prompt(message_id, request, cx) this.connection.prompt(request, cx)
})? })?
.await .await
}) })
@ -1509,8 +1511,8 @@ impl AcpThread {
/// Rewinds this thread to before the entry at `index`, removing it and all /// Rewinds this thread to before the entry at `index`, removing it and all
/// subsequent entries while reverting any changes made from that point. /// subsequent entries while reverting any changes made from that point.
pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> { pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context<Self>) -> Task<Result<()>> {
let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { let Some(rewind) = self.connection.rewind(&self.session_id, cx) else {
return Task::ready(Err(anyhow!("not supported"))); return Task::ready(Err(anyhow!("not supported")));
}; };
let Some(message) = self.user_message(&id) else { let Some(message) = self.user_message(&id) else {
@ -1530,7 +1532,7 @@ impl AcpThread {
.await?; .await?;
} }
cx.update(|cx| truncate.run(id.clone(), cx))?.await?; cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
if let Some((ix, _)) = this.user_message_mut(&id) { if let Some((ix, _)) = this.user_message_mut(&id) {
let range = ix..this.entries.len(); let range = ix..this.entries.len();
@ -1594,10 +1596,10 @@ impl AcpThread {
}) })
} }
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| { self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry { if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(id) { if message.prompt_id.as_ref() == Some(id) {
Some(message) Some(message)
} else { } else {
None None
@ -1608,10 +1610,10 @@ impl AcpThread {
}) })
} }
fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> {
self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
if let AgentThreadEntry::UserMessage(message) = entry { if let AgentThreadEntry::UserMessage(message) = entry {
if message.id.as_ref() == Some(id) { if message.prompt_id.as_ref() == Some(id) {
Some((ix, message)) Some((ix, message))
} else { } else {
None None
@ -1905,7 +1907,7 @@ mod tests {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1); assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
assert_eq!(user_msg.id, None); assert_eq!(user_msg.prompt_id, None);
assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
} else { } else {
panic!("Expected UserMessage"); panic!("Expected UserMessage");
@ -1913,7 +1915,7 @@ mod tests {
}); });
// Test appending to existing user message // Test appending to existing user message
let message_1_id = UserMessageId::new(); let message_1_id = new_prompt_id();
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.push_user_content_block( thread.push_user_content_block(
Some(message_1_id.clone()), Some(message_1_id.clone()),
@ -1928,7 +1930,7 @@ mod tests {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1); assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
assert_eq!(user_msg.id, Some(message_1_id)); assert_eq!(user_msg.prompt_id, Some(message_1_id));
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
} else { } else {
panic!("Expected UserMessage"); panic!("Expected UserMessage");
@ -1947,7 +1949,7 @@ mod tests {
); );
}); });
let message_2_id = UserMessageId::new(); let message_2_id = new_prompt_id();
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
thread.push_user_content_block( thread.push_user_content_block(
Some(message_2_id.clone()), Some(message_2_id.clone()),
@ -1962,7 +1964,7 @@ mod tests {
thread.update(cx, |thread, cx| { thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 3); assert_eq!(thread.entries.len(), 3);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
assert_eq!(user_msg.id, Some(message_2_id)); assert_eq!(user_msg.prompt_id, Some(message_2_id));
assert_eq!(user_msg.content.to_markdown(cx), "New user message"); assert_eq!(user_msg.content.to_markdown(cx), "New user message");
} else { } else {
panic!("Expected UserMessage at index 2"); panic!("Expected UserMessage at index 2");
@ -2259,9 +2261,13 @@ mod tests {
.await .await
.unwrap(); .unwrap();
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["Hi".into()], cx)
})
})
.await
.unwrap();
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
} }
@ -2320,9 +2326,13 @@ mod tests {
.await .await
.unwrap(); .unwrap();
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["Lorem".into()], cx)
})
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!( assert_eq!(
thread.to_markdown(cx), thread.to_markdown(cx),
@ -2340,9 +2350,13 @@ mod tests {
}); });
assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]); assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["ipsum".into()], cx)
})
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!( assert_eq!(
thread.to_markdown(cx), thread.to_markdown(cx),
@ -2376,9 +2390,13 @@ mod tests {
// Checkpoint isn't stored when there are no changes. // Checkpoint isn't stored when there are no changes.
simulate_changes.store(false, SeqCst); simulate_changes.store(false, SeqCst);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["dolor".into()], cx)
})
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!( assert_eq!(
thread.to_markdown(cx), thread.to_markdown(cx),
@ -2424,7 +2442,7 @@ mod tests {
let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
panic!("unexpected entries {:?}", thread.entries) panic!("unexpected entries {:?}", thread.entries)
}; };
thread.rewind(message.id.clone().unwrap(), cx) thread.rewind(message.prompt_id.clone().unwrap(), cx)
}) })
.await .await
.unwrap(); .unwrap();
@ -2490,9 +2508,13 @@ mod tests {
.await .await
.unwrap(); .unwrap();
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["hello".into()], cx)
})
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!( assert_eq!(
thread.to_markdown(cx), thread.to_markdown(cx),
@ -2512,9 +2534,13 @@ mod tests {
// Simulate refusing the second message, ensuring the conversation gets // Simulate refusing the second message, ensuring the conversation gets
// truncated to before sending it. // truncated to before sending it.
refuse_next.store(true, SeqCst); refuse_next.store(true, SeqCst);
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx))) cx.update(|cx| {
.await thread.update(cx, |thread, cx| {
.unwrap(); thread.send(new_prompt_id(), vec!["world".into()], cx)
})
})
.await
.unwrap();
thread.read_with(cx, |thread, cx| { thread.read_with(cx, |thread, cx| {
assert_eq!( assert_eq!(
thread.to_markdown(cx), thread.to_markdown(cx),
@ -2653,7 +2679,6 @@ mod tests {
fn prompt( fn prompt(
&self, &self,
_id: Option<UserMessageId>,
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> { ) -> Task<gpui::Result<acp::PromptResponse>> {
@ -2683,11 +2708,11 @@ mod tests {
.detach(); .detach();
} }
fn truncate( fn rewind(
&self, &self,
session_id: &acp::SessionId, session_id: &acp::SessionId,
_cx: &App, _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> { ) -> Option<Rc<dyn AgentSessionRewind>> {
Some(Rc::new(FakeAgentSessionEditor { Some(Rc::new(FakeAgentSessionEditor {
_session_id: session_id.clone(), _session_id: session_id.clone(),
})) }))
@ -2702,8 +2727,8 @@ mod tests {
_session_id: acp::SessionId, _session_id: acp::SessionId,
} }
impl AgentSessionTruncate for FakeAgentSessionEditor { impl AgentSessionRewind for FakeAgentSessionEditor {
fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> { fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
} }

View file

@ -5,19 +5,8 @@ use collections::IndexMap;
use gpui::{Entity, SharedString, Task}; use gpui::{Entity, SharedString, Task};
use language_model::LanguageModelProviderId; use language_model::LanguageModelProviderId;
use project::Project; use project::Project;
use serde::{Deserialize, Serialize}; use std::{any::Any, error::Error, fmt, path::Path, rc::Rc};
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName}; use ui::{App, IconName};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
pub trait AgentConnection { pub trait AgentConnection {
fn new_thread( fn new_thread(
@ -31,12 +20,8 @@ pub trait AgentConnection {
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>; fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt( fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
&self, -> Task<Result<acp::PromptResponse>>;
user_message_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn resume( fn resume(
&self, &self,
@ -48,11 +33,11 @@ pub trait AgentConnection {
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn truncate( fn rewind(
&self, &self,
_session_id: &acp::SessionId, _session_id: &acp::SessionId,
_cx: &App, _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> { ) -> Option<Rc<dyn AgentSessionRewind>> {
None None
} }
@ -85,8 +70,8 @@ impl dyn AgentConnection {
} }
} }
pub trait AgentSessionTruncate { pub trait AgentSessionRewind {
fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>; fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>>;
} }
pub trait AgentSessionResume { pub trait AgentSessionResume {
@ -362,7 +347,6 @@ mod test_support {
fn prompt( fn prompt(
&self, &self,
_id: Option<UserMessageId>,
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> { ) -> Task<gpui::Result<acp::PromptResponse>> {
@ -432,11 +416,11 @@ mod test_support {
} }
} }
fn truncate( fn rewind(
&self, &self,
_session_id: &agent_client_protocol::SessionId, _session_id: &agent_client_protocol::SessionId,
_cx: &App, _cx: &App,
) -> Option<Rc<dyn AgentSessionTruncate>> { ) -> Option<Rc<dyn AgentSessionRewind>> {
Some(Rc::new(StubAgentSessionEditor)) Some(Rc::new(StubAgentSessionEditor))
} }
@ -447,8 +431,8 @@ mod test_support {
struct StubAgentSessionEditor; struct StubAgentSessionEditor;
impl AgentSessionTruncate for StubAgentSessionEditor { impl AgentSessionRewind for StubAgentSessionEditor {
fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> { fn rewind(&self, _: acp::PromptId, _: &mut App) -> Task<Result<()>> {
Task::ready(Ok(())) Task::ready(Ok(()))
} }
} }

View file

@ -905,11 +905,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn prompt( fn prompt(
&self, &self,
id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required"); let id = params.prompt_id.expect("UserMessageId is required");
let session_id = params.session_id.clone(); let session_id = params.session_id.clone();
log::info!("Received prompt request for session: {}", session_id); log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len()); log::debug!("Prompt blocks count: {}", params.prompt.len());
@ -948,11 +947,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}); });
} }
fn truncate( fn rewind(
&self, &self,
session_id: &agent_client_protocol::SessionId, session_id: &agent_client_protocol::SessionId,
cx: &App, cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> { ) -> Option<Rc<dyn acp_thread::AgentSessionRewind>> {
self.0.read_with(cx, |agent, _cx| { self.0.read_with(cx, |agent, _cx| {
agent.sessions.get(session_id).map(|session| { agent.sessions.get(session_id).map(|session| {
Rc::new(NativeAgentSessionEditor { Rc::new(NativeAgentSessionEditor {
@ -1009,10 +1008,10 @@ struct NativeAgentSessionEditor {
acp_thread: WeakEntity<AcpThread>, acp_thread: WeakEntity<AcpThread>,
} }
impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor { impl acp_thread::AgentSessionRewind for NativeAgentSessionEditor {
fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> { fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>> {
match self.thread.update(cx, |thread, cx| { match self.thread.update(cx, |thread, cx| {
thread.truncate(message_id.clone(), cx)?; thread.rewind(message_id.clone(), cx)?;
Ok(thread.latest_token_usage()) Ok(thread.latest_token_usage())
}) { }) {
Ok(usage) => { Ok(usage) => {
@ -1065,6 +1064,7 @@ mod tests {
use super::*; use super::*;
use acp_thread::{ use acp_thread::{
AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri, AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
new_prompt_id,
}; };
use fs::FakeFs; use fs::FakeFs;
use gpui::TestAppContext; use gpui::TestAppContext;
@ -1311,6 +1311,7 @@ mod tests {
let send = acp_thread.update(cx, |thread, cx| { let send = acp_thread.update(cx, |thread, cx| {
thread.send( thread.send(
new_prompt_id(),
vec![ vec![
"What does ".into(), "What does ".into(),
acp::ContentBlock::ResourceLink(acp::ResourceLink { acp::ContentBlock::ResourceLink(acp::ResourceLink {

View file

@ -1,5 +1,5 @@
use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent};
use acp_thread::UserMessageId; use acp_thread::new_prompt_id;
use agent::{thread::DetailedSummaryState, thread_store}; use agent::{thread::DetailedSummaryState, thread_store};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, CompletionMode}; use agent_settings::{AgentProfileId, CompletionMode};
@ -43,7 +43,7 @@ pub struct DbThread {
#[serde(default)] #[serde(default)]
pub cumulative_token_usage: language_model::TokenUsage, pub cumulative_token_usage: language_model::TokenUsage,
#[serde(default)] #[serde(default)]
pub request_token_usage: HashMap<acp_thread::UserMessageId, language_model::TokenUsage>, pub request_token_usage: HashMap<acp::PromptId, language_model::TokenUsage>,
#[serde(default)] #[serde(default)]
pub model: Option<DbLanguageModel>, pub model: Option<DbLanguageModel>,
#[serde(default)] #[serde(default)]
@ -97,7 +97,7 @@ impl DbThread {
content.push(UserMessageContent::Text(msg.context)); content.push(UserMessageContent::Text(msg.context));
} }
let id = UserMessageId::new(); let id = new_prompt_id();
last_user_message_id = Some(id.clone()); last_user_message_id = Some(id.clone());
crate::Message::User(UserMessage { crate::Message::User(UserMessage {

View file

@ -1,5 +1,5 @@
use super::*; use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, new_prompt_id};
use agent_client_protocol::{self as acp}; use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId; use agent_settings::AgentProfileId;
use anyhow::Result; use anyhow::Result;
@ -48,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) {
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) thread.send(new_prompt_id(), ["Testing: Reply with 'Hello'"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -79,7 +79,7 @@ async fn test_thinking(cx: &mut TestAppContext) {
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
[indoc! {" [indoc! {"
Testing: Testing:
@ -130,9 +130,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
}); });
thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
thread.send(UserMessageId::new(), ["abc"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions(); let mut pending_completions = fake_model.pending_completions();
@ -168,7 +166,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
// Send initial user message and verify it's cached // Send initial user message and verify it's cached
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx) thread.send(new_prompt_id(), ["Message 1"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -191,7 +189,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
// Send another user message and verify only the latest is cached // Send another user message and verify only the latest is cached
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 2"], cx) thread.send(new_prompt_id(), ["Message 2"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -227,7 +225,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) {
thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Use the echo tool"], cx) thread.send(new_prompt_id(), ["Use the echo tool"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -304,7 +302,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
cx, cx,
) )
@ -320,7 +318,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.remove_tool(&EchoTool::name()); thread.remove_tool(&EchoTool::name());
thread.add_tool(DelayTool); thread.add_tool(DelayTool);
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
[ [
"Now call the delay tool with 200ms.", "Now call the delay tool with 200ms.",
"When the timer goes off, then you echo the output of the tool.", "When the timer goes off, then you echo the output of the tool.",
@ -363,7 +361,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(WordListTool); thread.add_tool(WordListTool);
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) thread.send(new_prompt_id(), ["Test the word_list tool."], cx)
}) })
.unwrap(); .unwrap();
@ -414,7 +412,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission); thread.add_tool(ToolRequiringPermission);
thread.send(UserMessageId::new(), ["abc"], cx) thread.send(new_prompt_id(), ["abc"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -544,9 +542,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
let fake_model = model.as_fake(); let fake_model = model.as_fake();
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx))
thread.send(UserMessageId::new(), ["abc"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@ -575,7 +571,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx) thread.send(new_prompt_id(), ["abc"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -684,7 +680,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["abc"], cx) thread.send(new_prompt_id(), ["abc"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -718,7 +714,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), vec!["ghi"], cx) thread.send(new_prompt_id(), vec!["ghi"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -818,7 +814,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(DelayTool); thread.add_tool(DelayTool);
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
[ [
"Call the delay tool twice in the same message.", "Call the delay tool twice in the same message.",
"Once with 100ms. Once with 300ms.", "Once with 100ms. Once with 300ms.",
@ -898,7 +894,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-1".into())); thread.set_profile(AgentProfileId("test-1".into()));
thread.send(UserMessageId::new(), ["test"], cx) thread.send(new_prompt_id(), ["test"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -918,7 +914,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-2".into())); thread.set_profile(AgentProfileId("test-2".into()));
thread.send(UserMessageId::new(), ["test2"], cx) thread.send(new_prompt_id(), ["test2"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -986,7 +982,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
); );
let events = thread.update(cx, |thread, cx| { let events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hey"], cx).unwrap() thread.send(new_prompt_id(), ["Hey"], cx).unwrap()
}); });
cx.run_until_parked(); cx.run_until_parked();
@ -1028,7 +1024,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) {
// Send again after adding the echo tool, ensuring the name collision is resolved. // Send again after adding the echo tool, ensuring the name collision is resolved.
let events = thread.update(cx, |thread, cx| { let events = thread.update(cx, |thread, cx| {
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["Go"], cx).unwrap() thread.send(new_prompt_id(), ["Go"], cx).unwrap()
}); });
cx.run_until_parked(); cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap(); let completion = fake_model.pending_completions().pop().unwrap();
@ -1235,9 +1231,7 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
); );
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Go"], cx))
thread.send(UserMessageId::new(), ["Go"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
let completion = fake_model.pending_completions().pop().unwrap(); let completion = fake_model.pending_completions().pop().unwrap();
@ -1271,7 +1265,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
thread.add_tool(InfiniteTool); thread.add_tool(InfiniteTool);
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
["Call the echo tool, then call the infinite tool, then explain their output"], ["Call the echo tool, then call the infinite tool, then explain their output"],
cx, cx,
) )
@ -1327,7 +1321,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send( thread.send(
UserMessageId::new(), new_prompt_id(),
["Testing: reply with 'Hello' then stop."], ["Testing: reply with 'Hello' then stop."],
cx, cx,
) )
@ -1353,7 +1347,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let events_1 = thread let events_1 = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 1"], cx) thread.send(new_prompt_id(), ["Hello 1"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1362,7 +1356,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let events_2 = thread let events_2 = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 2"], cx) thread.send(new_prompt_id(), ["Hello 2"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1384,7 +1378,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
let events_1 = thread let events_1 = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 1"], cx) thread.send(new_prompt_id(), ["Hello 1"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1396,7 +1390,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
let events_2 = thread let events_2 = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 2"], cx) thread.send(new_prompt_id(), ["Hello 2"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1416,9 +1410,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
let fake_model = model.as_fake(); let fake_model = model.as_fake();
let events = thread let events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
@ -1464,7 +1456,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake(); let fake_model = model.as_fake();
let message_id = UserMessageId::new(); let message_id = new_prompt_id();
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(message_id.clone(), ["Hello"], cx) thread.send(message_id.clone(), ["Hello"], cx)
@ -1516,7 +1508,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
}); });
thread thread
.update(cx, |thread, cx| thread.truncate(message_id, cx)) .update(cx, |thread, cx| thread.rewind(message_id, cx))
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
thread.read_with(cx, |thread, _| { thread.read_with(cx, |thread, _| {
@ -1526,9 +1518,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
// Ensure we can still send a new message after truncation. // Ensure we can still send a new message after truncation.
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hi"], cx))
thread.send(UserMessageId::new(), ["Hi"], cx)
})
.unwrap(); .unwrap();
thread.update(cx, |thread, _cx| { thread.update(cx, |thread, _cx| {
assert_eq!( assert_eq!(
@ -1582,7 +1572,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Message 1"], cx) thread.send(new_prompt_id(), ["Message 1"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1625,7 +1615,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
assert_first_message_state(cx); assert_first_message_state(cx);
let second_message_id = UserMessageId::new(); let second_message_id = new_prompt_id();
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(second_message_id.clone(), ["Message 2"], cx) thread.send(second_message_id.clone(), ["Message 2"], cx)
@ -1677,7 +1667,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
}); });
thread thread
.update(cx, |thread, cx| thread.truncate(second_message_id, cx)) .update(cx, |thread, cx| thread.rewind(second_message_id, cx))
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1696,9 +1686,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
}); });
let send = thread let send = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx))
thread.send(UserMessageId::new(), ["Hello"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1719,7 +1707,7 @@ async fn test_title_generation(cx: &mut TestAppContext) {
// Send another message, ensuring no title is generated this time. // Send another message, ensuring no title is generated this time.
let send = thread let send = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello again"], cx) thread.send(new_prompt_id(), ["Hello again"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1740,7 +1728,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission); thread.add_tool(ToolRequiringPermission);
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["Hey!"], cx) thread.send(new_prompt_id(), ["Hey!"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -1892,7 +1880,9 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let model = model.as_fake(); let model = model.as_fake();
assert_eq!(model.id().0, "fake", "should return default model"); assert_eq!(model.id().0, "fake", "should return default model");
let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); let request = acp_thread.update(cx, |thread, cx| {
thread.send(new_prompt_id(), vec!["abc".into()], cx)
});
cx.run_until_parked(); cx.run_until_parked();
model.send_last_completion_stream_text_chunk("def"); model.send_last_completion_stream_text_chunk("def");
cx.run_until_parked(); cx.run_until_parked();
@ -1922,8 +1912,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let result = cx let result = cx
.update(|cx| { .update(|cx| {
connection.prompt( connection.prompt(
Some(acp_thread::UserMessageId::new()),
acp::PromptRequest { acp::PromptRequest {
prompt_id: Some(new_prompt_id()),
session_id: session_id.clone(), session_id: session_id.clone(),
prompt: vec!["ghi".into()], prompt: vec!["ghi".into()],
}, },
@ -1946,9 +1936,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
let fake_model = model.as_fake(); let fake_model = model.as_fake();
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Think"], cx))
thread.send(UserMessageId::new(), ["Think"], cx)
})
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -2049,7 +2037,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(new_prompt_id(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -2093,7 +2081,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(new_prompt_id(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -2159,7 +2147,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.add_tool(EchoTool); thread.add_tool(EchoTool);
thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) thread.send(new_prompt_id(), ["Call the echo tool!"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
@ -2234,7 +2222,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
let mut events = thread let mut events = thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
thread.send(UserMessageId::new(), ["Hello!"], cx) thread.send(new_prompt_id(), ["Hello!"], cx)
}) })
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();

View file

@ -4,7 +4,7 @@ use crate::{
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate,
Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, Template, Templates, TerminalTool, ThinkingTool, WebSearchTool,
}; };
use acp_thread::{MentionUri, UserMessageId}; use acp_thread::MentionUri;
use action_log::ActionLog; use action_log::ActionLog;
use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot};
use agent_client_protocol as acp; use agent_client_protocol as acp;
@ -137,7 +137,7 @@ impl Message {
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct UserMessage { pub struct UserMessage {
pub id: UserMessageId, pub id: acp::PromptId,
pub content: Vec<UserMessageContent>, pub content: Vec<UserMessageContent>,
} }
@ -564,7 +564,7 @@ pub struct Thread {
pending_message: Option<AgentMessage>, pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>, tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>, request_token_usage: HashMap<acp::PromptId, language_model::TokenUsage>,
#[allow(unused)] #[allow(unused)]
cumulative_token_usage: TokenUsage, cumulative_token_usage: TokenUsage,
#[allow(unused)] #[allow(unused)]
@ -1070,7 +1070,7 @@ impl Thread {
cx.notify(); cx.notify();
} }
pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> { pub fn rewind(&mut self, message_id: acp::PromptId, cx: &mut Context<Self>) -> Result<()> {
self.cancel(cx); self.cancel(cx);
let Some(position) = self.messages.iter().position( let Some(position) = self.messages.iter().position(
|msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
@ -1118,7 +1118,7 @@ impl Thread {
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send<T>( pub fn send<T>(
&mut self, &mut self,
id: UserMessageId, id: acp::PromptId,
content: impl IntoIterator<Item = T>, content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>

View file

@ -29,6 +29,7 @@ pub struct AcpConnection {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>, sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>, auth_methods: Vec<acp::AuthMethod>,
prompt_capabilities: acp::PromptCapabilities, prompt_capabilities: acp::PromptCapabilities,
supports_rewind: bool,
_io_task: Task<Result<()>>, _io_task: Task<Result<()>>,
} }
@ -147,6 +148,7 @@ impl AcpConnection {
server_name, server_name,
sessions, sessions,
prompt_capabilities: response.agent_capabilities.prompt_capabilities, prompt_capabilities: response.agent_capabilities.prompt_capabilities,
supports_rewind: response.agent_capabilities.rewind_session,
_io_task: io_task, _io_task: io_task,
}) })
} }
@ -225,9 +227,22 @@ impl AgentConnection for AcpConnection {
}) })
} }
fn rewind(
&self,
session_id: &agent_client_protocol::SessionId,
_cx: &App,
) -> Option<Rc<dyn acp_thread::AgentSessionRewind>> {
if !self.supports_rewind {
return None;
}
Some(Rc::new(AcpRewinder {
connection: self.connection.clone(),
session_id: session_id.clone(),
}) as _)
}
fn prompt( fn prompt(
&self, &self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {
@ -302,6 +317,25 @@ impl AgentConnection for AcpConnection {
} }
} }
struct AcpRewinder {
connection: Rc<acp::ClientSideConnection>,
session_id: acp::SessionId,
}
impl acp_thread::AgentSessionRewind for AcpRewinder {
fn rewind(&self, prompt_id: acp::PromptId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
let params = acp::RewindRequest {
session_id: self.session_id.clone(),
prompt_id,
};
cx.foreground_executor().spawn(async move {
conn.rewind(params).await?;
Ok(())
})
}
}
struct ClientDelegate { struct ClientDelegate {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>, sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: AsyncApp, cx: AsyncApp,

View file

@ -294,7 +294,6 @@ impl AgentConnection for ClaudeAgentConnection {
fn prompt( fn prompt(
&self, &self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest, params: acp::PromptRequest,
cx: &mut App, cx: &mut App,
) -> Task<Result<acp::PromptResponse>> { ) -> Task<Result<acp::PromptResponse>> {

View file

@ -1,5 +1,5 @@
use crate::AgentServer; use crate::AgentServer;
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus, new_prompt_id};
use agent_client_protocol as acp; use agent_client_protocol as acp;
use futures::{FutureExt, StreamExt, channel::mpsc, select}; use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{AppContext, Entity, TestAppContext}; use gpui::{AppContext, Entity, TestAppContext};
@ -77,6 +77,7 @@ where
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.send( thread.send(
new_prompt_id(),
vec![ vec![
acp::ContentBlock::Text(acp::TextContent { acp::ContentBlock::Text(acp::TextContent {
text: "Read the file ".into(), text: "Read the file ".into(),

View file

@ -67,7 +67,7 @@ impl EntryViewState {
match thread_entry { match thread_entry {
AgentThreadEntry::UserMessage(message) => { AgentThreadEntry::UserMessage(message) => {
let has_id = message.id.is_some(); let has_id = message.prompt_id.is_some();
let chunks = message.chunks.clone(); let chunks = message.chunks.clone();
if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) { if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
if !editor.focus_handle(cx).is_focused(window) { if !editor.focus_handle(cx).is_focused(window) {

View file

@ -1,7 +1,7 @@
use acp_thread::{ use acp_thread::{
AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk,
AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent, AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent,
ToolCallStatus, UserMessageId, ToolCallStatus, new_prompt_id,
}; };
use acp_thread::{AgentConnection, Plan}; use acp_thread::{AgentConnection, Plan};
use action_log::ActionLog; use action_log::ActionLog;
@ -791,7 +791,7 @@ impl AcpThreadView {
if let Some(thread) = self.thread() if let Some(thread) = self.thread()
&& let Some(AgentThreadEntry::UserMessage(user_message)) = && let Some(AgentThreadEntry::UserMessage(user_message)) =
thread.read(cx).entries().get(event.entry_index) thread.read(cx).entries().get(event.entry_index)
&& user_message.id.is_some() && user_message.prompt_id.is_some()
{ {
self.editing_message = Some(event.entry_index); self.editing_message = Some(event.entry_index);
cx.notify(); cx.notify();
@ -801,7 +801,7 @@ impl AcpThreadView {
if let Some(thread) = self.thread() if let Some(thread) = self.thread()
&& let Some(AgentThreadEntry::UserMessage(user_message)) = && let Some(AgentThreadEntry::UserMessage(user_message)) =
thread.read(cx).entries().get(event.entry_index) thread.read(cx).entries().get(event.entry_index)
&& user_message.id.is_some() && user_message.prompt_id.is_some()
{ {
if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) { if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) {
self.editing_message = None; self.editing_message = None;
@ -942,7 +942,7 @@ impl AcpThreadView {
telemetry::event!("Agent Message Sent", agent = agent_telemetry_id); telemetry::event!("Agent Message Sent", agent = agent_telemetry_id);
thread.send(contents, cx) thread.send(new_prompt_id(), contents, cx)
})?; })?;
send.await send.await
}); });
@ -1011,7 +1011,12 @@ impl AcpThreadView {
} }
let Some(user_message_id) = thread.update(cx, |thread, _| { let Some(user_message_id) = thread.update(cx, |thread, _| {
thread.entries().get(entry_ix)?.user_message()?.id.clone() thread
.entries()
.get(entry_ix)?
.user_message()?
.prompt_id
.clone()
}) else { }) else {
return; return;
}; };
@ -1316,7 +1321,7 @@ impl AcpThreadView {
cx.notify(); cx.notify();
} }
fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context<Self>) { fn rewind(&mut self, message_id: &acp::PromptId, cx: &mut Context<Self>) {
let Some(thread) = self.thread() else { let Some(thread) = self.thread() else {
return; return;
}; };
@ -1379,7 +1384,7 @@ impl AcpThreadView {
.gap_1p5() .gap_1p5()
.w_full() .w_full()
.children(rules_item) .children(rules_item)
.children(message.id.clone().and_then(|message_id| { .children(message.prompt_id.clone().and_then(|message_id| {
message.checkpoint.as_ref()?.show.then(|| { message.checkpoint.as_ref()?.show.then(|| {
h_flex() h_flex()
.px_3() .px_3()
@ -1416,7 +1421,7 @@ impl AcpThreadView {
.map(|this|{ .map(|this|{
if editing && editor_focus { if editing && editor_focus {
this.border_color(focus_border) this.border_color(focus_border)
} else if message.id.is_some() { } else if message.prompt_id.is_some() {
this.hover(|s| s.border_color(focus_border.opacity(0.8))) this.hover(|s| s.border_color(focus_border.opacity(0.8)))
} else { } else {
this this
@ -1437,7 +1442,7 @@ impl AcpThreadView {
.bg(cx.theme().colors().editor_background) .bg(cx.theme().colors().editor_background)
.overflow_hidden(); .overflow_hidden();
if message.id.is_some() { if message.prompt_id.is_some() {
this.child( this.child(
base_container base_container
.child( .child(
@ -5398,7 +5403,6 @@ pub(crate) mod tests {
fn prompt( fn prompt(
&self, &self,
_id: Option<acp_thread::UserMessageId>,
_params: acp::PromptRequest, _params: acp::PromptRequest,
_cx: &mut App, _cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> { ) -> Task<gpui::Result<acp::PromptResponse>> {
@ -5544,7 +5548,7 @@ pub(crate) mod tests {
let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
panic!(); panic!();
}; };
user_message.id.clone().unwrap() user_message.prompt_id.clone().unwrap()
}); });
thread_view.read_with(cx, |view, cx| { thread_view.read_with(cx, |view, cx| {