agent2: Port retry logic (#36421)

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-08-19 11:41:55 +02:00 committed by GitHub
parent 47e1d4511c
commit 0ea0d466d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 514 additions and 52 deletions

View file

@ -12,12 +12,12 @@ use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
};
use gpui::{App, Context, Entity, SharedString, Task};
use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@ -25,7 +25,12 @@ use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::{collections::BTreeMap, path::Path, sync::Arc};
use std::{
collections::BTreeMap,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use std::{fmt::Write, ops::Range};
use util::{ResultExt, markdown::MarkdownCodeBlock};
use uuid::Uuid;
@ -71,6 +76,21 @@ impl std::fmt::Display for PromptId {
}
}
pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4;
pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
enum RetryStrategy {
ExponentialBackoff {
initial_delay: Duration,
max_attempts: u8,
},
Fixed {
delay: Duration,
max_attempts: u8,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
User(UserMessage),
@ -455,6 +475,7 @@ pub enum AgentResponseEvent {
ToolCall(acp::ToolCall),
ToolCallUpdate(acp_thread::ToolCallUpdate),
ToolCallAuthorization(ToolCallAuthorization),
Retry(acp_thread::RetryStatus),
Stop(acp::StopReason),
}
@ -662,41 +683,18 @@ impl Thread {
})??;
log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
) => {
tool_use_limit_reached = true;
}
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(());
}
}
event => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
&event_stream,
cx,
));
})
.ok();
}
}
}
let mut tool_uses = Self::stream_completion_with_retries(
this.clone(),
model.clone(),
request,
message_ix,
&event_stream,
&mut tool_use_limit_reached,
cx,
)
.await?;
let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
@ -754,10 +752,105 @@ impl Thread {
Ok(events_rx)
}
async fn stream_completion_with_retries(
this: WeakEntity<Self>,
model: Arc<dyn LanguageModel>,
request: LanguageModelRequest,
message_ix: usize,
event_stream: &AgentResponseEventStream,
tool_use_limit_reached: &mut bool,
cx: &mut AsyncApp,
) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
log::debug!("Stream completion started successfully");
let mut attempt = None;
'retry: loop {
let mut events = model.stream_completion(request.clone(), cx).await?;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event {
Ok(LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::ToolUseLimitReached,
)) => {
*tool_use_limit_reached = true;
}
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
this.update(cx, |this, _cx| {
this.flush_pending_message();
this.messages.truncate(message_ix);
})?;
return Ok(tool_uses);
}
}
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
this.update(cx, |this, cx| {
tool_uses.extend(this.handle_streamed_completion_event(
event,
event_stream,
cx,
));
})
.ok();
}
Err(error) => {
let completion_mode =
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
return Err(error.into());
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
return Err(error.into());
};
let max_attempts = match &strategy {
RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts,
RetryStrategy::Fixed { max_attempts, .. } => *max_attempts,
};
let attempt = attempt.get_or_insert(0u8);
*attempt += 1;
let attempt = *attempt;
if attempt > max_attempts {
return Err(error.into());
}
let delay = match &strategy {
RetryStrategy::ExponentialBackoff { initial_delay, .. } => {
let delay_secs =
initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32);
Duration::from_secs(delay_secs)
}
RetryStrategy::Fixed { delay, .. } => *delay,
};
log::debug!("Retry attempt {attempt} with delay {delay:?}");
event_stream.send_retry(acp_thread::RetryStatus {
last_error: error.to_string().into(),
attempt: attempt as usize,
max_attempts: max_attempts as usize,
started_at: Instant::now(),
duration: delay,
});
cx.background_executor().timer(delay).await;
continue 'retry;
}
}
}
return Ok(tool_uses);
}
}
pub fn build_system_message(&self, cx: &App) -> LanguageModelRequestMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
project: &self.project_context.read(cx),
project: self.project_context.read(cx),
available_tools: self.tools.keys().cloned().collect(),
}
.render(&self.templates)
@ -1158,6 +1251,113 @@ impl Thread {
fn advance_prompt_id(&mut self) {
self.prompt_id = PromptId::new();
}
fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option<RetryStrategy> {
use LanguageModelCompletionError::*;
use http_client::StatusCode;
// General strategy here:
// - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all.
// - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff.
// - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times.
match error {
HttpResponseError {
status_code: StatusCode::TOO_MANY_REQUESTS,
..
} => Some(RetryStrategy::ExponentialBackoff {
initial_delay: BASE_RETRY_DELAY,
max_attempts: MAX_RETRY_ATTEMPTS,
}),
ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
UpstreamProviderError {
status,
retry_after,
..
} => match *status {
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
}
StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
// Internal Server Error could be anything, retry up to 3 times.
max_attempts: 3,
}),
status => {
// There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"),
// but we frequently get them in practice. See https://http.dev/529
if status.as_u16() == 529 {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: MAX_RETRY_ATTEMPTS,
})
} else {
Some(RetryStrategy::Fixed {
delay: retry_after.unwrap_or(BASE_RETRY_DELAY),
max_attempts: 2,
})
}
}
},
ApiInternalServerError { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
ApiReadResponseError { .. }
| HttpSend { .. }
| DeserializeResponse { .. }
| BadRequestFormat { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
}),
// Retrying these errors definitely shouldn't help.
HttpResponseError {
status_code:
StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED,
..
}
| AuthenticationError { .. }
| PermissionError { .. }
| NoApiKey { .. }
| ApiEndpointNotFound { .. }
| PromptTooLarge { .. } => None,
// These errors might be transient, so retry them
SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 1,
}),
// Retry all other 4xx and 5xx errors once.
HttpResponseError { status_code, .. }
if status_code.is_client_error() || status_code.is_server_error() =>
{
Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 3,
})
}
Other(err)
if err.is::<language_model::PaymentRequiredError>()
|| err.is::<language_model::ModelRequestLimitReachedError>() =>
{
// Retrying won't help for Payment Required or Model Request Limit errors (where
// the user must upgrade to usage-based billing to get more requests, or else wait
// for a significant amount of time for the request limit to reset).
None
}
// Conservatively assume that any other errors are non-retryable
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,
max_attempts: 2,
}),
}
}
}
struct RunningTurn {
@ -1367,6 +1567,12 @@ impl AgentResponseEventStream {
.ok();
}
fn send_retry(&self, status: acp_thread::RetryStatus) {
self.0
.unbounded_send(Ok(AgentResponseEvent::Retry(status)))
.ok();
}
fn send_stop(&self, reason: StopReason) {
match reason {
StopReason::EndTurn => {