Combine end_turn_tx and cancellation_state into one enum

This commit is contained in:
Agus Zubiaga 2025-08-07 11:36:29 -03:00
parent 4b94e90899
commit 7f9adae3a3

View file

@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
})
.detach();
let cancellation_state = Rc::new(RefCell::new(CancellationState::None));
let turn_state = Rc::new(RefCell::new(TurnState::None));
let end_turn_tx = Rc::new(RefCell::new(None));
let handler_task = cx.spawn({
let end_turn_tx = end_turn_tx.clone();
let turn_state = turn_state.clone();
let mut thread_rx = thread_rx.clone();
let cancellation_state = cancellation_state.clone();
async move |cx| {
while let Some(message) = incoming_message_rx.next().await {
ClaudeAgentSession::handle_message(
thread_rx.clone(),
message,
end_turn_tx.clone(),
cancellation_state.clone(),
turn_state.clone(),
cx,
)
.await
@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
let session = ClaudeAgentSession {
outgoing_tx,
end_turn_tx,
cancellation_state,
turn_state,
_handler_task: handler_task,
_mcp_server: Some(permission_mcp_server),
};
@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
)));
};
let (tx, rx) = oneshot::channel();
session.end_turn_tx.borrow_mut().replace(tx);
let (end_tx, end_rx) = oneshot::channel();
session.turn_state.replace(TurnState::InProgress { end_tx });
let mut content = String::new();
for chunk in params.prompt {
@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
return Task::ready(Err(anyhow!(err)));
}
let cancellation_state = session.cancellation_state.clone();
cx.foreground_executor().spawn(async move {
let result = rx.await??;
cancellation_state.replace(CancellationState::None);
Ok(result)
})
cx.foreground_executor().spawn(async move { end_rx.await? })
}
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
@ -277,11 +268,17 @@ impl AgentConnection for ClaudeAgentConnection {
let request_id = new_request_id();
session
.cancellation_state
.replace(CancellationState::Requested {
request_id: request_id.clone(),
});
let turn_state = session.turn_state.take();
let TurnState::InProgress { end_tx } = session.turn_state.take() else {
// Already cancelled or idle, put it back
session.turn_state.replace(turn_state);
return;
};
session.turn_state.replace(TurnState::CancelRequested {
end_tx,
request_id: request_id.clone(),
});
session
.outgoing_tx
@ -351,25 +348,48 @@ fn spawn_claude(
struct ClaudeAgentSession {
outgoing_tx: UnboundedSender<SdkMessage>,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
cancellation_state: Rc<RefCell<CancellationState>>,
turn_state: Rc<RefCell<TurnState>>,
_mcp_server: Option<ClaudeZedMcpServer>,
_handler_task: Task<()>,
}
#[derive(Debug, Default)]
enum CancellationState {
enum TurnState {
#[default]
None,
Requested {
InProgress {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
CancelRequested {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
request_id: String,
},
Confirmed,
CancelConfirmed {
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
},
}
impl CancellationState {
fn is_confirmed(&self) -> bool {
matches!(self, CancellationState::Confirmed)
impl TurnState {
fn is_cancelled(&self) -> bool {
matches!(self, TurnState::CancelConfirmed { .. })
}
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
match self {
TurnState::None => None,
TurnState::InProgress { end_tx, .. } => Some(end_tx),
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
}
}
fn confirm_cancellation(self, id: &str) -> Self {
match self {
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
TurnState::CancelConfirmed { end_tx }
}
_ => self,
}
}
}
@ -377,8 +397,7 @@ impl ClaudeAgentSession {
async fn handle_message(
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
message: SdkMessage,
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
cancellation_state: Rc<RefCell<CancellationState>>,
turn_state: Rc<RefCell<TurnState>>,
cx: &mut AsyncApp,
) {
match message {
@ -401,7 +420,7 @@ impl ClaudeAgentSession {
for chunk in message.content.chunks() {
match chunk {
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
if !cancellation_state.borrow().is_confirmed() {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(text.into(), cx)
@ -420,10 +439,7 @@ impl ClaudeAgentSession {
acp::ToolCallUpdate {
id: acp::ToolCallId(tool_use_id.into()),
fields: acp::ToolCallUpdateFields {
status: if cancellation_state
.borrow()
.is_confirmed()
{
status: if turn_state.borrow().is_cancelled() {
// Do not set to completed if turn was cancelled
None
} else {
@ -555,37 +571,38 @@ impl ClaudeAgentSession {
result,
..
} => {
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
if is_error
|| (subtype == ResultErrorType::ErrorDuringExecution
&& !cancellation_state.borrow().is_confirmed())
{
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
}
let turn_state = turn_state.take();
let was_cancelled = turn_state.is_cancelled();
let Some(end_turn_tx) = turn_state.end_tx() else {
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
return;
};
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
{
end_turn_tx
.send(Err(anyhow!(
"Error: {}",
result.unwrap_or_else(|| subtype.to_string())
)))
.ok();
} else {
let stop_reason = match subtype {
ResultErrorType::Success => acp::StopReason::EndTurn,
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
};
end_turn_tx
.send(Ok(acp::PromptResponse { stop_reason }))
.ok();
}
}
SdkMessage::ControlResponse { response } => {
if matches!(response.subtype, ResultErrorType::Success) {
let mut cancellation_state = cancellation_state.borrow_mut();
if let CancellationState::Requested { request_id } = &*cancellation_state
&& request_id == &response.request_id
{
*cancellation_state = CancellationState::Confirmed;
}
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
turn_state.replace(new_state);
} else {
log::error!("Control response error: {:?}", response);
}
}
SdkMessage::System { .. } => {}