agent2: Emit cancellation stop reason on cancel (#36381)
Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
parent
b3969ed427
commit
ea828c0c59
2 changed files with 191 additions and 93 deletions
|
@ -461,7 +461,7 @@ pub struct Thread {
|
|||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
running_turn: Option<RunningTurn>,
|
||||
pending_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
tool_use_limit_reached: bool,
|
||||
|
@ -554,8 +554,9 @@ impl Thread {
|
|||
}
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
// TODO: do we need to emit a stop::cancel for ACP?
|
||||
self.running_turn.take();
|
||||
if let Some(running_turn) = self.running_turn.take() {
|
||||
running_turn.cancel();
|
||||
}
|
||||
self.flush_pending_message();
|
||||
}
|
||||
|
||||
|
@ -616,108 +617,118 @@ impl Thread {
|
|||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||
self.cancel();
|
||||
|
||||
let model = self.model.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
let message_ix = self.messages.len().saturating_sub(1);
|
||||
self.tool_use_limit_reached = false;
|
||||
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result: Result<()> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
self.running_turn = Some(RunningTurn {
|
||||
event_stream: event_stream.clone(),
|
||||
_task: cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result: Result<()> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
event => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
event => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
if tool_use_limit_reached {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if used_tools {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
event_stream.update_tool_call_fields(
|
||||
&tool_result.tool_use_id,
|
||||
acp::ToolCallUpdateFields {
|
||||
status: Some(if tool_result.is_error {
|
||||
acp::ToolCallStatus::Failed
|
||||
} else {
|
||||
acp::ToolCallStatus::Completed
|
||||
}),
|
||||
raw_output: tool_result.output.clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
if tool_use_limit_reached {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if used_tools {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, _| this.flush_pending_message()).ok();
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
}));
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
|
||||
this.update(cx, |this, _| {
|
||||
this.flush_pending_message();
|
||||
this.running_turn.take();
|
||||
})
|
||||
.ok();
|
||||
}),
|
||||
});
|
||||
events_rx
|
||||
}
|
||||
|
||||
|
@ -1125,6 +1136,23 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
struct RunningTurn {
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
_task: Task<()>,
|
||||
/// The current event stream for the running turn. Used to report a final
|
||||
/// cancellation event if we cancel the turn.
|
||||
event_stream: AgentResponseEventStream,
|
||||
}
|
||||
|
||||
impl RunningTurn {
|
||||
fn cancel(self) {
|
||||
log::debug!("Cancelling in progress turn");
|
||||
self.event_stream.send_canceled();
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
|
@ -1336,6 +1364,12 @@ impl AgentResponseEventStream {
|
|||
}
|
||||
}
|
||||
|
||||
fn send_canceled(&self) {
|
||||
self.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
||||
self.0.unbounded_send(Err(error.into())).ok();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue