Combine end_turn_tx and cancellation_state into one enum
This commit is contained in:
parent
4b94e90899
commit
7f9adae3a3
1 changed files with 80 additions and 63 deletions
|
@ -153,20 +153,17 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
let cancellation_state = Rc::new(RefCell::new(CancellationState::None));
|
let turn_state = Rc::new(RefCell::new(TurnState::None));
|
||||||
|
|
||||||
let end_turn_tx = Rc::new(RefCell::new(None));
|
|
||||||
let handler_task = cx.spawn({
|
let handler_task = cx.spawn({
|
||||||
let end_turn_tx = end_turn_tx.clone();
|
let turn_state = turn_state.clone();
|
||||||
let mut thread_rx = thread_rx.clone();
|
let mut thread_rx = thread_rx.clone();
|
||||||
let cancellation_state = cancellation_state.clone();
|
|
||||||
async move |cx| {
|
async move |cx| {
|
||||||
while let Some(message) = incoming_message_rx.next().await {
|
while let Some(message) = incoming_message_rx.next().await {
|
||||||
ClaudeAgentSession::handle_message(
|
ClaudeAgentSession::handle_message(
|
||||||
thread_rx.clone(),
|
thread_rx.clone(),
|
||||||
message,
|
message,
|
||||||
end_turn_tx.clone(),
|
turn_state.clone(),
|
||||||
cancellation_state.clone(),
|
|
||||||
cx,
|
cx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -192,8 +189,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
let session = ClaudeAgentSession {
|
let session = ClaudeAgentSession {
|
||||||
outgoing_tx,
|
outgoing_tx,
|
||||||
end_turn_tx,
|
turn_state,
|
||||||
cancellation_state,
|
|
||||||
_handler_task: handler_task,
|
_handler_task: handler_task,
|
||||||
_mcp_server: Some(permission_mcp_server),
|
_mcp_server: Some(permission_mcp_server),
|
||||||
};
|
};
|
||||||
|
@ -225,8 +221,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
)));
|
)));
|
||||||
};
|
};
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (end_tx, end_rx) = oneshot::channel();
|
||||||
session.end_turn_tx.borrow_mut().replace(tx);
|
session.turn_state.replace(TurnState::InProgress { end_tx });
|
||||||
|
|
||||||
let mut content = String::new();
|
let mut content = String::new();
|
||||||
for chunk in params.prompt {
|
for chunk in params.prompt {
|
||||||
|
@ -260,12 +256,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
return Task::ready(Err(anyhow!(err)));
|
return Task::ready(Err(anyhow!(err)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let cancellation_state = session.cancellation_state.clone();
|
cx.foreground_executor().spawn(async move { end_rx.await? })
|
||||||
cx.foreground_executor().spawn(async move {
|
|
||||||
let result = rx.await??;
|
|
||||||
cancellation_state.replace(CancellationState::None);
|
|
||||||
Ok(result)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
|
||||||
|
@ -277,11 +268,17 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||||
|
|
||||||
let request_id = new_request_id();
|
let request_id = new_request_id();
|
||||||
|
|
||||||
session
|
let turn_state = session.turn_state.take();
|
||||||
.cancellation_state
|
let TurnState::InProgress { end_tx } = session.turn_state.take() else {
|
||||||
.replace(CancellationState::Requested {
|
// Already cancelled or idle, put it back
|
||||||
request_id: request_id.clone(),
|
session.turn_state.replace(turn_state);
|
||||||
});
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
session.turn_state.replace(TurnState::CancelRequested {
|
||||||
|
end_tx,
|
||||||
|
request_id: request_id.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
session
|
session
|
||||||
.outgoing_tx
|
.outgoing_tx
|
||||||
|
@ -351,25 +348,48 @@ fn spawn_claude(
|
||||||
|
|
||||||
struct ClaudeAgentSession {
|
struct ClaudeAgentSession {
|
||||||
outgoing_tx: UnboundedSender<SdkMessage>,
|
outgoing_tx: UnboundedSender<SdkMessage>,
|
||||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
turn_state: Rc<RefCell<TurnState>>,
|
||||||
cancellation_state: Rc<RefCell<CancellationState>>,
|
|
||||||
_mcp_server: Option<ClaudeZedMcpServer>,
|
_mcp_server: Option<ClaudeZedMcpServer>,
|
||||||
_handler_task: Task<()>,
|
_handler_task: Task<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
enum CancellationState {
|
enum TurnState {
|
||||||
#[default]
|
#[default]
|
||||||
None,
|
None,
|
||||||
Requested {
|
InProgress {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
|
},
|
||||||
|
CancelRequested {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
},
|
},
|
||||||
Confirmed,
|
CancelConfirmed {
|
||||||
|
end_tx: oneshot::Sender<Result<acp::PromptResponse>>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CancellationState {
|
impl TurnState {
|
||||||
fn is_confirmed(&self) -> bool {
|
fn is_cancelled(&self) -> bool {
|
||||||
matches!(self, CancellationState::Confirmed)
|
matches!(self, TurnState::CancelConfirmed { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_tx(self) -> Option<oneshot::Sender<Result<acp::PromptResponse>>> {
|
||||||
|
match self {
|
||||||
|
TurnState::None => None,
|
||||||
|
TurnState::InProgress { end_tx, .. } => Some(end_tx),
|
||||||
|
TurnState::CancelRequested { end_tx, .. } => Some(end_tx),
|
||||||
|
TurnState::CancelConfirmed { end_tx } => Some(end_tx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn confirm_cancellation(self, id: &str) -> Self {
|
||||||
|
match self {
|
||||||
|
TurnState::CancelRequested { request_id, end_tx } if request_id == id => {
|
||||||
|
TurnState::CancelConfirmed { end_tx }
|
||||||
|
}
|
||||||
|
_ => self,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -377,8 +397,7 @@ impl ClaudeAgentSession {
|
||||||
async fn handle_message(
|
async fn handle_message(
|
||||||
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
mut thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||||
message: SdkMessage,
|
message: SdkMessage,
|
||||||
end_turn_tx: Rc<RefCell<Option<oneshot::Sender<Result<acp::PromptResponse>>>>>,
|
turn_state: Rc<RefCell<TurnState>>,
|
||||||
cancellation_state: Rc<RefCell<CancellationState>>,
|
|
||||||
cx: &mut AsyncApp,
|
cx: &mut AsyncApp,
|
||||||
) {
|
) {
|
||||||
match message {
|
match message {
|
||||||
|
@ -401,7 +420,7 @@ impl ClaudeAgentSession {
|
||||||
for chunk in message.content.chunks() {
|
for chunk in message.content.chunks() {
|
||||||
match chunk {
|
match chunk {
|
||||||
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => {
|
||||||
if !cancellation_state.borrow().is_confirmed() {
|
if !turn_state.borrow().is_cancelled() {
|
||||||
thread
|
thread
|
||||||
.update(cx, |thread, cx| {
|
.update(cx, |thread, cx| {
|
||||||
thread.push_user_content_block(text.into(), cx)
|
thread.push_user_content_block(text.into(), cx)
|
||||||
|
@ -420,10 +439,7 @@ impl ClaudeAgentSession {
|
||||||
acp::ToolCallUpdate {
|
acp::ToolCallUpdate {
|
||||||
id: acp::ToolCallId(tool_use_id.into()),
|
id: acp::ToolCallId(tool_use_id.into()),
|
||||||
fields: acp::ToolCallUpdateFields {
|
fields: acp::ToolCallUpdateFields {
|
||||||
status: if cancellation_state
|
status: if turn_state.borrow().is_cancelled() {
|
||||||
.borrow()
|
|
||||||
.is_confirmed()
|
|
||||||
{
|
|
||||||
// Do not set to completed if turn was cancelled
|
// Do not set to completed if turn was cancelled
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
|
@ -555,37 +571,38 @@ impl ClaudeAgentSession {
|
||||||
result,
|
result,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
if let Some(end_turn_tx) = end_turn_tx.borrow_mut().take() {
|
let turn_state = turn_state.take();
|
||||||
if is_error
|
let was_cancelled = turn_state.is_cancelled();
|
||||||
|| (subtype == ResultErrorType::ErrorDuringExecution
|
let Some(end_turn_tx) = turn_state.end_tx() else {
|
||||||
&& !cancellation_state.borrow().is_confirmed())
|
debug_panic!("Received `SdkMessage::Result` but there wasn't an active turn");
|
||||||
{
|
return;
|
||||||
end_turn_tx
|
};
|
||||||
.send(Err(anyhow!(
|
|
||||||
"Error: {}",
|
if is_error || (!was_cancelled && subtype == ResultErrorType::ErrorDuringExecution)
|
||||||
result.unwrap_or_else(|| subtype.to_string())
|
{
|
||||||
)))
|
end_turn_tx
|
||||||
.ok();
|
.send(Err(anyhow!(
|
||||||
} else {
|
"Error: {}",
|
||||||
let stop_reason = match subtype {
|
result.unwrap_or_else(|| subtype.to_string())
|
||||||
ResultErrorType::Success => acp::StopReason::EndTurn,
|
)))
|
||||||
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
.ok();
|
||||||
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
} else {
|
||||||
};
|
let stop_reason = match subtype {
|
||||||
end_turn_tx
|
ResultErrorType::Success => acp::StopReason::EndTurn,
|
||||||
.send(Ok(acp::PromptResponse { stop_reason }))
|
ResultErrorType::ErrorMaxTurns => acp::StopReason::MaxTurnRequests,
|
||||||
.ok();
|
ResultErrorType::ErrorDuringExecution => acp::StopReason::Cancelled,
|
||||||
}
|
};
|
||||||
|
end_turn_tx
|
||||||
|
.send(Ok(acp::PromptResponse { stop_reason }))
|
||||||
|
.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SdkMessage::ControlResponse { response } => {
|
SdkMessage::ControlResponse { response } => {
|
||||||
if matches!(response.subtype, ResultErrorType::Success) {
|
if matches!(response.subtype, ResultErrorType::Success) {
|
||||||
let mut cancellation_state = cancellation_state.borrow_mut();
|
let new_state = turn_state.take().confirm_cancellation(&response.request_id);
|
||||||
if let CancellationState::Requested { request_id } = &*cancellation_state
|
turn_state.replace(new_state);
|
||||||
&& request_id == &response.request_id
|
} else {
|
||||||
{
|
log::error!("Control response error: {:?}", response);
|
||||||
*cancellation_state = CancellationState::Confirmed;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SdkMessage::System { .. } => {}
|
SdkMessage::System { .. } => {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue