From 6912dc8399148dd0caf951ce0bba711de7279f01 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 7 Aug 2025 20:26:19 -0300 Subject: [PATCH] Fix CC tool state on cancel (#35763) When we stop the generation, CC tells us the tool completed, but it was actually cancelled. Release Notes: - N/A --- crates/agent_servers/src/claude.rs | 148 +++++++++++++++----------- crates/agent_servers/src/e2e_tests.rs | 7 +- 2 files changed, 91 insertions(+), 64 deletions(-) diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 09d08fdcf8..c65508f152 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection { }) .detach(); - let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); + let turn_state = Rc::new(RefCell::new(TurnState::None)); - let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ - let end_turn_tx = end_turn_tx.clone(); + let turn_state = turn_state.clone(); let mut thread_rx = thread_rx.clone(); - let cancellation_state = pending_cancellation.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( thread_rx.clone(), message, - end_turn_tx.clone(), - cancellation_state.clone(), + turn_state.clone(), cx, ) .await @@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, - end_turn_tx, - pending_cancellation, + turn_state, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection { ))); }; - let (tx, rx) = oneshot::channel(); - session.end_turn_tx.borrow_mut().replace(tx); + let (end_tx, end_rx) = oneshot::channel(); + session.turn_state.replace(TurnState::InProgress { end_tx }); let mut content = String::new(); for chunk in params.prompt { @@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - let cancellation_state = session.pending_cancellation.clone(); - cx.foreground_executor().spawn(async move { - let result = rx.await??; - cancellation_state.set(PendingCancellation::None); - Ok(result) - }) + cx.foreground_executor().spawn(async move { end_rx.await? }) } fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { @@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection { let request_id = new_request_id(); - session.pending_cancellation.set(PendingCancellation::Sent { + let turn_state = session.turn_state.take(); + let TurnState::InProgress { end_tx } = turn_state else { + // Already cancelled or idle, put it back + session.turn_state.replace(turn_state); + return; + }; + + session.turn_state.replace(TurnState::CancelRequested { + end_tx, request_id: request_id.clone(), }); @@ -349,28 +348,56 @@ fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, - end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + turn_state: Rc>, _mcp_server: Option, _handler_task: Task<()>, } -#[derive(Debug, Default, PartialEq)] -enum PendingCancellation { +#[derive(Debug, Default)] +enum TurnState { #[default] None, - Sent { + InProgress { + end_tx: oneshot::Sender>, + }, + CancelRequested { + end_tx: oneshot::Sender>, request_id: String, }, - Confirmed, + CancelConfirmed { + end_tx: oneshot::Sender>, + }, +} + +impl TurnState { + fn is_cancelled(&self) -> bool { + matches!(self, TurnState::CancelConfirmed { .. }) + } + + fn end_tx(self) -> Option>> { + match self { + TurnState::None => None, + TurnState::InProgress { end_tx, .. } => Some(end_tx), + TurnState::CancelRequested { end_tx, .. } => Some(end_tx), + TurnState::CancelConfirmed { end_tx } => Some(end_tx), + } + } + + fn confirm_cancellation(self, id: &str) -> Self { + match self { + TurnState::CancelRequested { request_id, end_tx } if request_id == id => { + TurnState::CancelConfirmed { end_tx } + } + _ => self, + } + } } impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, - end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + turn_state: Rc>, cx: &mut AsyncApp, ) { match message { @@ -393,15 +420,13 @@ impl ClaudeAgentSession { for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - let state = pending_cancellation.take(); - if state != PendingCancellation::Confirmed { + if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { thread.push_user_content_block(text.into(), cx) }) .log_err(); } - pending_cancellation.set(state); } ContentChunk::ToolResult { content, @@ -414,7 +439,12 @@ impl ClaudeAgentSession { acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.into()), fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), + status: if turn_state.borrow().is_cancelled() { + // Do not set to completed if turn was cancelled + None + } else { + Some(acp::ToolCallStatus::Completed) + }, content: (!content.is_empty()) .then(|| vec![content.into()]), ..Default::default() @@ -541,40 +571,38 @@ impl ClaudeAgentSession { result, .. } => { - if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error - || (subtype == ResultErrorType::ErrorDuringExecution - && pending_cancellation.take() != PendingCancellation::Confirmed) - { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); - } else { - let stop_reason = match subtype { - ResultErrorType::Success => acp::StopReason::EndTurn, - ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, - ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, - }; - end_turn_tx - .send(Ok(acp::PromptResponse { stop_reason })) - .ok(); - } + let turn_state = turn_state.take(); + let was_cancelled = turn_state.is_cancelled(); + let Some(end_turn_tx) = turn_state.end_tx() else { + debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); + return; + }; + + if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) + { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); + } else { + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); } } SdkMessage::ControlResponse { response } => { if matches!(response.subtype, ResultErrorType::Success) { - let pending_cancellation_value = pending_cancellation.take(); - - if let PendingCancellation::Sent { request_id } = &pending_cancellation_value - && request_id == &response.request_id - { - pending_cancellation.set(PendingCancellation::Confirmed); - } else { - pending_cancellation.set(pending_cancellation_value); - } + let new_state = turn_state.take().confirm_cancellation(&response.request_id); + turn_state.replace(new_state); + } else { + log::error!("Control response error: {:?}", response); } } SdkMessage::System { .. } => {} diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 05f874bd30..ec6ca29b9d 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; - let full_turn = thread.update(cx, |thread, cx| { + let _ = thread.update(cx, |thread, cx| { thread.send_raw( r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, cx, @@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon id.clone() }); - let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); - full_turn.await.unwrap(); - thread.read_with(cx, |thread, _| { + thread.update(cx, |thread, cx| thread.cancel(cx)).await; + thread.read_with(cx, |thread, _cx| { let AgentThreadEntry::ToolCall(ToolCall { status: ToolCallStatus::Canceled, ..