diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index cc8bd483bb..48a16bf685 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -941,7 +941,15 @@ async fn test_cancellation(cx: &mut TestAppContext) { // Cancel the current send and ensure that the event stream is closed, even // if one of the tools is still running. thread.update(cx, |thread, _cx| thread.cancel()); - events.collect::>().await; + let events = events.collect::>().await; + let last_event = events.last(); + assert!( + matches!( + last_event, + Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + ), + "unexpected event {last_event:?}" + ); // Ensure we can still send a new message after cancellation. let events = thread @@ -965,6 +973,62 @@ async fn test_cancellation(cx: &mut TestAppContext) { assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]); } +#[gpui::test] +async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events_1 = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 1"], cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 1!"); + cx.run_until_parked(); + + let events_2 = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 2"], cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 2!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + + let events_1 = events_1.collect::>().await; + assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]); + let events_2 = events_2.collect::>().await; + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + +#[gpui::test] +async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + + let events_1 = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 1"], cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 1!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + let events_1 = events_1.collect::>().await; + + let events_2 = thread.update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello 2"], cx) + }); + cx.run_until_parked(); + fake_model.send_last_completion_stream_text_chunk("Hey 2!"); + fake_model + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + fake_model.end_last_completion_stream(); + let events_2 = events_2.collect::>().await; + + assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]); + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + #[gpui::test] async fn test_refusal(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 0741bb9e08..d8b6286f60 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -461,7 +461,7 @@ pub struct Thread { /// Holds the task that handles agent interaction until the end of the turn. /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. - running_turn: Option>, + running_turn: Option, pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, @@ -554,8 +554,9 @@ impl Thread { } pub fn cancel(&mut self) { - // TODO: do we need to emit a stop::cancel for ACP? - self.running_turn.take(); + if let Some(running_turn) = self.running_turn.take() { + running_turn.cancel(); + } self.flush_pending_message(); } @@ -616,108 +617,118 @@ impl Thread { &mut self, cx: &mut Context, ) -> mpsc::UnboundedReceiver> { + self.cancel(); + let model = self.model.clone(); let (events_tx, events_rx) = mpsc::unbounded::>(); let event_stream = AgentResponseEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); self.tool_use_limit_reached = false; - self.running_turn = Some(cx.spawn(async move |this, cx| { - log::info!("Starting agent turn execution"); - let turn_result: Result<()> = async { - let mut completion_intent = CompletionIntent::UserPrompt; - loop { - log::debug!( - "Building completion request with intent: {:?}", - completion_intent - ); - let request = this.update(cx, |this, cx| { - this.build_completion_request(completion_intent, cx) - })?; + self.running_turn = Some(RunningTurn { + event_stream: event_stream.clone(), + _task: cx.spawn(async move |this, cx| { + log::info!("Starting agent turn execution"); + let turn_result: Result<()> = async { + let mut completion_intent = CompletionIntent::UserPrompt; + loop { + log::debug!( + "Building completion request with intent: {:?}", + completion_intent + ); + let request = this.update(cx, |this, cx| { + this.build_completion_request(completion_intent, cx) + })?; - log::info!("Calling model.stream_completion"); - let mut events = model.stream_completion(request, cx).await?; - log::debug!("Stream completion started successfully"); + log::info!("Calling model.stream_completion"); + let mut events = model.stream_completion(request, cx).await?; + log::debug!("Stream completion started successfully"); - let mut tool_use_limit_reached = false; - let mut tool_uses = FuturesUnordered::new(); - while let Some(event) = events.next().await { - match event? { - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::ToolUseLimitReached, - ) => { - tool_use_limit_reached = true; - } - LanguageModelCompletionEvent::Stop(reason) => { - event_stream.send_stop(reason); - if reason == StopReason::Refusal { - this.update(cx, |this, _cx| { - this.flush_pending_message(); - this.messages.truncate(message_ix); - })?; - return Ok(()); + let mut tool_use_limit_reached = false; + let mut tool_uses = FuturesUnordered::new(); + while let Some(event) = events.next().await { + match event? { + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::ToolUseLimitReached, + ) => { + tool_use_limit_reached = true; + } + LanguageModelCompletionEvent::Stop(reason) => { + event_stream.send_stop(reason); + if reason == StopReason::Refusal { + this.update(cx, |this, _cx| { + this.flush_pending_message(); + this.messages.truncate(message_ix); + })?; + return Ok(()); + } + } + event => { + log::trace!("Received completion event: {:?}", event); + this.update(cx, |this, cx| { + tool_uses.extend(this.handle_streamed_completion_event( + event, + &event_stream, + cx, + )); + }) + .ok(); } } - event => { - log::trace!("Received completion event: {:?}", event); - this.update(cx, |this, cx| { - tool_uses.extend(this.handle_streamed_completion_event( - event, - &event_stream, - cx, - )); - }) - .ok(); - } + } + + let used_tools = tool_uses.is_empty(); + while let Some(tool_result) = tool_uses.next().await { + log::info!("Tool finished {:?}", tool_result); + + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + raw_output: tool_result.output.clone(), + ..Default::default() + }, + ); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + }) + .ok(); + } + + if tool_use_limit_reached { + log::info!("Tool use limit reached, completing turn"); + this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; + return Err(language_model::ToolUseLimitReachedError.into()); + } else if used_tools { + log::info!("No tool uses found, completing turn"); + return Ok(()); + } else { + this.update(cx, |this, _| this.flush_pending_message())?; + completion_intent = CompletionIntent::ToolResults; } } - - let used_tools = tool_uses.is_empty(); - while let Some(tool_result) = tool_uses.next().await { - log::info!("Tool finished {:?}", tool_result); - - event_stream.update_tool_call_fields( - &tool_result.tool_use_id, - acp::ToolCallUpdateFields { - status: Some(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }), - raw_output: tool_result.output.clone(), - ..Default::default() - }, - ); - this.update(cx, |this, _cx| { - this.pending_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - }) - .ok(); - } - - if tool_use_limit_reached { - log::info!("Tool use limit reached, completing turn"); - this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; - return Err(language_model::ToolUseLimitReachedError.into()); - } else if used_tools { - log::info!("No tool uses found, completing turn"); - return Ok(()); - } else { - this.update(cx, |this, _| this.flush_pending_message())?; - completion_intent = CompletionIntent::ToolResults; - } } - } - .await; + .await; - this.update(cx, |this, _| this.flush_pending_message()).ok(); - if let Err(error) = turn_result { - log::error!("Turn execution failed: {:?}", error); - event_stream.send_error(error); - } else { - log::info!("Turn execution completed successfully"); - } - })); + if let Err(error) = turn_result { + log::error!("Turn execution failed: {:?}", error); + event_stream.send_error(error); + } else { + log::info!("Turn execution completed successfully"); + } + + this.update(cx, |this, _| { + this.flush_pending_message(); + this.running_turn.take(); + }) + .ok(); + }), + }); events_rx } @@ -1125,6 +1136,23 @@ impl Thread { } } +struct RunningTurn { + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + _task: Task<()>, + /// The current event stream for the running turn. Used to report a final + /// cancellation event if we cancel the turn. + event_stream: AgentResponseEventStream, +} + +impl RunningTurn { + fn cancel(self) { + log::debug!("Cancelling in progress turn"); + self.event_stream.send_canceled(); + } +} + pub trait AgentTool where Self: 'static + Sized, @@ -1336,6 +1364,12 @@ impl AgentResponseEventStream { } } + fn send_canceled(&self) { + self.0 + .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled))) + .ok(); + } + fn send_error(&self, error: impl Into) { self.0.unbounded_send(Err(error.into())).ok(); }