Fix CC tool state on cancel

This commit is contained in:
Agus Zubiaga 2025-08-07 01:11:13 -03:00
parent f1e69f6311
commit 63cc3291e3

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project; use project::Project;
use settings::SettingsStore; use settings::SettingsStore;
use smol::process::Child; use smol::process::Child;
use std::cell::{Cell, RefCell}; use std::cell::RefCell;
use std::fmt::Display; use std::fmt::Display;
use std::path::Path; use std::path::Path;
use std::rc::Rc; use std::rc::Rc;
@ -153,13 +153,13 @@ impl AgentConnection for ClaudeAgentConnection {
}) })
.detach(); .detach();
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None)); let cancellation_state = Rc::new(RefCell::new(CancellationState::None));
let end_turn_tx = Rc::new(RefCell::new(None)); let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({ let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone(); let end_turn_tx = end_turn_tx.clone();
let mut thread_rx = thread_rx.clone(); let mut thread_rx = thread_rx.clone();
let cancellation_state = pending_cancellation.clone(); let cancellation_state = cancellation_state.clone();
async move |cx| { async move |cx| {
while let Some(message) = incoming_message_rx.next().await { while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message( ClaudeAgentSession::handle_message(
@ -193,7 +193,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession { let session = ClaudeAgentSession {
outgoing_tx, outgoing_tx,
end_turn_tx, end_turn_tx,
pending_cancellation, cancellation_state,
_handler_task: handler_task, _handler_task: handler_task,
_mcp_server: Some(permission_mcp_server), _mcp_server: Some(permission_mcp_server),
}; };
@ -260,10 +260,10 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err))); return Task::ready(Err(anyhow!(err)));
} }
let cancellation_state = session.pending_cancellation.clone(); let cancellation_state = session.cancellation_state.clone();
cx.foreground_executor().spawn(async move { cx.foreground_executor().spawn(async move {
let result = rx.await??; let result = rx.await??;
cancellation_state.set(PendingCancellation::None); *cancellation_state.borrow_mut() = CancellationState::None;
Ok(result) Ok(result)
}) })
} }
@ -277,9 +277,9 @@ impl AgentConnection for ClaudeAgentConnection {
let request_id = new_request_id(); let request_id = new_request_id();
session.pending_cancellation.set(PendingCancellation::Sent { *session.cancellation_state.borrow_mut() = CancellationState::Requested {
request_id: request_id.clone(), request_id: request_id.clone(),
}); };
session session
.outgoing_tx .outgoing_tx
@ -350,27 +350,33 @@ fn spawn_claude(
struct ClaudeAgentSession { struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>, outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
pending_cancellation: Rc<Cell<PendingCancellation>>, cancellation_state: Rc<RefCell<CancellationState>>,
_mcp_server: Option<ClaudeZedMcpServer>, _mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>, _handler_task: Task<()>,
} }
#[derive(Debug, Default, PartialEq)] #[derive(Debug, Default)]
enum PendingCancellation { enum CancellationState {
#[default] #[default]
None, None,
Sent { Requested {
request_id: String, request_id: String,
}, },
Confirmed, Confirmed,
} }
impl CancellationState {
fn is_confirmed(&self) -> bool {
matches!(self, CancellationState::Confirmed)
}
}
impl ClaudeAgentSession { impl ClaudeAgentSession {
async fn handle_message( async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>, mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage, message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>, end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
pending_cancellation: Rc<Cell<PendingCancellation>>, cancellation_state: Rc<RefCell<CancellationState>>,
cx: &mut AsyncApp, cx: &mut AsyncApp,
) { ) {
match message { match message {
@ -393,15 +399,13 @@ impl ClaudeAgentSession {
for chunk in message.content.chunks() { for chunk in message.content.chunks() {
match chunk { match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
let state = pending_cancellation.take(); if !cancellation_state.borrow().is_confirmed() {
if state != PendingCancellation::Confirmed {
thread thread
.update(cx, |thread, cx| { .update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx) thread.push_user_content_block(text.into(), cx)
}) })
.log_err(); .log_err();
} }
pending_cancellation.set(state);
} }
ContentChunk::ToolResult { ContentChunk::ToolResult {
content, content,
@ -414,7 +418,15 @@ impl ClaudeAgentSession {
acp::ToolCallUpdate { acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()), id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields { fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed), status: if cancellation_state
.borrow()
.is_confirmed()
{
// Do not set to completed if turn was cancelled
None
} else {
Some(acp::ToolCallStatus::Completed)
},
content: (!content.is_empty()) content: (!content.is_empty())
.then(|| vec![content.into()]), .then(|| vec![content.into()]),
..Default::default() ..Default::default()
@ -544,7 +556,7 @@ impl ClaudeAgentSession {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() { if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error if is_error
|| (subtype == ResultErrorType::ErrorDuringExecution || (subtype == ResultErrorType::ErrorDuringExecution
&& pending_cancellation.take() != PendingCancellation::Confirmed) && !cancellation_state.borrow().is_confirmed())
{ {
end_turn_tx end_turn_tx
.send(Err(anyhow!( .send(Err(anyhow!(
@ -566,14 +578,11 @@ impl ClaudeAgentSession {
} }
SdkMessage::ControlResponse { response } => { SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) { if matches!(response.subtype, ResultErrorType::Success) {
let pending_cancellation_value = pending_cancellation.take(); let mut cancellation_state = cancellation_state.borrow_mut();
if let CancellationState::Requested { request_id } = &*cancellation_state
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
&& request_id == &response.request_id && request_id == &response.request_id
{ {
pending_cancellation.set(PendingCancellation::Confirmed); *cancellation_state = CancellationState::Confirmed;
} else {
pending_cancellation.set(pending_cancellation_value);
} }
} }
} }