diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 09d08fdcf8..588d6e9f45 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -6,7 +6,7 @@ use context_server::listener::McpServerTool; use project::Project; use settings::SettingsStore; use smol::process::Child; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; @@ -153,13 +153,13 @@ impl AgentConnection for ClaudeAgentConnection { }) .detach(); - let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); + let cancellation_state = Rc::new(RefCell::new(CancellationState::None)); let end_turn_tx = Rc::new(RefCell::new(None)); let handler_task = cx.spawn({ let end_turn_tx = end_turn_tx.clone(); let mut thread_rx = thread_rx.clone(); - let cancellation_state = pending_cancellation.clone(); + let cancellation_state = cancellation_state.clone(); async move |cx| { while let Some(message) = incoming_message_rx.next().await { ClaudeAgentSession::handle_message( @@ -193,7 +193,7 @@ impl AgentConnection for ClaudeAgentConnection { let session = ClaudeAgentSession { outgoing_tx, end_turn_tx, - pending_cancellation, + cancellation_state, _handler_task: handler_task, _mcp_server: Some(permission_mcp_server), }; @@ -260,10 +260,10 @@ impl AgentConnection for ClaudeAgentConnection { return Task::ready(Err(anyhow!(err))); } - let cancellation_state = session.pending_cancellation.clone(); + let cancellation_state = session.cancellation_state.clone(); cx.foreground_executor().spawn(async move { let result = rx.await??; - cancellation_state.set(PendingCancellation::None); + *cancellation_state.borrow_mut() = CancellationState::None; Ok(result) }) } @@ -277,9 +277,9 @@ impl AgentConnection for ClaudeAgentConnection { let request_id = new_request_id(); - session.pending_cancellation.set(PendingCancellation::Sent { + *session.cancellation_state.borrow_mut() = CancellationState::Requested { request_id: request_id.clone(), - }); + }; session .outgoing_tx @@ -350,27 +350,33 @@ fn spawn_claude( struct ClaudeAgentSession { outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + cancellation_state: Rc>, _mcp_server: Option, _handler_task: Task<()>, } -#[derive(Debug, Default, PartialEq)] -enum PendingCancellation { +#[derive(Debug, Default)] +enum CancellationState { #[default] None, - Sent { + Requested { request_id: String, }, Confirmed, } +impl CancellationState { + fn is_confirmed(&self) -> bool { + matches!(self, CancellationState::Confirmed) + } +} + impl ClaudeAgentSession { async fn handle_message( mut thread_rx: watch::Receiver>, message: SdkMessage, end_turn_tx: Rc>>>>, - pending_cancellation: Rc>, + cancellation_state: Rc>, cx: &mut AsyncApp, ) { match message { @@ -393,15 +399,13 @@ impl ClaudeAgentSession { for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - let state = pending_cancellation.take(); - if state != PendingCancellation::Confirmed { + if !cancellation_state.borrow().is_confirmed() { thread .update(cx, |thread, cx| { thread.push_user_content_block(text.into(), cx) }) .log_err(); } - pending_cancellation.set(state); } ContentChunk::ToolResult { content, @@ -414,7 +418,15 @@ impl ClaudeAgentSession { acp::ToolCallUpdate { id: acp::ToolCallId(tool_use_id.into()), fields: acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::Completed), + status: if cancellation_state + .borrow() + .is_confirmed() + { + // Do not set to completed if turn was cancelled + None + } else { + Some(acp::ToolCallStatus::Completed) + }, content: (!content.is_empty()) .then(|| vec![content.into()]), ..Default::default() @@ -544,7 +556,7 @@ impl ClaudeAgentSession { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { if is_error || (subtype == ResultErrorType::ErrorDuringExecution - && pending_cancellation.take() != PendingCancellation::Confirmed) + && !cancellation_state.borrow().is_confirmed()) { end_turn_tx .send(Err(anyhow!( @@ -566,14 +578,11 @@ impl ClaudeAgentSession { } SdkMessage::ControlResponse { response } => { if matches!(response.subtype, ResultErrorType::Success) { - let pending_cancellation_value = pending_cancellation.take(); - - if let PendingCancellation::Sent { request_id } = &pending_cancellation_value + let mut cancellation_state = cancellation_state.borrow_mut(); + if let CancellationState::Requested { request_id } = &*cancellation_state && request_id == &response.request_id { - pending_cancellation.set(PendingCancellation::Confirmed); - } else { - pending_cancellation.set(pending_cancellation_value); + *cancellation_state = CancellationState::Confirmed; } } }