agent2: Emit cancellation stop reason on cancel (#36381)

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Ben Brandt 2025-08-18 09:58:30 +02:00 committed by GitHub
parent b3969ed427
commit ea828c0c59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 191 additions and 93 deletions

View file

@ -941,7 +941,15 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Cancel the current send and ensure that the event stream is closed, even // Cancel the current send and ensure that the event stream is closed, even
// if one of the tools is still running. // if one of the tools is still running.
thread.update(cx, |thread, _cx| thread.cancel()); thread.update(cx, |thread, _cx| thread.cancel());
events.collect::<Vec<_>>().await; let events = events.collect::<Vec<_>>().await;
let last_event = events.last();
assert!(
matches!(
last_event,
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
),
"unexpected event {last_event:?}"
);
// Ensure we can still send a new message after cancellation. // Ensure we can still send a new message after cancellation.
let events = thread let events = thread
@ -965,6 +973,62 @@ async fn test_cancellation(cx: &mut TestAppContext) {
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
} }
#[gpui::test]
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events_1 = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 1"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
cx.run_until_parked();
let events_2 = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 2"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
fake_model.end_last_completion_stream();
let events_1 = events_1.collect::<Vec<_>>().await;
assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
let events_2 = events_2.collect::<Vec<_>>().await;
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
}
#[gpui::test]
async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events_1 = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 1"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
fake_model.end_last_completion_stream();
let events_1 = events_1.collect::<Vec<_>>().await;
let events_2 = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello 2"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
fake_model.end_last_completion_stream();
let events_2 = events_2.collect::<Vec<_>>().await;
assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
}
#[gpui::test] #[gpui::test]
async fn test_refusal(cx: &mut TestAppContext) { async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;

View file

@ -461,7 +461,7 @@ pub struct Thread {
/// Holds the task that handles agent interaction until the end of the turn. /// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and /// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results. /// we run tools, report their results.
running_turn: Option<Task<()>>, running_turn: Option<RunningTurn>,
pending_message: Option<AgentMessage>, pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>, tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
tool_use_limit_reached: bool, tool_use_limit_reached: bool,
@ -554,8 +554,9 @@ impl Thread {
} }
pub fn cancel(&mut self) { pub fn cancel(&mut self) {
// TODO: do we need to emit a stop::cancel for ACP? if let Some(running_turn) = self.running_turn.take() {
self.running_turn.take(); running_turn.cancel();
}
self.flush_pending_message(); self.flush_pending_message();
} }
@ -616,108 +617,118 @@ impl Thread {
&mut self, &mut self,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> { ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
self.cancel();
let model = self.model.clone(); let model = self.model.clone();
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>(); let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let event_stream = AgentResponseEventStream(events_tx); let event_stream = AgentResponseEventStream(events_tx);
let message_ix = self.messages.len().saturating_sub(1); let message_ix = self.messages.len().saturating_sub(1);
self.tool_use_limit_reached = false; self.tool_use_limit_reached = false;
self.running_turn = Some(cx.spawn(async move |this, cx| { self.running_turn = Some(RunningTurn {
log::info!("Starting agent turn execution"); event_stream: event_stream.clone(),
let turn_result: Result<()> = async { _task: cx.spawn(async move |this, cx| {
let mut completion_intent = CompletionIntent::UserPrompt; log::info!("Starting agent turn execution");
loop { let turn_result: Result<()> = async {
log::debug!( let mut completion_intent = CompletionIntent::UserPrompt;
"Building completion request with intent: {:?}", loop {
completion_intent log::debug!(
); "Building completion request with intent: {:?}",
let request = this.update(cx, |this, cx| { completion_intent
this.build_completion_request(completion_intent, cx) );
})?; let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})?;
log::info!("Calling model.stream_completion"); log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?; let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully"); log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false; let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new(); let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await { while let Some(event) = events.next().await {
match event? { match event? {
LanguageModelCompletionEvent::StatusUpdate( LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached, CompletionRequestStatus::ToolUseLimitReached,
) => { ) => {
tool_use_limit_reached = true; tool_use_limit_reached = true;
} }
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason); event_stream.send_stop(reason);
if reason == StopReason::Refusal { if reason == StopReason::Refusal {
this.update(cx, |this, _cx| { this.update(cx, |this, _cx| {
this.flush_pending_message(); this.flush_pending_message();
this.messages.truncate(message_ix); this.messages.truncate(message_ix);
})?; })?;
return Ok(()); return Ok(());
}
}
event => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
&event_stream,
cx,
));
})
.ok();
} }
} }
event => { }
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| { let used_tools = tool_uses.is_empty();
tool_uses.extend(this.handle_streamed_completion_event( while let Some(tool_result) = tool_uses.next().await {
event, log::info!("Tool finished {:?}", tool_result);
&event_stream,
cx, event_stream.update_tool_call_fields(
)); &tool_result.tool_use_id,
}) acp::ToolCallUpdateFields {
.ok(); status: Some(if tool_result.is_error {
} acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}),
raw_output: tool_result.output.clone(),
..Default::default()
},
);
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
})
.ok();
}
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if used_tools {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
} }
} }
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
event_stream.update_tool_call_fields(
&tool_result.tool_use_id,
acp::ToolCallUpdateFields {
status: Some(if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}),
raw_output: tool_result.output.clone(),
..Default::default()
},
);
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
})
.ok();
}
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if used_tools {
log::info!("No tool uses found, completing turn");
return Ok(());
} else {
this.update(cx, |this, _| this.flush_pending_message())?;
completion_intent = CompletionIntent::ToolResults;
}
} }
} .await;
.await;
this.update(cx, |this, _| this.flush_pending_message()).ok(); if let Err(error) = turn_result {
if let Err(error) = turn_result { log::error!("Turn execution failed: {:?}", error);
log::error!("Turn execution failed: {:?}", error); event_stream.send_error(error);
event_stream.send_error(error); } else {
} else { log::info!("Turn execution completed successfully");
log::info!("Turn execution completed successfully"); }
}
})); this.update(cx, |this, _| {
this.flush_pending_message();
this.running_turn.take();
})
.ok();
}),
});
events_rx events_rx
} }
@ -1125,6 +1136,23 @@ impl Thread {
} }
} }
struct RunningTurn {
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
_task: Task<()>,
/// The current event stream for the running turn. Used to report a final
/// cancellation event if we cancel the turn.
event_stream: AgentResponseEventStream,
}
impl RunningTurn {
fn cancel(self) {
log::debug!("Cancelling in progress turn");
self.event_stream.send_canceled();
}
}
pub trait AgentTool pub trait AgentTool
where where
Self: 'static + Sized, Self: 'static + Sized,
@ -1336,6 +1364,12 @@ impl AgentResponseEventStream {
} }
} }
fn send_canceled(&self) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
.ok();
}
fn send_error(&self, error: impl Into<anyhow::Error>) { fn send_error(&self, error: impl Into<anyhow::Error>) {
self.0.unbounded_send(Err(error.into())).ok(); self.0.unbounded_send(Err(error.into())).ok();
} }