agent2: Emit cancellation stop reason on cancel (#36381)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
b3969ed427
commit
ea828c0c59
2 changed files with 191 additions and 93 deletions
|
@ -941,7 +941,15 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
// Cancel the current send and ensure that the event stream is closed, even
|
// Cancel the current send and ensure that the event stream is closed, even
|
||||||
// if one of the tools is still running.
|
// if one of the tools is still running.
|
||||||
thread.update(cx, |thread, _cx| thread.cancel());
|
thread.update(cx, |thread, _cx| thread.cancel());
|
||||||
events.collect::<Vec<_>>().await;
|
let events = events.collect::<Vec<_>>().await;
|
||||||
|
let last_event = events.last();
|
||||||
|
assert!(
|
||||||
|
matches!(
|
||||||
|
last_event,
|
||||||
|
Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
||||||
|
),
|
||||||
|
"unexpected event {last_event:?}"
|
||||||
|
);
|
||||||
|
|
||||||
// Ensure we can still send a new message after cancellation.
|
// Ensure we can still send a new message after cancellation.
|
||||||
let events = thread
|
let events = thread
|
||||||
|
@ -965,6 +973,62 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let events_1 = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
||||||
|
cx.run_until_parked();
|
||||||
|
|
||||||
|
let events_2 = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
|
||||||
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
||||||
|
assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
|
||||||
|
let events_2 = events_2.collect::<Vec<_>>().await;
|
||||||
|
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
|
||||||
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
let fake_model = model.as_fake();
|
||||||
|
|
||||||
|
let events_1 = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello 1"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey 1!");
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
let events_1 = events_1.collect::<Vec<_>>().await;
|
||||||
|
|
||||||
|
let events_2 = thread.update(cx, |thread, cx| {
|
||||||
|
thread.send(UserMessageId::new(), ["Hello 2"], cx)
|
||||||
|
});
|
||||||
|
cx.run_until_parked();
|
||||||
|
fake_model.send_last_completion_stream_text_chunk("Hey 2!");
|
||||||
|
fake_model
|
||||||
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
|
||||||
|
fake_model.end_last_completion_stream();
|
||||||
|
let events_2 = events_2.collect::<Vec<_>>().await;
|
||||||
|
|
||||||
|
assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
|
||||||
|
assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_refusal(cx: &mut TestAppContext) {
|
async fn test_refusal(cx: &mut TestAppContext) {
|
||||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||||
|
|
|
@ -461,7 +461,7 @@ pub struct Thread {
|
||||||
/// Holds the task that handles agent interaction until the end of the turn.
|
/// Holds the task that handles agent interaction until the end of the turn.
|
||||||
/// Survives across multiple requests as the model performs tool calls and
|
/// Survives across multiple requests as the model performs tool calls and
|
||||||
/// we run tools, report their results.
|
/// we run tools, report their results.
|
||||||
running_turn: Option<Task<()>>,
|
running_turn: Option<RunningTurn>,
|
||||||
pending_message: Option<AgentMessage>,
|
pending_message: Option<AgentMessage>,
|
||||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||||
tool_use_limit_reached: bool,
|
tool_use_limit_reached: bool,
|
||||||
|
@ -554,8 +554,9 @@ impl Thread {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cancel(&mut self) {
|
pub fn cancel(&mut self) {
|
||||||
// TODO: do we need to emit a stop::cancel for ACP?
|
if let Some(running_turn) = self.running_turn.take() {
|
||||||
self.running_turn.take();
|
running_turn.cancel();
|
||||||
|
}
|
||||||
self.flush_pending_message();
|
self.flush_pending_message();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,12 +617,16 @@ impl Thread {
|
||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||||
|
self.cancel();
|
||||||
|
|
||||||
let model = self.model.clone();
|
let model = self.model.clone();
|
||||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||||
let event_stream = AgentResponseEventStream(events_tx);
|
let event_stream = AgentResponseEventStream(events_tx);
|
||||||
let message_ix = self.messages.len().saturating_sub(1);
|
let message_ix = self.messages.len().saturating_sub(1);
|
||||||
self.tool_use_limit_reached = false;
|
self.tool_use_limit_reached = false;
|
||||||
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
self.running_turn = Some(RunningTurn {
|
||||||
|
event_stream: event_stream.clone(),
|
||||||
|
_task: cx.spawn(async move |this, cx| {
|
||||||
log::info!("Starting agent turn execution");
|
log::info!("Starting agent turn execution");
|
||||||
let turn_result: Result<()> = async {
|
let turn_result: Result<()> = async {
|
||||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||||
|
@ -710,14 +715,20 @@ impl Thread {
|
||||||
}
|
}
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
this.update(cx, |this, _| this.flush_pending_message()).ok();
|
|
||||||
if let Err(error) = turn_result {
|
if let Err(error) = turn_result {
|
||||||
log::error!("Turn execution failed: {:?}", error);
|
log::error!("Turn execution failed: {:?}", error);
|
||||||
event_stream.send_error(error);
|
event_stream.send_error(error);
|
||||||
} else {
|
} else {
|
||||||
log::info!("Turn execution completed successfully");
|
log::info!("Turn execution completed successfully");
|
||||||
}
|
}
|
||||||
}));
|
|
||||||
|
this.update(cx, |this, _| {
|
||||||
|
this.flush_pending_message();
|
||||||
|
this.running_turn.take();
|
||||||
|
})
|
||||||
|
.ok();
|
||||||
|
}),
|
||||||
|
});
|
||||||
events_rx
|
events_rx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1125,6 +1136,23 @@ impl Thread {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RunningTurn {
|
||||||
|
/// Holds the task that handles agent interaction until the end of the turn.
|
||||||
|
/// Survives across multiple requests as the model performs tool calls and
|
||||||
|
/// we run tools, report their results.
|
||||||
|
_task: Task<()>,
|
||||||
|
/// The current event stream for the running turn. Used to report a final
|
||||||
|
/// cancellation event if we cancel the turn.
|
||||||
|
event_stream: AgentResponseEventStream,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RunningTurn {
|
||||||
|
fn cancel(self) {
|
||||||
|
log::debug!("Cancelling in progress turn");
|
||||||
|
self.event_stream.send_canceled();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait AgentTool
|
pub trait AgentTool
|
||||||
where
|
where
|
||||||
Self: 'static + Sized,
|
Self: 'static + Sized,
|
||||||
|
@ -1336,6 +1364,12 @@ impl AgentResponseEventStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn send_canceled(&self) {
|
||||||
|
self.0
|
||||||
|
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
||||||
self.0.unbounded_send(Err(error.into())).ok();
|
self.0.unbounded_send(Err(error.into())).ok();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue