From c985789d8434f235011b57a72ffa05ae0c83fdba Mon Sep 17 00:00:00 2001 From: Conrad Irwin Date: Mon, 25 Aug 2025 20:58:03 -0600 Subject: [PATCH] acp-native-rewind --- Cargo.lock | 4 +- Cargo.toml | 2 +- crates/acp_thread/src/acp_thread.rs | 123 ++++++++++++-------- crates/acp_thread/src/connection.rs | 38 ++---- crates/agent2/src/agent.rs | 15 +-- crates/agent2/src/db.rs | 6 +- crates/agent2/src/tests/mod.rs | 104 ++++++++--------- crates/agent2/src/thread.rs | 10 +- crates/agent_servers/src/acp.rs | 36 +++++- crates/agent_servers/src/claude.rs | 1 - crates/agent_servers/src/e2e_tests.rs | 3 +- crates/agent_ui/src/acp/entry_view_state.rs | 2 +- crates/agent_ui/src/acp/thread_view.rs | 26 +++-- 13 files changed, 202 insertions(+), 168 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 42649b137f..5c24892c8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,9 +191,7 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.0.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "289eb34ee17213dadcca47eedadd386a5e7678094095414e475965d1bcca2860" +version = "0.0.32" dependencies = [ "anyhow", "async-broadcast", diff --git a/Cargo.toml b/Cargo.toml index 6ec243a9b9..46f91a8ea3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -426,7 +426,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agent-client-protocol = "0.0.31" +agent-client-protocol = { path = "../agent-client-protocol"} aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 4ded647a74..17e4922ec2 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -32,10 +32,15 @@ use std::time::{Duration, Instant}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; use util::ResultExt; +use uuid::Uuid; + +pub fn new_prompt_id() -> acp::PromptId { + acp::PromptId(Uuid::new_v4().to_string().into()) +} #[derive(Debug)] pub struct UserMessage { - pub id: Option, + pub prompt_id: Option, pub content: ContentBlock, pub chunks: Vec, pub checkpoint: Option, @@ -962,7 +967,7 @@ impl AcpThread { pub fn push_user_content_block( &mut self, - message_id: Option, + prompt_id: Option, chunk: acp::ContentBlock, cx: &mut Context, ) { @@ -971,13 +976,13 @@ impl AcpThread { if let Some(last_entry) = self.entries.last_mut() && let AgentThreadEntry::UserMessage(UserMessage { - id, + prompt_id: id, content, chunks, .. }) = last_entry { - *id = message_id.or(id.take()); + *id = prompt_id.or(id.take()); content.append(chunk.clone(), &language_registry, cx); chunks.push(chunk); let idx = entries_len - 1; @@ -986,7 +991,7 @@ impl AcpThread { let content = ContentBlock::new(chunk.clone(), &language_registry, cx); self.push_entry( AgentThreadEntry::UserMessage(UserMessage { - id: message_id, + prompt_id, content, chunks: vec![chunk], checkpoint: None, @@ -1336,6 +1341,7 @@ impl AcpThread { cx: &mut Context, ) -> BoxFuture<'static, Result<()>> { self.send( + new_prompt_id(), vec![acp::ContentBlock::Text(acp::TextContent { text: message.to_string(), annotations: None, @@ -1346,6 +1352,7 @@ impl AcpThread { pub fn send( &mut self, + prompt_id: acp::PromptId, message: Vec, cx: &mut Context, ) -> BoxFuture<'static, Result<()>> { @@ -1355,22 +1362,17 @@ impl AcpThread { cx, ); let request = acp::PromptRequest { + prompt_id: Some(prompt_id.clone()), prompt: message.clone(), session_id: self.session_id.clone(), }; let git_store = self.project.read(cx).git_store().clone(); - let message_id = if self.connection.truncate(&self.session_id, cx).is_some() { - Some(UserMessageId::new()) - } else { - None - }; - self.run_turn(cx, async move |this, cx| { this.update(cx, |this, cx| { this.push_entry( AgentThreadEntry::UserMessage(UserMessage { - id: message_id.clone(), + prompt_id: Some(prompt_id), content: block, chunks: message, checkpoint: None, @@ -1392,7 +1394,7 @@ impl AcpThread { show: false, }); } - this.connection.prompt(message_id, request, cx) + this.connection.prompt(request, cx) })? .await }) @@ -1509,8 +1511,8 @@ impl AcpThread { /// Rewinds this thread to before the entry at `index`, removing it and all /// subsequent entries while reverting any changes made from that point. - pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context) -> Task> { - let Some(truncate) = self.connection.truncate(&self.session_id, cx) else { + pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context) -> Task> { + let Some(rewind) = self.connection.rewind(&self.session_id, cx) else { return Task::ready(Err(anyhow!("not supported"))); }; let Some(message) = self.user_message(&id) else { @@ -1530,7 +1532,7 @@ impl AcpThread { .await?; } - cx.update(|cx| truncate.run(id.clone(), cx))?.await?; + cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?; this.update(cx, |this, cx| { if let Some((ix, _)) = this.user_message_mut(&id) { let range = ix..this.entries.len(); @@ -1594,10 +1596,10 @@ impl AcpThread { }) } - fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> { + fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> { self.entries.iter().find_map(|entry| { if let AgentThreadEntry::UserMessage(message) = entry { - if message.id.as_ref() == Some(id) { + if message.prompt_id.as_ref() == Some(id) { Some(message) } else { None @@ -1608,10 +1610,10 @@ impl AcpThread { }) } - fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> { + fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> { self.entries.iter_mut().enumerate().find_map(|(ix, entry)| { if let AgentThreadEntry::UserMessage(message) = entry { - if message.id.as_ref() == Some(id) { + if message.prompt_id.as_ref() == Some(id) { Some((ix, message)) } else { None @@ -1905,7 +1907,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { - assert_eq!(user_msg.id, None); + assert_eq!(user_msg.prompt_id, None); assert_eq!(user_msg.content.to_markdown(cx), "Hello, "); } else { panic!("Expected UserMessage"); @@ -1913,7 +1915,7 @@ mod tests { }); // Test appending to existing user message - let message_1_id = UserMessageId::new(); + let message_1_id = new_prompt_id(); thread.update(cx, |thread, cx| { thread.push_user_content_block( Some(message_1_id.clone()), @@ -1928,7 +1930,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 1); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] { - assert_eq!(user_msg.id, Some(message_1_id)); + assert_eq!(user_msg.prompt_id, Some(message_1_id)); assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!"); } else { panic!("Expected UserMessage"); @@ -1947,7 +1949,7 @@ mod tests { ); }); - let message_2_id = UserMessageId::new(); + let message_2_id = new_prompt_id(); thread.update(cx, |thread, cx| { thread.push_user_content_block( Some(message_2_id.clone()), @@ -1962,7 +1964,7 @@ mod tests { thread.update(cx, |thread, cx| { assert_eq!(thread.entries.len(), 3); if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] { - assert_eq!(user_msg.id, Some(message_2_id)); + assert_eq!(user_msg.prompt_id, Some(message_2_id)); assert_eq!(user_msg.content.to_markdown(cx), "New user message"); } else { panic!("Expected UserMessage at index 2"); @@ -2259,9 +2261,13 @@ mod tests { .await .unwrap(); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["Hi".into()], cx) + }) + }) + .await + .unwrap(); assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls())); } @@ -2320,9 +2326,13 @@ mod tests { .await .unwrap(); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["Lorem".into()], cx) + }) + }) + .await + .unwrap(); thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), @@ -2340,9 +2350,13 @@ mod tests { }); assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["ipsum".into()], cx) + }) + }) + .await + .unwrap(); thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), @@ -2376,9 +2390,13 @@ mod tests { // Checkpoint isn't stored when there are no changes. simulate_changes.store(false, SeqCst); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["dolor".into()], cx) + }) + }) + .await + .unwrap(); thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), @@ -2424,7 +2442,7 @@ mod tests { let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else { panic!("unexpected entries {:?}", thread.entries) }; - thread.rewind(message.id.clone().unwrap(), cx) + thread.rewind(message.prompt_id.clone().unwrap(), cx) }) .await .unwrap(); @@ -2490,9 +2508,13 @@ mod tests { .await .unwrap(); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["hello".into()], cx) + }) + }) + .await + .unwrap(); thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), @@ -2512,9 +2534,13 @@ mod tests { // Simulate refusing the second message, ensuring the conversation gets // truncated to before sending it. refuse_next.store(true, SeqCst); - cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx))) - .await - .unwrap(); + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["world".into()], cx) + }) + }) + .await + .unwrap(); thread.read_with(cx, |thread, cx| { assert_eq!( thread.to_markdown(cx), @@ -2653,7 +2679,6 @@ mod tests { fn prompt( &self, - _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -2683,11 +2708,11 @@ mod tests { .detach(); } - fn truncate( + fn rewind( &self, session_id: &acp::SessionId, _cx: &App, - ) -> Option> { + ) -> Option> { Some(Rc::new(FakeAgentSessionEditor { _session_id: session_id.clone(), })) @@ -2702,8 +2727,8 @@ mod tests { _session_id: acp::SessionId, } - impl AgentSessionTruncate for FakeAgentSessionEditor { - fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task> { + impl AgentSessionRewind for FakeAgentSessionEditor { + fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task> { Task::ready(Ok(())) } } diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index af229b7545..ef2e24ae8f 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -5,19 +5,8 @@ use collections::IndexMap; use gpui::{Entity, SharedString, Task}; use language_model::LanguageModelProviderId; use project::Project; -use serde::{Deserialize, Serialize}; -use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; +use std::{any::Any, error::Error, fmt, path::Path, rc::Rc}; use ui::{App, IconName}; -use uuid::Uuid; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub struct UserMessageId(Arc); - -impl UserMessageId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} pub trait AgentConnection { fn new_thread( @@ -31,12 +20,8 @@ pub trait AgentConnection { fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task>; - fn prompt( - &self, - user_message_id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task>; + fn prompt(&self, params: acp::PromptRequest, cx: &mut App) + -> Task>; fn resume( &self, @@ -48,11 +33,11 @@ pub trait AgentConnection { fn cancel(&self, session_id: &acp::SessionId, cx: &mut App); - fn truncate( + fn rewind( &self, _session_id: &acp::SessionId, _cx: &App, - ) -> Option> { + ) -> Option> { None } @@ -85,8 +70,8 @@ impl dyn AgentConnection { } } -pub trait AgentSessionTruncate { - fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task>; +pub trait AgentSessionRewind { + fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task>; } pub trait AgentSessionResume { @@ -362,7 +347,6 @@ mod test_support { fn prompt( &self, - _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -432,11 +416,11 @@ mod test_support { } } - fn truncate( + fn rewind( &self, _session_id: &agent_client_protocol::SessionId, _cx: &App, - ) -> Option> { + ) -> Option> { Some(Rc::new(StubAgentSessionEditor)) } @@ -447,8 +431,8 @@ mod test_support { struct StubAgentSessionEditor; - impl AgentSessionTruncate for StubAgentSessionEditor { - fn run(&self, _: UserMessageId, _: &mut App) -> Task> { + impl AgentSessionRewind for StubAgentSessionEditor { + fn rewind(&self, _: acp::PromptId, _: &mut App) -> Task> { Task::ready(Ok(())) } } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index ecfaea4b49..cabe6c3a33 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -905,11 +905,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection { fn prompt( &self, - id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { - let id = id.expect("UserMessageId is required"); + let id = params.prompt_id.expect("UserMessageId is required"); let session_id = params.session_id.clone(); log::info!("Received prompt request for session: {}", session_id); log::debug!("Prompt blocks count: {}", params.prompt.len()); @@ -948,11 +947,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { }); } - fn truncate( + fn rewind( &self, session_id: &agent_client_protocol::SessionId, cx: &App, - ) -> Option> { + ) -> Option> { self.0.read_with(cx, |agent, _cx| { agent.sessions.get(session_id).map(|session| { Rc::new(NativeAgentSessionEditor { @@ -1009,10 +1008,10 @@ struct NativeAgentSessionEditor { acp_thread: WeakEntity, } -impl acp_thread::AgentSessionTruncate for NativeAgentSessionEditor { - fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { +impl acp_thread::AgentSessionRewind for NativeAgentSessionEditor { + fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task> { match self.thread.update(cx, |thread, cx| { - thread.truncate(message_id.clone(), cx)?; + thread.rewind(message_id.clone(), cx)?; Ok(thread.latest_token_usage()) }) { Ok(usage) => { @@ -1065,6 +1064,7 @@ mod tests { use super::*; use acp_thread::{ AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri, + new_prompt_id, }; use fs::FakeFs; use gpui::TestAppContext; @@ -1311,6 +1311,7 @@ mod tests { let send = acp_thread.update(cx, |thread, cx| { thread.send( + new_prompt_id(), vec![ "What does ".into(), acp::ContentBlock::ResourceLink(acp::ResourceLink { diff --git a/crates/agent2/src/db.rs b/crates/agent2/src/db.rs index e7d31c0c7a..e7cd2b0591 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent2/src/db.rs @@ -1,5 +1,5 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; -use acp_thread::UserMessageId; +use acp_thread::new_prompt_id; use agent::{thread::DetailedSummaryState, thread_store}; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; @@ -43,7 +43,7 @@ pub struct DbThread { #[serde(default)] pub cumulative_token_usage: language_model::TokenUsage, #[serde(default)] - pub request_token_usage: HashMap, + pub request_token_usage: HashMap, #[serde(default)] pub model: Option, #[serde(default)] @@ -97,7 +97,7 @@ impl DbThread { content.push(UserMessageContent::Text(msg.context)); } - let id = UserMessageId::new(); + let id = new_prompt_id(); last_user_message_id = Some(id.clone()); crate::Message::User(UserMessage { diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 864fbf8b10..845eb82a99 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,5 +1,5 @@ use super::*; -use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId}; +use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, new_prompt_id}; use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; use anyhow::Result; @@ -48,7 +48,7 @@ async fn test_echo(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx) + thread.send(new_prompt_id(), ["Testing: Reply with 'Hello'"], cx) }) .unwrap(); cx.run_until_parked(); @@ -79,7 +79,7 @@ async fn test_thinking(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.send( - UserMessageId::new(), + new_prompt_id(), [indoc! {" Testing: @@ -130,9 +130,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) { }); thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["abc"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx)) .unwrap(); cx.run_until_parked(); let mut pending_completions = fake_model.pending_completions(); @@ -168,7 +166,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { // Send initial user message and verify it's cached thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Message 1"], cx) + thread.send(new_prompt_id(), ["Message 1"], cx) }) .unwrap(); cx.run_until_parked(); @@ -191,7 +189,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { // Send another user message and verify only the latest is cached thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Message 2"], cx) + thread.send(new_prompt_id(), ["Message 2"], cx) }) .unwrap(); cx.run_until_parked(); @@ -227,7 +225,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { thread.update(cx, |thread, _| thread.add_tool(EchoTool)); thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Use the echo tool"], cx) + thread.send(new_prompt_id(), ["Use the echo tool"], cx) }) .unwrap(); cx.run_until_parked(); @@ -304,7 +302,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(EchoTool); thread.send( - UserMessageId::new(), + new_prompt_id(), ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."], cx, ) @@ -320,7 +318,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) { thread.remove_tool(&EchoTool::name()); thread.add_tool(DelayTool); thread.send( - UserMessageId::new(), + new_prompt_id(), [ "Now call the delay tool with 200ms.", "When the timer goes off, then you echo the output of the tool.", @@ -363,7 +361,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { thread.add_tool(WordListTool); - thread.send(UserMessageId::new(), ["Test the word_list tool."], cx) + thread.send(new_prompt_id(), ["Test the word_list tool."], cx) }) .unwrap(); @@ -414,7 +412,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { thread.add_tool(ToolRequiringPermission); - thread.send(UserMessageId::new(), ["abc"], cx) + thread.send(new_prompt_id(), ["abc"], cx) }) .unwrap(); cx.run_until_parked(); @@ -544,9 +542,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { let fake_model = model.as_fake(); let mut events = thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["abc"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["abc"], cx)) .unwrap(); cx.run_until_parked(); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -575,7 +571,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.add_tool(EchoTool); - thread.send(UserMessageId::new(), ["abc"], cx) + thread.send(new_prompt_id(), ["abc"], cx) }) .unwrap(); cx.run_until_parked(); @@ -684,7 +680,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.add_tool(EchoTool); - thread.send(UserMessageId::new(), ["abc"], cx) + thread.send(new_prompt_id(), ["abc"], cx) }) .unwrap(); cx.run_until_parked(); @@ -718,7 +714,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), vec!["ghi"], cx) + thread.send(new_prompt_id(), vec!["ghi"], cx) }) .unwrap(); cx.run_until_parked(); @@ -818,7 +814,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(DelayTool); thread.send( - UserMessageId::new(), + new_prompt_id(), [ "Call the delay tool twice in the same message.", "Once with 100ms. Once with 300ms.", @@ -898,7 +894,7 @@ async fn test_profiles(cx: &mut TestAppContext) { thread .update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-1".into())); - thread.send(UserMessageId::new(), ["test"], cx) + thread.send(new_prompt_id(), ["test"], cx) }) .unwrap(); cx.run_until_parked(); @@ -918,7 +914,7 @@ async fn test_profiles(cx: &mut TestAppContext) { thread .update(cx, |thread, cx| { thread.set_profile(AgentProfileId("test-2".into())); - thread.send(UserMessageId::new(), ["test2"], cx) + thread.send(new_prompt_id(), ["test2"], cx) }) .unwrap(); cx.run_until_parked(); @@ -986,7 +982,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { ); let events = thread.update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hey"], cx).unwrap() + thread.send(new_prompt_id(), ["Hey"], cx).unwrap() }); cx.run_until_parked(); @@ -1028,7 +1024,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { // Send again after adding the echo tool, ensuring the name collision is resolved. let events = thread.update(cx, |thread, cx| { thread.add_tool(EchoTool); - thread.send(UserMessageId::new(), ["Go"], cx).unwrap() + thread.send(new_prompt_id(), ["Go"], cx).unwrap() }); cx.run_until_parked(); let completion = fake_model.pending_completions().pop().unwrap(); @@ -1235,9 +1231,7 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { ); thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Go"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Go"], cx)) .unwrap(); cx.run_until_parked(); let completion = fake_model.pending_completions().pop().unwrap(); @@ -1271,7 +1265,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { thread.add_tool(InfiniteTool); thread.add_tool(EchoTool); thread.send( - UserMessageId::new(), + new_prompt_id(), ["Call the echo tool, then call the infinite tool, then explain their output"], cx, ) @@ -1327,7 +1321,7 @@ async fn test_cancellation(cx: &mut TestAppContext) { let events = thread .update(cx, |thread, cx| { thread.send( - UserMessageId::new(), + new_prompt_id(), ["Testing: reply with 'Hello' then stop."], cx, ) @@ -1353,7 +1347,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { let events_1 = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello 1"], cx) + thread.send(new_prompt_id(), ["Hello 1"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1362,7 +1356,7 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { let events_2 = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello 2"], cx) + thread.send(new_prompt_id(), ["Hello 2"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1384,7 +1378,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { let events_1 = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello 1"], cx) + thread.send(new_prompt_id(), ["Hello 1"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1396,7 +1390,7 @@ async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { let events_2 = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello 2"], cx) + thread.send(new_prompt_id(), ["Hello 2"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1416,9 +1410,7 @@ async fn test_refusal(cx: &mut TestAppContext) { let fake_model = model.as_fake(); let events = thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { @@ -1464,7 +1456,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let fake_model = model.as_fake(); - let message_id = UserMessageId::new(); + let message_id = new_prompt_id(); thread .update(cx, |thread, cx| { thread.send(message_id.clone(), ["Hello"], cx) @@ -1516,7 +1508,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, cx| thread.truncate(message_id, cx)) + .update(cx, |thread, cx| thread.rewind(message_id, cx)) .unwrap(); cx.run_until_parked(); thread.read_with(cx, |thread, _| { @@ -1526,9 +1518,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) { // Ensure we can still send a new message after truncation. thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hi"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hi"], cx)) .unwrap(); thread.update(cx, |thread, _cx| { assert_eq!( @@ -1582,7 +1572,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) { thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Message 1"], cx) + thread.send(new_prompt_id(), ["Message 1"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1625,7 +1615,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) { assert_first_message_state(cx); - let second_message_id = UserMessageId::new(); + let second_message_id = new_prompt_id(); thread .update(cx, |thread, cx| { thread.send(second_message_id.clone(), ["Message 2"], cx) @@ -1677,7 +1667,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) { }); thread - .update(cx, |thread, cx| thread.truncate(second_message_id, cx)) + .update(cx, |thread, cx| thread.rewind(second_message_id, cx)) .unwrap(); cx.run_until_parked(); @@ -1696,9 +1686,7 @@ async fn test_title_generation(cx: &mut TestAppContext) { }); let send = thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Hello"], cx)) .unwrap(); cx.run_until_parked(); @@ -1719,7 +1707,7 @@ async fn test_title_generation(cx: &mut TestAppContext) { // Send another message, ensuring no title is generated this time. let send = thread .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Hello again"], cx) + thread.send(new_prompt_id(), ["Hello again"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1740,7 +1728,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.add_tool(ToolRequiringPermission); thread.add_tool(EchoTool); - thread.send(UserMessageId::new(), ["Hey!"], cx) + thread.send(new_prompt_id(), ["Hey!"], cx) }) .unwrap(); cx.run_until_parked(); @@ -1892,7 +1880,9 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let model = model.as_fake(); assert_eq!(model.id().0, "fake", "should return default model"); - let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx)); + let request = acp_thread.update(cx, |thread, cx| { + thread.send(new_prompt_id(), vec!["abc".into()], cx) + }); cx.run_until_parked(); model.send_last_completion_stream_text_chunk("def"); cx.run_until_parked(); @@ -1922,8 +1912,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let result = cx .update(|cx| { connection.prompt( - Some(acp_thread::UserMessageId::new()), acp::PromptRequest { + prompt_id: Some(new_prompt_id()), session_id: session_id.clone(), prompt: vec!["ghi".into()], }, @@ -1946,9 +1936,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { let fake_model = model.as_fake(); let mut events = thread - .update(cx, |thread, cx| { - thread.send(UserMessageId::new(), ["Think"], cx) - }) + .update(cx, |thread, cx| thread.send(new_prompt_id(), ["Think"], cx)) .unwrap(); cx.run_until_parked(); @@ -2049,7 +2037,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); - thread.send(UserMessageId::new(), ["Hello!"], cx) + thread.send(new_prompt_id(), ["Hello!"], cx) }) .unwrap(); cx.run_until_parked(); @@ -2093,7 +2081,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); - thread.send(UserMessageId::new(), ["Hello!"], cx) + thread.send(new_prompt_id(), ["Hello!"], cx) }) .unwrap(); cx.run_until_parked(); @@ -2159,7 +2147,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { .update(cx, |thread, cx| { thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); thread.add_tool(EchoTool); - thread.send(UserMessageId::new(), ["Call the echo tool!"], cx) + thread.send(new_prompt_id(), ["Call the echo tool!"], cx) }) .unwrap(); cx.run_until_parked(); @@ -2234,7 +2222,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) { let mut events = thread .update(cx, |thread, cx| { thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx); - thread.send(UserMessageId::new(), ["Hello!"], cx) + thread.send(new_prompt_id(), ["Hello!"], cx) }) .unwrap(); cx.run_until_parked(); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 1b1c014b79..64a84bcbc4 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -4,7 +4,7 @@ use crate::{ ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate, Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, }; -use acp_thread::{MentionUri, UserMessageId}; +use acp_thread::MentionUri; use action_log::ActionLog; use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; @@ -137,7 +137,7 @@ impl Message { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UserMessage { - pub id: UserMessageId, + pub id: acp::PromptId, pub content: Vec, } @@ -564,7 +564,7 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, - request_token_usage: HashMap, + request_token_usage: HashMap, #[allow(unused)] cumulative_token_usage: TokenUsage, #[allow(unused)] @@ -1070,7 +1070,7 @@ impl Thread { cx.notify(); } - pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + pub fn rewind(&mut self, message_id: acp::PromptId, cx: &mut Context) -> Result<()> { self.cancel(cx); let Some(position) = self.messages.iter().position( |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), @@ -1118,7 +1118,7 @@ impl Thread { /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. pub fn send( &mut self, - id: UserMessageId, + id: acp::PromptId, content: impl IntoIterator, cx: &mut Context, ) -> Result>> diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 9080fc1ab0..a253d81ad2 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -29,6 +29,7 @@ pub struct AcpConnection { sessions: Rc>>, auth_methods: Vec, prompt_capabilities: acp::PromptCapabilities, + supports_rewind: bool, _io_task: Task>, } @@ -147,6 +148,7 @@ impl AcpConnection { server_name, sessions, prompt_capabilities: response.agent_capabilities.prompt_capabilities, + supports_rewind: response.agent_capabilities.rewind_session, _io_task: io_task, }) } @@ -225,9 +227,22 @@ impl AgentConnection for AcpConnection { }) } + fn rewind( + &self, + session_id: &agent_client_protocol::SessionId, + _cx: &App, + ) -> Option> { + if !self.supports_rewind { + return None; + } + Some(Rc::new(AcpRewinder { + connection: self.connection.clone(), + session_id: session_id.clone(), + }) as _) + } + fn prompt( &self, - _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { @@ -302,6 +317,25 @@ impl AgentConnection for AcpConnection { } } +struct AcpRewinder { + connection: Rc, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionRewind for AcpRewinder { + fn rewind(&self, prompt_id: acp::PromptId, cx: &mut App) -> Task> { + let conn = self.connection.clone(); + let params = acp::RewindRequest { + session_id: self.session_id.clone(), + prompt_id, + }; + cx.foreground_executor().spawn(async move { + conn.rewind(params).await?; + Ok(()) + }) + } +} + struct ClientDelegate { sessions: Rc>>, cx: AsyncApp, diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 250e564526..332aa3f6b0 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -294,7 +294,6 @@ impl AgentConnection for ClaudeAgentConnection { fn prompt( &self, - _id: Option, params: acp::PromptRequest, cx: &mut App, ) -> Task> { diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 42264b4b4f..1897f44948 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,5 +1,5 @@ use crate::AgentServer; -use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; +use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus, new_prompt_id}; use agent_client_protocol as acp; use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::{AppContext, Entity, TestAppContext}; @@ -77,6 +77,7 @@ where thread .update(cx, |thread, cx| { thread.send( + new_prompt_id(), vec![ acp::ContentBlock::Text(acp::TextContent { text: "Read the file ".into(), diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 0e4080d689..d65304efba 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -67,7 +67,7 @@ impl EntryViewState { match thread_entry { AgentThreadEntry::UserMessage(message) => { - let has_id = message.id.is_some(); + let has_id = message.prompt_id.is_some(); let chunks = message.chunks.clone(); if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) { if !editor.focus_handle(cx).is_focused(window) { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 837ce6f90a..0e1d3fcc95 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -1,7 +1,7 @@ use acp_thread::{ AcpThread, AcpThreadEvent, AgentThreadEntry, AssistantMessage, AssistantMessageChunk, AuthRequired, LoadError, MentionUri, RetryStatus, ThreadStatus, ToolCall, ToolCallContent, - ToolCallStatus, UserMessageId, + ToolCallStatus, new_prompt_id, }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; @@ -791,7 +791,7 @@ impl AcpThreadView { if let Some(thread) = self.thread() && let Some(AgentThreadEntry::UserMessage(user_message)) = thread.read(cx).entries().get(event.entry_index) - && user_message.id.is_some() + && user_message.prompt_id.is_some() { self.editing_message = Some(event.entry_index); cx.notify(); @@ -801,7 +801,7 @@ impl AcpThreadView { if let Some(thread) = self.thread() && let Some(AgentThreadEntry::UserMessage(user_message)) = thread.read(cx).entries().get(event.entry_index) - && user_message.id.is_some() + && user_message.prompt_id.is_some() { if editor.read(cx).text(cx).as_str() == user_message.content.to_markdown(cx) { self.editing_message = None; @@ -942,7 +942,7 @@ impl AcpThreadView { telemetry::event!("Agent Message Sent", agent = agent_telemetry_id); - thread.send(contents, cx) + thread.send(new_prompt_id(), contents, cx) })?; send.await }); @@ -1011,7 +1011,12 @@ impl AcpThreadView { } let Some(user_message_id) = thread.update(cx, |thread, _| { - thread.entries().get(entry_ix)?.user_message()?.id.clone() + thread + .entries() + .get(entry_ix)? + .user_message()? + .prompt_id + .clone() }) else { return; }; @@ -1316,7 +1321,7 @@ impl AcpThreadView { cx.notify(); } - fn rewind(&mut self, message_id: &UserMessageId, cx: &mut Context) { + fn rewind(&mut self, message_id: &acp::PromptId, cx: &mut Context) { let Some(thread) = self.thread() else { return; }; @@ -1379,7 +1384,7 @@ impl AcpThreadView { .gap_1p5() .w_full() .children(rules_item) - .children(message.id.clone().and_then(|message_id| { + .children(message.prompt_id.clone().and_then(|message_id| { message.checkpoint.as_ref()?.show.then(|| { h_flex() .px_3() @@ -1416,7 +1421,7 @@ impl AcpThreadView { .map(|this|{ if editing && editor_focus { this.border_color(focus_border) - } else if message.id.is_some() { + } else if message.prompt_id.is_some() { this.hover(|s| s.border_color(focus_border.opacity(0.8))) } else { this @@ -1437,7 +1442,7 @@ impl AcpThreadView { .bg(cx.theme().colors().editor_background) .overflow_hidden(); - if message.id.is_some() { + if message.prompt_id.is_some() { this.child( base_container .child( @@ -5398,7 +5403,6 @@ pub(crate) mod tests { fn prompt( &self, - _id: Option, _params: acp::PromptRequest, _cx: &mut App, ) -> Task> { @@ -5544,7 +5548,7 @@ pub(crate) mod tests { let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else { panic!(); }; - user_message.id.clone().unwrap() + user_message.prompt_id.clone().unwrap() }); thread_view.read_with(cx, |view, cx| {