From ca0a20f3d53176459d2592a485365ab63ce2ecd7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 21 Aug 2025 18:11:05 +0200 Subject: [PATCH] acp: Refactor agent2 `send` to have a clearer control flow (#36689) Release Notes: - N/A --- Cargo.lock | 1 + crates/agent2/Cargo.toml | 1 + crates/agent2/src/thread.rs | 295 ++++++++++++++++-------------------- 3 files changed, 134 insertions(+), 163 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 797e38fdac..203e9869e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -244,6 +244,7 @@ dependencies = [ "terminal", "text", "theme", + "thiserror 2.0.12", "tree-sitter-rust", "ui", "unindent", diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index 8dd79062f8..68246a96b0 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -61,6 +61,7 @@ sqlez.workspace = true task.workspace = true telemetry.workspace = true terminal.workspace = true +thiserror.workspace = true text.workspace = true ui.workspace = true util.workspace = true diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index d34c929152..6f560cd390 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -499,6 +499,16 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +#[derive(Debug, thiserror::Error)] +enum CompletionError { + #[error("max tokens")] + MaxTokens, + #[error("refusal")] + Refusal, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + pub struct Thread { id: acp::SessionId, prompt_id: PromptId, @@ -1077,101 +1087,62 @@ impl Thread { _task: cx.spawn(async move |this, cx| { log::info!("Starting agent turn execution"); let mut update_title = None; - let turn_result: Result = async { - let mut completion_intent = CompletionIntent::UserPrompt; + let turn_result: Result<()> = async { + let mut intent = CompletionIntent::UserPrompt; loop { - log::debug!( - "Building completion request with intent: {:?}", - completion_intent - ); - let request = this.update(cx, |this, cx| { - this.build_completion_request(completion_intent, cx) - })??; - - log::info!("Calling model.stream_completion"); - - let mut tool_use_limit_reached = false; - let mut refused = false; - let mut reached_max_tokens = false; - let mut tool_uses = Self::stream_completion_with_retries( - this.clone(), - model.clone(), - request, - &event_stream, - &mut tool_use_limit_reached, - &mut refused, - &mut reached_max_tokens, - cx, - ) - .await?; - - if refused { - return Ok(StopReason::Refusal); - } else if reached_max_tokens { - return Ok(StopReason::MaxTokens); - } - - let end_turn = tool_uses.is_empty(); - while let Some(tool_result) = tool_uses.next().await { - log::info!("Tool finished {:?}", tool_result); - - event_stream.update_tool_call_fields( - &tool_result.tool_use_id, - acp::ToolCallUpdateFields { - status: Some(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }), - raw_output: tool_result.output.clone(), - ..Default::default() - }, - ); - this.update(cx, |this, _cx| { - this.pending_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - })?; - } + Self::stream_completion(&this, &model, intent, &event_stream, cx).await?; + let mut end_turn = true; this.update(cx, |this, cx| { + // Generate title if needed. if this.title.is_none() && update_title.is_none() { update_title = Some(this.update_title(&event_stream, cx)); } + + // End the turn if the model didn't use tools. + let message = this.pending_message.as_ref(); + end_turn = + message.map_or(true, |message| message.tool_results.is_empty()); + this.flush_pending_message(cx); })?; - if tool_use_limit_reached { + if this.read_with(cx, |this, _| this.tool_use_limit_reached)? { log::info!("Tool use limit reached, completing turn"); - this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?; return Err(language_model::ToolUseLimitReachedError.into()); } else if end_turn { log::info!("No tool uses found, completing turn"); - return Ok(StopReason::EndTurn); + return Ok(()); } else { - this.update(cx, |this, cx| this.flush_pending_message(cx))?; - completion_intent = CompletionIntent::ToolResults; + intent = CompletionIntent::ToolResults; } } } .await; _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); + if let Some(update_title) = update_title { + update_title.await.context("update title failed").log_err(); + } + match turn_result { - Ok(reason) => { - log::info!("Turn execution completed: {:?}", reason); - - if let Some(update_title) = update_title { - update_title.await.context("update title failed").log_err(); - } - - event_stream.send_stop(reason); - if reason == StopReason::Refusal { - _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); - } + Ok(()) => { + log::info!("Turn execution completed"); + event_stream.send_stop(acp::StopReason::EndTurn); } Err(error) => { log::error!("Turn execution failed: {:?}", error); - event_stream.send_error(error); + match error.downcast::() { + Ok(CompletionError::Refusal) => { + event_stream.send_stop(acp::StopReason::Refusal); + _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); + } + Ok(CompletionError::MaxTokens) => { + event_stream.send_stop(acp::StopReason::MaxTokens); + } + Ok(CompletionError::Other(error)) | Err(error) => { + event_stream.send_error(error); + } + } } } @@ -1181,17 +1152,17 @@ impl Thread { Ok(events_rx) } - async fn stream_completion_with_retries( - this: WeakEntity, - model: Arc, - request: LanguageModelRequest, + async fn stream_completion( + this: &WeakEntity, + model: &Arc, + completion_intent: CompletionIntent, event_stream: &ThreadEventStream, - tool_use_limit_reached: &mut bool, - refusal: &mut bool, - max_tokens_reached: &mut bool, cx: &mut AsyncApp, - ) -> Result>> { + ) -> Result<()> { log::debug!("Stream completion started successfully"); + let request = this.update(cx, |this, cx| { + this.build_completion_request(completion_intent, cx) + })??; let mut attempt = None; 'retry: loop { @@ -1204,68 +1175,33 @@ impl Thread { attempt ); - let mut events = model.stream_completion(request.clone(), cx).await?; - let mut tool_uses = FuturesUnordered::new(); + log::info!( + "Calling model.stream_completion, attempt {}", + attempt.unwrap_or(0) + ); + let mut events = model + .stream_completion(request.clone(), cx) + .await + .map_err(|error| anyhow!(error))?; + let mut tool_results = FuturesUnordered::new(); + while let Some(event) = events.next().await { match event { - Ok(LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::ToolUseLimitReached, - )) => { - *tool_use_limit_reached = true; - } - Ok(LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - )) => { - this.update(cx, |this, cx| { - this.update_model_request_usage(amount, limit, cx) - })?; - } - Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => { - telemetry::event!( - "Agent Thread Completion Usage Updated", - thread_id = this.read_with(cx, |this, _| this.id.to_string())?, - prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - attempt, - input_tokens = usage.input_tokens, - output_tokens = usage.output_tokens, - cache_creation_input_tokens = usage.cache_creation_input_tokens, - cache_read_input_tokens = usage.cache_read_input_tokens, - ); - - this.update(cx, |this, cx| this.update_token_usage(usage, cx))?; - } - Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => { - *refusal = true; - return Ok(FuturesUnordered::default()); - } - Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => { - *max_tokens_reached = true; - return Ok(FuturesUnordered::default()); - } - Ok(LanguageModelCompletionEvent::Stop( - StopReason::ToolUse | StopReason::EndTurn, - )) => break, Ok(event) => { log::trace!("Received completion event: {:?}", event); - this.update(cx, |this, cx| { - tool_uses.extend(this.handle_streamed_completion_event( - event, - event_stream, - cx, - )); - })?; + tool_results.extend(this.update(cx, |this, cx| { + this.handle_streamed_completion_event(event, event_stream, cx) + })??); } Err(error) => { let completion_mode = this.read_with(cx, |thread, _cx| thread.completion_mode())?; if completion_mode == CompletionMode::Normal { - return Err(error.into()); + return Err(anyhow!(error))?; } let Some(strategy) = Self::retry_strategy_for(&error) else { - return Err(error.into()); + return Err(anyhow!(error))?; }; let max_attempts = match &strategy { @@ -1279,7 +1215,7 @@ impl Thread { let attempt = *attempt; if attempt > max_attempts { - return Err(error.into()); + return Err(anyhow!(error))?; } let delay = match &strategy { @@ -1306,7 +1242,29 @@ impl Thread { } } - return Ok(tool_uses); + while let Some(tool_result) = tool_results.next().await { + log::info!("Tool finished {:?}", tool_result); + + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + raw_output: tool_result.output.clone(), + ..Default::default() + }, + ); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + })?; + } + + return Ok(()); } } @@ -1328,14 +1286,14 @@ impl Thread { } /// A helper method that's called on every streamed completion event. - /// Returns an optional tool result task, which the main agentic loop in - /// send will send back to the model when it resolves. + /// Returns an optional tool result task, which the main agentic loop will + /// send back to the model when it resolves. fn handle_streamed_completion_event( &mut self, event: LanguageModelCompletionEvent, event_stream: &ThreadEventStream, cx: &mut Context, - ) -> Option> { + ) -> Result>> { log::trace!("Handling streamed completion event: {:?}", event); use LanguageModelCompletionEvent::*; @@ -1350,7 +1308,7 @@ impl Thread { } RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), ToolUse(tool_use) => { - return self.handle_tool_use_event(tool_use, event_stream, cx); + return Ok(self.handle_tool_use_event(tool_use, event_stream, cx)); } ToolUseJsonParseError { id, @@ -1358,18 +1316,46 @@ impl Thread { raw_input, json_parse_error, } => { - return Some(Task::ready(self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, + return Ok(Some(Task::ready( + self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ), ))); } - StatusUpdate(_) => {} - UsageUpdate(_) | Stop(_) => unreachable!(), + UsageUpdate(usage) => { + telemetry::event!( + "Agent Thread Completion Usage Updated", + thread_id = self.id.to_string(), + prompt_id = self.prompt_id.to_string(), + model = self.model.as_ref().map(|m| m.telemetry_id()), + model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()), + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + cache_creation_input_tokens = usage.cache_creation_input_tokens, + cache_read_input_tokens = usage.cache_read_input_tokens, + ); + self.update_token_usage(usage, cx); + } + StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { + self.update_model_request_usage(amount, limit, cx); + } + StatusUpdate( + CompletionRequestStatus::Started + | CompletionRequestStatus::Queued { .. } + | CompletionRequestStatus::Failed { .. }, + ) => {} + StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => { + self.tool_use_limit_reached = true; + } + Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), + Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), + Stop(StopReason::ToolUse | StopReason::EndTurn) => {} } - None + Ok(None) } fn handle_text_event( @@ -2225,25 +2211,8 @@ impl ThreadEventStream { self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); } - fn send_stop(&self, reason: StopReason) { - match reason { - StopReason::EndTurn => { - self.0 - .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn))) - .ok(); - } - StopReason::MaxTokens => { - self.0 - .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens))) - .ok(); - } - StopReason::Refusal => { - self.0 - .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal))) - .ok(); - } - StopReason::ToolUse => {} - } + fn send_stop(&self, reason: acp::StopReason) { + self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok(); } fn send_canceled(&self) {