Fix CC tool state on cancel (#35763)

When we stop the generation, CC tells us the tool completed, but it was
actually cancelled.

Release Notes:

- N/A
This commit is contained in:
Agus Zubiaga 2025-08-07 20:26:19 -03:00 committed by GitHub
parent 952e3713d7
commit 6912dc8399
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 91 additions and 64 deletions

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::cell::{Cell, RefCell}; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
.detach(); .detach();
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); let turn_state = Rc::new(RefCell::new(TurnState::None));
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({ let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone(); let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone(); let mut thread_rx = thread_rx.clone();
let cancellation_state = pending_cancellation.clone();
async move |cx| { async move |cx| {
while let Some(message) = incoming_message_rx.next().await { while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message( ClaudeAgentSession::handle_message(
thread_rx.clone(), thread_rx.clone(),
message, message,
end_turn_tx.clone(), turn_state.clone(),
cancellation_state.clone(),
cx, cx,
) )
.await .await
@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession { let session = ClaudeAgentSession {
outgoing_tx, outgoing_tx,
end_turn_tx, turn_state,
pending_cancellation,
_handler_task: handler_task, _handler_task: handler_task,
_mcp_server: Some(permission_mcp_server), _mcp_server: Some(permission_mcp_server),
}; };
@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
))); )));
}; };
let (tx, rx) = oneshot::channel(); let (end_tx, end_rx) = oneshot::channel();
session.end_turn_tx.borrow_mut().replace(tx); session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new(); let mut content = String::new();
for chunk in params.prompt { for chunk in params.prompt {
@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err))); return Task::ready(Err(anyhow!(err)));
} }
let cancellation_state = session.pending_cancellation.clone(); cx.foreground_executor().spawn(async move { end_rx.await? })
cx.foreground_executor().spawn(async move {
let result = rx.await??;
cancellation_state.set(PendingCancellation::None);
Ok(result)
})
} }
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) { fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@ -277,7 +268,15 @@ impl AgentConnection for ClaudeAgentConnection {
let request_id = new_request_id(); let request_id = new_request_id();
session.pending_cancellation.set(PendingCancellation::Sent { let turn_state = session.turn_state.take();
let TurnState::InProgress { end_tx } = turn_state else {
// Already cancelled or idle, put it back
session.turn_state.replace(turn_state);
return;
};
session.turn_state.replace(TurnState::CancelRequested {
end_tx,
request_id: request_id.clone(), request_id: request_id.clone(),
}); });
@ -349,28 +348,56 @@ fn spawn_claude(
struct ClaudeAgentSession { struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, turn_state: Rc<RefCell<TurnState>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
_mcp_server: Option<ClaudeZedMcpServer>, _mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>, _handler_task: Task<()>,
} }
#[derive(Debug, Default, PartialEq)] #[derive(Debug, Default)]
enum PendingCancellation { enum TurnState {
#[default] #[default]
None, None,
Sent { InProgress {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
CancelRequested {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String, request_id: String,
}, },
Confirmed, CancelConfirmed {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
}
impl TurnState {
fn is_cancelled(&self) -> bool {
matches!(self, TurnState::CancelConfirmed { .. })
}
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
match self {
TurnState::None => None,
TurnState::InProgress { end_tx, .. } => Some(end_tx),
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
}
}
fn confirm_cancellation(self, id: &str) -> Self {
match self {
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
TurnState::CancelConfirmed { end_tx }
}
_ => self,
}
}
} }
impl ClaudeAgentSession { impl ClaudeAgentSession {
async fn handle_message( async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>, mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage, message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, turn_state: Rc<RefCell<TurnState>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) { ) {
match message { match message {
@ -393,15 +420,13 @@ impl ClaudeAgentSession {
for chunk in message.content.chunks() { for chunk in message.content.chunks() {
match chunk { match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
let state = pending_cancellation.take(); if !turn_state.borrow().is_cancelled() {
if state != PendingCancellation::Confirmed {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx) thread.push_user_content_block(text.into(), cx)
}) })
.log_err(); .log_err();
} }
pending_cancellation.set(state);
} }
ContentChunk::ToolResult { ContentChunk::ToolResult {
content, content,
@ -414,7 +439,12 @@ impl ClaudeAgentSession {
acp::ToolCallUpdate { acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()), id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields { fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed), status: if turn_state.borrow().is_cancelled() {
// Do not set to completed if turn was cancelled
None
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty()) content: (!content.is_empty())
.then(|| vec![content.into()]), .then(|| vec![content.into()]),
..Default::default() ..Default::default()
@ -541,40 +571,38 @@ impl ClaudeAgentSession {
result, result,
.. ..
} => { } => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { let turn_state = turn_state.take();
if is_error let was_cancelled = turn_state.is_cancelled();
|| (subtype == ResultErrorType::ErrorDuringExecution let Some(end_turn_tx) = turn_state.end_tx() else {
&& pending_cancellation.take() != PendingCancellation::Confirmed) debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
{ return;
end_turn_tx };
.send(Err(anyhow!(
"Error: {}", if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
result.unwrap_or_else(|| subtype.to_string()) {
))) end_turn_tx
.ok(); .send(Err(anyhow!(
} else { "Error: {}",
let stop_reason = match subtype { result.unwrap_or_else(|| subtype.to_string())
ResultErrorType::Success => acp::StopReason::EndTurn, )))
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests, .ok();
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled, } else {
}; let stop_reason = match subtype {
end_turn_tx ResultErrorType::Success => acp::StopReason::EndTurn,
.send(Ok(acp::PromptResponse { stop_reason })) ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
.ok(); ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
} };
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
} }
} }
SdkMessage::ControlResponse { response } => { SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) { if matches!(response.subtype, ResultErrorType::Success) {
let pending_cancellation_value = pending_cancellation.take(); let new_state = turn_state.take().confirm_cancellation(&response.request_id);
turn_state.replace(new_state);
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value } else {
&& request_id == &response.request_id log::error!("Control response error: {:?}", response);
{
pending_cancellation.set(PendingCancellation::Confirmed);
} else {
pending_cancellation.set(pending_cancellation_value);
}
} }
} }
SdkMessage::System { .. } => {} SdkMessage::System { .. } => {}

View file

@ -246,7 +246,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await; let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await; let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| { let _ = thread.update(cx, |thread, cx| {
thread.send_raw( thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#, r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
cx, cx,
@ -285,9 +285,8 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
id.clone() id.clone()
}); });
let _ = thread.update(cx, |thread, cx| thread.cancel(cx)); thread.update(cx, |thread, cx| thread.cancel(cx)).await;
full_turn.await.unwrap(); thread.read_with(cx, |thread, _cx| {
thread.read_with(cx, |thread, _| {
let AgentThreadEntry::ToolCall(ToolCall { let AgentThreadEntry::ToolCall(ToolCall {
status: ToolCallStatus::Canceled, status: ToolCallStatus::Canceled,
.. ..