From 48fed866e60f1951bd8aa6ccec000670ce839b7f Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 18 Aug 2025 12:34:27 -0300 Subject: [PATCH] acp: Have `AcpThread` handle all interrupting (#36417) The view was cancelling the generation, but `AcpThread` already handles that, so we removed that extra code and fixed a bug where an update from the first user message would appear after the second one. Release Notes: - N/A Co-authored-by: Danilo --- crates/acp_thread/src/acp_thread.rs | 22 ++-- crates/acp_thread/src/connection.rs | 27 +++-- crates/agent_ui/src/acp/thread_view.rs | 135 ++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 20 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 3762c553cc..e104c40bf2 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1200,17 +1200,21 @@ impl AcpThread { } else { None }; - self.push_entry( - AgentThreadEntry::UserMessage(UserMessage { - id: message_id.clone(), - content: block, - chunks: message, - checkpoint: None, - }), - cx, - ); self.run_turn(cx, async move |this, cx| { + this.update(cx, |this, cx| { + this.push_entry( + AgentThreadEntry::UserMessage(UserMessage { + id: message_id.clone(), + content: block, + chunks: message, + checkpoint: None, + }), + cx, + ); + }) + .ok(); + let old_checkpoint = git_store .update(cx, |git, cx| git.checkpoint(cx))? .await diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 48310f07ce..a328499bbc 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -201,7 +201,7 @@ mod test_support { struct Session { thread: WeakEntity, - response_tx: Option>, + response_tx: Option>, } impl StubAgentConnection { @@ -242,12 +242,12 @@ mod test_support { .unwrap() .thread .update(cx, |thread, cx| { - thread.handle_session_update(update.clone(), cx).unwrap(); + thread.handle_session_update(update, cx).unwrap(); }) .unwrap(); } - pub fn end_turn(&self, session_id: acp::SessionId) { + pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) { self.sessions .lock() .get_mut(&session_id) @@ -255,7 +255,7 @@ mod test_support { .response_tx .take() .expect("No pending turn") - .send(()) + .send(stop_reason) .unwrap(); } } @@ -308,10 +308,8 @@ mod test_support { let (tx, rx) = oneshot::channel(); response_tx.replace(tx); cx.spawn(async move |_| { - rx.await?; - Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - }) + let stop_reason = rx.await?; + Ok(acp::PromptResponse { stop_reason }) }) } else { for update in self.next_prompt_updates.lock().drain(..) { @@ -353,8 +351,17 @@ mod test_support { } } - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() + fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { + if let Some(end_turn_tx) = self + .sessions + .lock() + .get_mut(session_id) + .unwrap() + .response_tx + .take() + { + end_turn_tx.send(acp::StopReason::Canceled).unwrap(); + } } fn session_editor( diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 4760677fa1..2c02027c4d 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -4283,7 +4283,7 @@ pub(crate) mod tests { }, cx, ); - connection.end_turn(session_id); + connection.end_turn(session_id, acp::StopReason::EndTurn); }); thread_view.read_with(cx, |view, _cx| { @@ -4302,4 +4302,137 @@ pub(crate) mod tests { ); }); } + + #[gpui::test] + async fn test_interrupt(cx: &mut TestAppContext) { + init_test(cx); + + let connection = StubAgentConnection::new(); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(connection.clone()), cx).await; + add_to_workspace(thread_view.clone(), cx); + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Message 1", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + let (thread, session_id) = thread_view.read_with(cx, |view, cx| { + let thread = view.thread().unwrap(); + + (thread.clone(), thread.read(cx).session_id().clone()) + }); + + cx.run_until_parked(); + + cx.update(|_, cx| { + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "Message 1 resp".into(), + }, + cx, + ); + }); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 resp + + "} + ) + }); + + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Message 2", window, cx); + }); + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.update(|_, cx| { + // Simulate a response sent after beginning to cancel + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "onse".into(), + }, + cx, + ); + }); + + cx.run_until_parked(); + + // Last Message 1 response should appear before Message 2 + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 response + + ## User + + Message 2 + + "} + ) + }); + + cx.update(|_, cx| { + connection.send_update( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk { + content: "Message 2 response".into(), + }, + cx, + ); + connection.end_turn(session_id.clone(), acp::StopReason::EndTurn); + }); + + cx.run_until_parked(); + + thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + indoc::indoc! {" + ## User + + Message 1 + + ## Assistant + + Message 1 response + + ## User + + Message 2 + + ## Assistant + + Message 2 response + + "} + ) + }); + } }