acp: Refactor agent2 send to have a clearer control flow (#36689)

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2025-08-21 18:11:05 +02:00 committed by GitHub
parent 132daef9f6
commit 190217a43b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 134 additions and 163 deletions

1
Cargo.lock generated
View file

@ -244,6 +244,7 @@ dependencies = [
"terminal",
"text",
"theme",
"thiserror 2.0.12",
"tree-sitter-rust",
"ui",
"unindent",

View file

@ -61,6 +61,7 @@ sqlez.workspace = true
task.workspace = true
telemetry.workspace = true
terminal.workspace = true
thiserror.workspace = true
text.workspace = true
ui.workspace = true
util.workspace = true

View file

@ -499,6 +499,16 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
#[derive(Debug, thiserror::Error)]
enum CompletionError {
#[error("max tokens")]
MaxTokens,
#[error("refusal")]
Refusal,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub struct Thread {
id: acp::SessionId,
prompt_id: PromptId,
@ -1077,101 +1087,62 @@ impl Thread {
_task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
let mut update_title = None;
let turn_result: Result<StopReason> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
let turn_result: Result<()> = async {
let mut intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
"Building completion request with intent: {:?}",
completion_intent
);
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})??;
log::info!("Calling model.stream_completion");
let mut tool_use_limit_reached = false;
let mut refused = false;
let mut reached_max_tokens = false;
let mut tool_uses = Self::stream_completion_with_retries(
this.clone(),
model.clone(),
request,
&event_stream,
&mut tool_use_limit_reached,
&mut refused,
&mut reached_max_tokens,
cx,
)
.await?;
if refused {
return Ok(StopReason::Refusal);
} else if reached_max_tokens {
return Ok(StopReason::MaxTokens);
}
let end_turn = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
event_stream.update_tool_call_fields(
&tool_result.tool_use_id,
acp::ToolCallUpdateFields {
status: Some(if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}),
raw_output: tool_result.output.clone(),
..Default::default()
},
);
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
})?;
}
Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
let mut end_turn = true;
this.update(cx, |this, cx| {
// Generate title if needed.
if this.title.is_none() && update_title.is_none() {
update_title = Some(this.update_title(&event_stream, cx));
}
// End the turn if the model didn't use tools.
let message = this.pending_message.as_ref();
end_turn =
message.map_or(true, |message| message.tool_results.is_empty());
this.flush_pending_message(cx);
})?;
if tool_use_limit_reached {
if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
log::info!("No tool uses found, completing turn");
return Ok(StopReason::EndTurn);
return Ok(());
} else {
this.update(cx, |this, cx| this.flush_pending_message(cx))?;
completion_intent = CompletionIntent::ToolResults;
intent = CompletionIntent::ToolResults;
}
}
}
.await;
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
if let Some(update_title) = update_title {
update_title.await.context("update title failed").log_err();
}
match turn_result {
Ok(reason) => {
log::info!("Turn execution completed: {:?}", reason);
if let Some(update_title) = update_title {
update_title.await.context("update title failed").log_err();
}
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
}
Ok(()) => {
log::info!("Turn execution completed");
event_stream.send_stop(acp::StopReason::EndTurn);
}
Err(error) => {
log::error!("Turn execution failed: {:?}", error);
event_stream.send_error(error);
match error.downcast::<CompletionError>() {
Ok(CompletionError::Refusal) => {
event_stream.send_stop(acp::StopReason::Refusal);
_ = this.update(cx, |this, _| this.messages.truncate(message_ix));
}
Ok(CompletionError::MaxTokens) => {
event_stream.send_stop(acp::StopReason::MaxTokens);
}
Ok(CompletionError::Other(error)) | Err(error) => {
event_stream.send_error(error);
}
}
}
}
@ -1181,17 +1152,17 @@ impl Thread {
Ok(events_rx)
}
async fn stream_completion_with_retries(
this: WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
async fn stream_completion(
this: &WeakEntity<Self>,
model: &Arc<dyn LanguageModel>,
completion_intent: CompletionIntent,
event_stream: &ThreadEventStream,
tool_use_limit_reached: &mut bool,
refusal: &mut bool,
max_tokens_reached: &mut bool,
cx: &mut AsyncApp,
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
) -> Result<()> {
log::debug!("Stream completion started successfully");
let request = this.update(cx, |this, cx| {
this.build_completion_request(completion_intent, cx)
})??;
let mut attempt = None;
'retry: loop {
@ -1204,68 +1175,33 @@ impl Thread {
attempt
);
let mut events = model.stream_completion(request.clone(), cx).await?;
let mut tool_uses = FuturesUnordered::new();
log::info!(
"Calling model.stream_completion, attempt {}",
attempt.unwrap_or(0)
);
let mut events = model
.stream_completion(request.clone(), cx)
.await
.map_err(|error| anyhow!(error))?;
let mut tool_results = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
)) => {
*tool_use_limit_reached = true;
}
Ok(LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
)) => {
this.update(cx, |this, cx| {
this.update_model_request_usage(amount, limit, cx)
})?;
}
Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
telemetry::event!(
"Agent Thread Completion Usage Updated",
thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
model = model.telemetry_id(),
model_provider = model.provider_id().to_string(),
attempt,
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
cache_creation_input_tokens = usage.cache_creation_input_tokens,
cache_read_input_tokens = usage.cache_read_input_tokens,
);
this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
}
Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
*refusal = true;
return Ok(FuturesUnordered::default());
}
Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
*max_tokens_reached = true;
return Ok(FuturesUnordered::default());
}
Ok(LanguageModelCompletionEvent::Stop(
StopReason::ToolUse | StopReason::EndTurn,
)) => break,
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
event_stream,
cx,
));
})?;
tool_results.extend(this.update(cx, |this, cx| {
this.handle_streamed_completion_event(event, event_stream, cx)
})??);
}
Err(error) => {
let completion_mode =
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(error.into());
return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(error.into());
return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
@ -1279,7 +1215,7 @@ impl Thread {
let attempt = *attempt;
if attempt > max_attempts {
return Err(error.into());
return Err(anyhow!(error))?;
}
let delay = match &strategy {
@ -1306,7 +1242,29 @@ impl Thread {
}
}
return Ok(tool_uses);
while let Some(tool_result) = tool_results.next().await {
log::info!("Tool finished {:?}", tool_result);
event_stream.update_tool_call_fields(
&tool_result.tool_use_id,
acp::ToolCallUpdateFields {
status: Some(if tool_result.is_error {
acp::ToolCallStatus::Failed
} else {
acp::ToolCallStatus::Completed
}),
raw_output: tool_result.output.clone(),
..Default::default()
},
);
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
})?;
}
return Ok(());
}
}
@ -1328,14 +1286,14 @@ impl Thread {
}
/// A helper method that's called on every streamed completion event.
/// Returns an optional tool result task, which the main agentic loop in
/// send will send back to the model when it resolves.
/// Returns an optional tool result task, which the main agentic loop will
/// send back to the model when it resolves.
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
) -> Result<Option<Task<LanguageModelToolResult>>> {
log::trace!("Handling streamed completion event: {:?}", event);
use LanguageModelCompletionEvent::*;
@ -1350,7 +1308,7 @@ impl Thread {
}
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
ToolUse(tool_use) => {
return self.handle_tool_use_event(tool_use, event_stream, cx);
return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
}
ToolUseJsonParseError {
id,
@ -1358,18 +1316,46 @@ impl Thread {
raw_input,
json_parse_error,
} => {
return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
id,
tool_name,
raw_input,
json_parse_error,
return Ok(Some(Task::ready(
self.handle_tool_use_json_parse_error_event(
id,
tool_name,
raw_input,
json_parse_error,
),
)));
}
StatusUpdate(_) => {}
UsageUpdate(_) | Stop(_) => unreachable!(),
UsageUpdate(usage) => {
telemetry::event!(
"Agent Thread Completion Usage Updated",
thread_id = self.id.to_string(),
prompt_id = self.prompt_id.to_string(),
model = self.model.as_ref().map(|m| m.telemetry_id()),
model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
input_tokens = usage.input_tokens,
output_tokens = usage.output_tokens,
cache_creation_input_tokens = usage.cache_creation_input_tokens,
cache_read_input_tokens = usage.cache_read_input_tokens,
);
self.update_token_usage(usage, cx);
}
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
self.update_model_request_usage(amount, limit, cx);
}
StatusUpdate(
CompletionRequestStatus::Started
| CompletionRequestStatus::Queued { .. }
| CompletionRequestStatus::Failed { .. },
) => {}
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
self.tool_use_limit_reached = true;
}
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
}
None
Ok(None)
}
fn handle_text_event(
@ -2225,25 +2211,8 @@ impl ThreadEventStream {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {
self.0
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
.ok();
}
StopReason::MaxTokens => {
self.0
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
.ok();
}
StopReason::Refusal => {
self.0
.unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
.ok();
}
StopReason::ToolUse => {}
}
fn send_stop(&self, reason: acp::StopReason) {
self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
}
fn send_canceled(&self) {