Fix interrupting ACP threads and CC cancellation (#35752)

Fixes a bug where generation wouldn't continue after interrupting the
agent, and improves CC cancellation so we don't display "[Request
interrupted by user]"

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
This commit is contained in:
Agus Zubiaga 2025-08-06 22:55:17 -03:00 committed by GitHub
parent 1907b16fe6
commit bd1c26cb5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 171 additions and 80 deletions

View file

@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
@ -24,7 +24,7 @@ use futures::{
};
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
use serde::{Deserialize, Serialize};
use util::ResultExt;
use util::{ResultExt, debug_panic};
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
use crate::claude::tools::ClaudeTool;
@ -153,16 +153,20 @@ impl AgentConnection for ClaudeAgentConnection {
})
.detach();
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
let mut thread_rx = thread_rx.clone();
let cancellation_state = pending_cancellation.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
end_turn_tx.clone(),
cancellation_state.clone(),
cx,
)
.await
@ -189,6 +193,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
pending_cancellation,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@ -255,7 +260,12 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
cx.foreground_executor().spawn(async move { rx.await? })
let cancellation_state = session.pending_cancellation.clone();
cx.foreground_executor().spawn(async move {
let result = rx.await??;
cancellation_state.set(PendingCancellation::None);
Ok(result)
})
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@ -265,18 +275,19 @@ impl AgentConnection for ClaudeAgentConnection {
return;
};
let request_id = new_request_id();
session.pending_cancellation.set(PendingCancellation::Sent {
request_id: request_id.clone(),
});
session
.outgoing_tx
.unbounded_send(SdkMessage::new_interrupt_message())
.unbounded_send(SdkMessage::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
})
.log_err();
if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
end_turn_tx
.send(Ok(acp::PromptResponse {
stop_reason: acp::StopReason::Cancelled,
}))
.ok();
}
}
}
@ -339,25 +350,107 @@ fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
#[derive(Debug, Default, PartialEq)]
enum PendingCancellation {
#[default]
None,
Sent {
request_id: String,
},
Confirmed,
}
impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
pending_cancellation: Rc<Cell<PendingCancellation>>,
cx: &mut AsyncApp,
) {
match message {
// we should only be sending these out, they don't need to be in the thread
SdkMessage::ControlRequest { .. } => {}
SdkMessage::Assistant {
SdkMessage::User {
message,
session_id: _,
} => {
let Some(thread) = thread_rx
.recv()
.await
.log_err()
.and_then(|entity| entity.upgrade())
else {
log::error!("Received an SDK message but thread is gone");
return;
};
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
let state = pending_cancellation.take();
if state != PendingCancellation::Confirmed {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
})
.log_err();
}
pending_cancellation.set(state);
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
}
ContentChunk::Thinking { .. }
| ContentChunk::RedactedThinking
| ContentChunk::ToolUse { .. } => {
debug_panic!(
"Should not get {:?} with role: assistant. should we handle this?",
chunk
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
format!("Unsupported content: {:?}", chunk).into(),
false,
cx,
)
})
.log_err();
}
}
}
}
| SdkMessage::User {
SdkMessage::Assistant {
message,
session_id: _,
} => {
@ -423,31 +516,12 @@ impl ClaudeAgentSession {
})
.log_err();
}
ContentChunk::ToolResult {
content,
tool_use_id,
} => {
let content = content.to_string();
thread
.update(cx, |thread, cx| {
thread.update_tool_call(
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
content: (!content.is_empty())
.then(|| vec![content.into()]),
..Default::default()
},
},
cx,
)
})
.log_err();
ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
debug_panic!(
"Should not get tool results with role: assistant. should we handle this?"
);
}
ContentChunk::Image
| ContentChunk::Document
| ContentChunk::WebSearchToolResult => {
ContentChunk::Image | ContentChunk::Document => {
thread
.update(cx, |thread, cx| {
thread.push_assistant_content_block(
@ -468,7 +542,10 @@ impl ClaudeAgentSession {
..
} => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error || subtype == ResultErrorType::ErrorDuringExecution {
if is_error
|| (subtype == ResultErrorType::ErrorDuringExecution
&& pending_cancellation.take() != PendingCancellation::Confirmed)
{
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
@ -479,7 +556,7 @@ impl ClaudeAgentSession {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => unreachable!(),
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
@ -487,7 +564,20 @@ impl ClaudeAgentSession {
}
}
}
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) {
let pending_cancellation_value = pending_cancellation.take();
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
&& request_id == &response.request_id
{
pending_cancellation.set(PendingCancellation::Confirmed);
} else {
pending_cancellation.set(pending_cancellation_value);
}
}
}
SdkMessage::System { .. } => {}
}
}
@ -728,22 +818,15 @@ impl Display for ResultErrorType {
}
}
impl SdkMessage {
fn new_interrupt_message() -> Self {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
let request_id = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect();
Self::ControlRequest {
request_id,
request: ControlRequest::Interrupt,
}
}
fn new_request_id() -> String {
use rand::Rng;
// In the Claude Code TS SDK they just generate a random 12 character string,
// `Math.random().toString(36).substring(2, 15)`
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(12)
.map(char::from)
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize)]