Fix interrupting ACP threads and CC cancellation (#35752)
Fixes a bug where generation wouldn't continue after interrupting the agent, and improves CC cancellation so we don't display "[Request interrupted by user]" Release Notes: - N/A --------- Co-authored-by: Cole Miller <cole@zed.dev>
This commit is contained in:
parent
1907b16fe6
commit
bd1c26cb5b
2 changed files with 171 additions and 80 deletions
|
@ -6,7 +6,7 @@ use context_server::listener::McpServerTool;
|
|||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
@ -24,7 +24,7 @@ use futures::{
|
|||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, debug_panic};
|
||||
|
||||
use crate::claude::mcp_server::{ClaudeZedMcpServer, McpConfig};
|
||||
use crate::claude::tools::ClaudeTool;
|
||||
|
@ -153,16 +153,20 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
})
|
||||
.detach();
|
||||
|
||||
let pending_cancellation = Rc::new(Cell::new(PendingCancellation::None));
|
||||
|
||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
||||
let handler_task = cx.spawn({
|
||||
let end_turn_tx = end_turn_tx.clone();
|
||||
let mut thread_rx = thread_rx.clone();
|
||||
let cancellation_state = pending_cancellation.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = incoming_message_rx.next().await {
|
||||
ClaudeAgentSession::handle_message(
|
||||
thread_rx.clone(),
|
||||
message,
|
||||
end_turn_tx.clone(),
|
||||
cancellation_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
|
@ -189,6 +193,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
let session = ClaudeAgentSession {
|
||||
outgoing_tx,
|
||||
end_turn_tx,
|
||||
pending_cancellation,
|
||||
_handler_task: handler_task,
|
||||
_mcp_server: Some(permission_mcp_server),
|
||||
};
|
||||
|
@ -255,7 +260,12 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
return Task::ready(Err(anyhow!(err)));
|
||||
}
|
||||
|
||||
cx.foreground_executor().spawn(async move { rx.await? })
|
||||
let cancellation_state = session.pending_cancellation.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let result = rx.await??;
|
||||
cancellation_state.set(PendingCancellation::None);
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||
|
@ -265,18 +275,19 @@ impl AgentConnection for ClaudeAgentConnection {
|
|||
return;
|
||||
};
|
||||
|
||||
let request_id = new_request_id();
|
||||
|
||||
session.pending_cancellation.set(PendingCancellation::Sent {
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
|
||||
session
|
||||
.outgoing_tx
|
||||
.unbounded_send(SdkMessage::new_interrupt_message())
|
||||
.unbounded_send(SdkMessage::ControlRequest {
|
||||
request_id,
|
||||
request: ControlRequest::Interrupt,
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(end_turn_tx) = session.end_turn_tx.borrow_mut().take() {
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::Cancelled,
|
||||
}))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -339,25 +350,107 @@ fn spawn_claude(
|
|||
struct ClaudeAgentSession {
|
||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
pending_cancellation: Rc<Cell<PendingCancellation>>,
|
||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||
_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
enum PendingCancellation {
|
||||
#[default]
|
||||
None,
|
||||
Sent {
|
||||
request_id: String,
|
||||
},
|
||||
Confirmed,
|
||||
}
|
||||
|
||||
impl ClaudeAgentSession {
|
||||
async fn handle_message(
|
||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
message: SdkMessage,
|
||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
||||
pending_cancellation: Rc<Cell<PendingCancellation>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
match message {
|
||||
// we should only be sending these out, they don't need to be in the thread
|
||||
SdkMessage::ControlRequest { .. } => {}
|
||||
SdkMessage::Assistant {
|
||||
SdkMessage::User {
|
||||
message,
|
||||
session_id: _,
|
||||
} => {
|
||||
let Some(thread) = thread_rx
|
||||
.recv()
|
||||
.await
|
||||
.log_err()
|
||||
.and_then(|entity| entity.upgrade())
|
||||
else {
|
||||
log::error!("Received an SDK message but thread is gone");
|
||||
return;
|
||||
};
|
||||
|
||||
for chunk in message.content.chunks() {
|
||||
match chunk {
|
||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||
let state = pending_cancellation.take();
|
||||
if state != PendingCancellation::Confirmed {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(text.into(), cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
pending_cancellation.set(state);
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let content = content.to_string();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
content: (!content.is_empty())
|
||||
.then(|| vec![content.into()]),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::Thinking { .. }
|
||||
| ContentChunk::RedactedThinking
|
||||
| ContentChunk::ToolUse { .. } => {
|
||||
debug_panic!(
|
||||
"Should not get {:?} with role: assistant. should we handle this?",
|
||||
chunk
|
||||
);
|
||||
}
|
||||
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
format!("Unsupported content: {:?}", chunk).into(),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
| SdkMessage::User {
|
||||
SdkMessage::Assistant {
|
||||
message,
|
||||
session_id: _,
|
||||
} => {
|
||||
|
@ -423,31 +516,12 @@ impl ClaudeAgentSession {
|
|||
})
|
||||
.log_err();
|
||||
}
|
||||
ContentChunk::ToolResult {
|
||||
content,
|
||||
tool_use_id,
|
||||
} => {
|
||||
let content = content.to_string();
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(
|
||||
acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(tool_use_id.into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
content: (!content.is_empty())
|
||||
.then(|| vec![content.into()]),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.log_err();
|
||||
ContentChunk::ToolResult { .. } | ContentChunk::WebSearchToolResult => {
|
||||
debug_panic!(
|
||||
"Should not get tool results with role: assistant. should we handle this?"
|
||||
);
|
||||
}
|
||||
ContentChunk::Image
|
||||
| ContentChunk::Document
|
||||
| ContentChunk::WebSearchToolResult => {
|
||||
ContentChunk::Image | ContentChunk::Document => {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
|
@ -468,7 +542,10 @@ impl ClaudeAgentSession {
|
|||
..
|
||||
} => {
|
||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
||||
if is_error || subtype == ResultErrorType::ErrorDuringExecution {
|
||||
if is_error
|
||||
|| (subtype == ResultErrorType::ErrorDuringExecution
|
||||
&& pending_cancellation.take() != PendingCancellation::Confirmed)
|
||||
{
|
||||
end_turn_tx
|
||||
.send(Err(anyhow!(
|
||||
"Error: {}",
|
||||
|
@ -479,7 +556,7 @@ impl ClaudeAgentSession {
|
|||
let stop_reason = match subtype {
|
||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||
ResultErrorType::ErrorDuringExecution => unreachable!(),
|
||||
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
||||
};
|
||||
end_turn_tx
|
||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||
|
@ -487,7 +564,20 @@ impl ClaudeAgentSession {
|
|||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } | SdkMessage::ControlResponse { .. } => {}
|
||||
SdkMessage::ControlResponse { response } => {
|
||||
if matches!(response.subtype, ResultErrorType::Success) {
|
||||
let pending_cancellation_value = pending_cancellation.take();
|
||||
|
||||
if let PendingCancellation::Sent { request_id } = &pending_cancellation_value
|
||||
&& request_id == &response.request_id
|
||||
{
|
||||
pending_cancellation.set(PendingCancellation::Confirmed);
|
||||
} else {
|
||||
pending_cancellation.set(pending_cancellation_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
SdkMessage::System { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -728,22 +818,15 @@ impl Display for ResultErrorType {
|
|||
}
|
||||
}
|
||||
|
||||
impl SdkMessage {
|
||||
fn new_interrupt_message() -> Self {
|
||||
use rand::Rng;
|
||||
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||
// `Math.random().toString(36).substring(2, 15)`
|
||||
let request_id = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(12)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
Self::ControlRequest {
|
||||
request_id,
|
||||
request: ControlRequest::Interrupt,
|
||||
}
|
||||
}
|
||||
fn new_request_id() -> String {
|
||||
use rand::Rng;
|
||||
// In the Claude Code TS SDK they just generate a random 12 character string,
|
||||
// `Math.random().toString(36).substring(2, 15)`
|
||||
rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(12)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue