Retry on 5xx errors from cloud language model providers (#27584)

Release Notes:

- N/A
This commit is contained in:
Richard Feldman 2025-03-27 09:35:16 -04:00 committed by GitHub
parent e6c473a488
commit 76d3a9a0f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -27,9 +27,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::io::{AsyncReadExt, BufReader};
use smol::Timer;
use std::{
future,
sync::{Arc, LazyLock},
time::Duration,
};
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
@ -456,6 +458,8 @@ pub struct CloudLanguageModel {
}
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@ -464,9 +468,10 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
let mut did_retry = false;
let mut retries_remaining = Self::MAX_RETRIES;
let mut retry_delay = Duration::from_secs(1);
let response = loop {
loop {
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
@ -475,36 +480,53 @@ impl CloudLanguageModel {
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
if response.status().is_success() {
break response;
} else if !did_retry
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
let status = response.status();
if status.is_success() {
return Ok(response);
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
} else if response.status() == StatusCode::FORBIDDEN
} else if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
.is_some()
{
break Err(anyhow!(MaxMonthlySpendReachedError))?;
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
break Err(anyhow!(PaymentRequiredError))?;
return Err(anyhow!(MaxMonthlySpendReachedError));
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:
// * 500 Internal Server Error
// * 502 Bad Gateway
// * 529 Service Overloaded
if retries_remaining == 0 {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(
"cloud language model completion failed after {} retries with status {status}: {body}",
Self::MAX_RETRIES
));
}
Timer::after(retry_delay).await;
retries_remaining -= 1;
retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
break Err(anyhow!(
"cloud language model completion failed with status {}: {body}",
response.status()
))?;
return Err(anyhow!(
"cloud language model completion failed with status {status}: {body}",
));
}
};
Ok(response)
}
}
}