Retry on 5xx errors from cloud language model providers (#27584)
Release Notes: - N/A
This commit is contained in:
parent
e6c473a488
commit
76d3a9a0f0
1 changed files with 43 additions and 21 deletions
|
@ -27,9 +27,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
|||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use smol::Timer;
|
||||
use std::{
|
||||
future,
|
||||
sync::{Arc, LazyLock},
|
||||
time::Duration,
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::{prelude::*, TintColor};
|
||||
|
@ -456,6 +458,8 @@ pub struct CloudLanguageModel {
|
|||
}
|
||||
|
||||
impl CloudLanguageModel {
|
||||
const MAX_RETRIES: usize = 3;
|
||||
|
||||
async fn perform_llm_completion(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
|
@ -464,9 +468,10 @@ impl CloudLanguageModel {
|
|||
let http_client = &client.http_client();
|
||||
|
||||
let mut token = llm_api_token.acquire(&client).await?;
|
||||
let mut did_retry = false;
|
||||
let mut retries_remaining = Self::MAX_RETRIES;
|
||||
let mut retry_delay = Duration::from_secs(1);
|
||||
|
||||
let response = loop {
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder();
|
||||
let request = request_builder
|
||||
.method(Method::POST)
|
||||
|
@ -475,36 +480,53 @@ impl CloudLanguageModel {
|
|||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
break response;
|
||||
} else if !did_retry
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
return Ok(response);
|
||||
} else if response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
retries_remaining -= 1;
|
||||
token = llm_api_token.refresh(&client).await?;
|
||||
} else if response.status() == StatusCode::FORBIDDEN
|
||||
} else if status == StatusCode::FORBIDDEN
|
||||
&& response
|
||||
.headers()
|
||||
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
break Err(anyhow!(MaxMonthlySpendReachedError))?;
|
||||
} else if response.status() == StatusCode::PAYMENT_REQUIRED {
|
||||
break Err(anyhow!(PaymentRequiredError))?;
|
||||
return Err(anyhow!(MaxMonthlySpendReachedError));
|
||||
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
|
||||
// If we encounter an error in the 500 range, retry after a delay.
|
||||
// We've seen at least these in the wild from API providers:
|
||||
// * 500 Internal Server Error
|
||||
// * 502 Bad Gateway
|
||||
// * 529 Service Overloaded
|
||||
|
||||
if retries_remaining == 0 {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"cloud language model completion failed after {} retries with status {status}: {body}",
|
||||
Self::MAX_RETRIES
|
||||
));
|
||||
}
|
||||
|
||||
Timer::after(retry_delay).await;
|
||||
|
||||
retries_remaining -= 1;
|
||||
retry_delay *= 2; // If it fails again, wait longer.
|
||||
} else if status == StatusCode::PAYMENT_REQUIRED {
|
||||
return Err(anyhow!(PaymentRequiredError));
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
break Err(anyhow!(
|
||||
"cloud language model completion failed with status {}: {body}",
|
||||
response.status()
|
||||
))?;
|
||||
return Err(anyhow!(
|
||||
"cloud language model completion failed with status {status}: {body}",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue