From 7f9adae3a372702214fb4846caf828527c9b036f Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 7 Aug 2025 11:36:29 -0300 Subject: [PATCH] Combine end_turn_tx and cancellation_state into one enum --- crates/agent_servers/src/claude.rs | 143 ++++++++++++++++------------- 1 file changed, 80 insertions(+), 63 deletions(-) diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 98de041047..c00c3877cb 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection { }) .detach(); - let cancellation_state = Rc::new(RefCell::new(CancellationState::None)); + let turn_state = Rc::new(RefCell::new(TurnState::None)); - let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ - let end_turn_tx = end_turn_tx.clone(); + let turn_state = turn_state.clone(); let mut thread_rx = thread_rx.clone(); - let cancellation_state = cancellation_state.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( thread_rx.clone(), message, - end_turn_tx.clone(), - cancellation_state.clone(), + turn_state.clone(), cx, ) .await @@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, - end_turn_tx, - cancellation_state, + turn_state, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection { ))); }; - let (tx, rx) = oneshot::channel(); - session.end_turn_tx.borrow_mut().replace(tx); + let (end_tx, end_rx) = oneshot::channel(); + session.turn_state.replace(TurnState::InProgress { end_tx }); let mut content = String::new(); for chunk in params.prompt { @@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - let cancellation_state = session.cancellation_state.clone(); - cx.foreground_executor().spawn(async move { - let result = rx.await??; - cancellation_state.replace(CancellationState::None); - Ok(result) - }) + cx.foreground_executor().spawn(async move { end_rx.await? }) } fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { @@ -277,11 +268,17 @@ impl AgentConnection for ClaudeAgentConnection { let request_id = new_request_id(); - session - .cancellation_state - .replace(CancellationState::Requested { - request_id: request_id.clone(), - }); + let turn_state = session.turn_state.take(); + let TurnState::InProgress { end_tx } = session.turn_state.take() else { + // Already cancelled or idle, put it back + session.turn_state.replace(turn_state); + return; + }; + + session.turn_state.replace(TurnState::CancelRequested { + end_tx, + request_id: request_id.clone(), + }); session .outgoing_tx @@ -351,25 +348,48 @@ fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, - end_turn_tx: Rc>>>>, - cancellation_state: Rc>, + turn_state: Rc>, _mcp_server: Option, _handler_task: Task<()>, } #[derive(Debug, Default)] -enum CancellationState { +enum TurnState { #[default] None, - Requested { + InProgress { + end_tx: oneshot::Sender>, + }, + CancelRequested { + end_tx: oneshot::Sender>, request_id: String, }, - Confirmed, + CancelConfirmed { + end_tx: oneshot::Sender>, + }, } -impl CancellationState { - fn is_confirmed(&self) -> bool { - matches!(self, CancellationState::Confirmed) +impl TurnState { + fn is_cancelled(&self) -> bool { + matches!(self, TurnState::CancelConfirmed { .. }) + } + + fn end_tx(self) -> Option>> { + match self { + TurnState::None => None, + TurnState::InProgress { end_tx, .. } => Some(end_tx), + TurnState::CancelRequested { end_tx, .. } => Some(end_tx), + TurnState::CancelConfirmed { end_tx } => Some(end_tx), + } + } + + fn confirm_cancellation(self, id: &str) -> Self { + match self { + TurnState::CancelRequested { request_id, end_tx } if request_id == id => { + TurnState::CancelConfirmed { end_tx } + } + _ => self, + } } } @@ -377,8 +397,7 @@ impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, - end_turn_tx: Rc>>>>, - cancellation_state: Rc>, + turn_state: Rc>, cx: &mut AsyncApp, ) { match message { @@ -401,7 +420,7 @@ impl ClaudeAgentSession { for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - if !cancellation_state.borrow().is_confirmed() { + if !turn_state.borrow().is_cancelled() { thread .update(cx, |thread, cx| { thread.push_user_content_block(text.into(), cx) @@ -420,10 +439,7 @@ impl ClaudeAgentSession { acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.into()), fields: acp::ToolCallUpdateFields { - status: if cancellation_state - .borrow() - .is_confirmed() - { + status: if turn_state.borrow().is_cancelled() { // Do not set to completed if turn was cancelled None } else { @@ -555,37 +571,38 @@ impl ClaudeAgentSession { result, .. } => { - if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { - if is_error - || (subtype == ResultErrorType::ErrorDuringExecution - && !cancellation_state.borrow().is_confirmed()) - { - end_turn_tx - .send(Err(anyhow!( - "Error: {}", - result.unwrap_or_else(|| subtype.to_string()) - ))) - .ok(); - } else { - let stop_reason = match subtype { - ResultErrorType::Success => acp::StopReason::EndTurn, - ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, - ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, - }; - end_turn_tx - .send(Ok(acp::PromptResponse { stop_reason })) - .ok(); - } + let turn_state = turn_state.take(); + let was_cancelled = turn_state.is_cancelled(); + let Some(end_turn_tx) = turn_state.end_tx() else { + debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn"); + return; + }; + + if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution) + { + end_turn_tx + .send(Err(anyhow!( + "Error: {}", + result.unwrap_or_else(|| subtype.to_string()) + ))) + .ok(); + } else { + let stop_reason = match subtype { + ResultErrorType::Success => acp::StopReason::EndTurn, + ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, + ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, + }; + end_turn_tx + .send(Ok(acp::PromptResponse { stop_reason })) + .ok(); } } SdkMessage::ControlResponse { response } => { if matches!(response.subtype, ResultErrorType::Success) { - let mut cancellation_state = cancellation_state.borrow_mut(); - if let CancellationState::Requested { request_id } = &*cancellation_state - && request_id == &response.request_id - { - *cancellation_state = CancellationState::Confirmed; - } + let new_state = turn_state.take().confirm_cancellation(&response.request_id); + turn_state.replace(new_state); + } else { + log::error!("Control response error: {:?}", response); } } SdkMessage::System { .. } => {}