From 76d3a9a0f0d36e2c9170ce2fb43e557079a514a0 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Thu, 27 Mar 2025 09:35:16 -0400 Subject: [PATCH] Retry on 5xx errors from cloud language model providers (#27584) Release Notes: - N/A --- crates/language_models/src/provider/cloud.rs | 64 +++++++++++++------- 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index f91c2799e7..bc3cc87181 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -27,9 +27,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::value::RawValue; use settings::{Settings, SettingsStore}; use smol::io::{AsyncReadExt, BufReader}; +use smol::Timer; use std::{ future, sync::{Arc, LazyLock}, + time::Duration, }; use strum::IntoEnumIterator; use ui::{prelude::*, TintColor}; @@ -456,6 +458,8 @@ pub struct CloudLanguageModel { } impl CloudLanguageModel { + const MAX_RETRIES: usize = 3; + async fn perform_llm_completion( client: Arc, llm_api_token: LlmApiToken, @@ -464,9 +468,10 @@ impl CloudLanguageModel { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; - let mut did_retry = false; + let mut retries_remaining = Self::MAX_RETRIES; + let mut retry_delay = Duration::from_secs(1); - let response = loop { + loop { let request_builder = http_client::Request::builder(); let request = request_builder .method(Method::POST) @@ -475,36 +480,53 @@ impl CloudLanguageModel { .header("Authorization", format!("Bearer {token}")) .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; - if response.status().is_success() { - break response; - } else if !did_retry - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() + let status = response.status(); + if status.is_success() { + return Ok(response); + } else if response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() { - did_retry = true; + retries_remaining -= 1; token = llm_api_token.refresh(&client).await?; - } else if response.status() == StatusCode::FORBIDDEN + } else if status == StatusCode::FORBIDDEN && response .headers() .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME) .is_some() { - break Err(anyhow!(MaxMonthlySpendReachedError))?; - } else if response.status() == StatusCode::PAYMENT_REQUIRED { - break Err(anyhow!(PaymentRequiredError))?; + return Err(anyhow!(MaxMonthlySpendReachedError)); + } else if status.as_u16() >= 500 && status.as_u16() < 600 { + // If we encounter an error in the 500 range, retry after a delay. + // We've seen at least these in the wild from API providers: + // * 500 Internal Server Error + // * 502 Bad Gateway + // * 529 Service Overloaded + + if retries_remaining == 0 { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "cloud language model completion failed after {} retries with status {status}: {body}", + Self::MAX_RETRIES + )); + } + + Timer::after(retry_delay).await; + + retries_remaining -= 1; + retry_delay *= 2; // If it fails again, wait longer. + } else if status == StatusCode::PAYMENT_REQUIRED { + return Err(anyhow!(PaymentRequiredError)); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - break Err(anyhow!( - "cloud language model completion failed with status {}: {body}", - response.status() - ))?; + return Err(anyhow!( + "cloud language model completion failed with status {status}: {body}", + )); } - }; - - Ok(response) + } } }