agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

View file

@ -1,4 +1,4 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{
@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
RefreshLlmTokenListener,
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::SettingsStore;
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _;
@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed";
const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn icon(&self) -> IconName {
@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
}
impl CloudLanguageModel {
const MAX_RETRIES: usize = 3;
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@ -547,8 +542,7 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
let mut retries_remaining = Self::MAX_RETRIES;
let mut retry_delay = Duration::from_secs(1);
let mut refreshed_token = false;
loop {
let request_builder = http_client::Request::builder()
@ -590,14 +584,20 @@ impl CloudLanguageModel {
includes_status_messages,
tool_use_limit_reached,
});
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
}
if !refreshed_token
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
} else if status == StatusCode::FORBIDDEN
refreshed_token = true;
continue;
}
if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@ -622,35 +622,18 @@ impl CloudLanguageModel {
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
}
anyhow::bail!("Forbidden");
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
// If we encounter an error in the 500 range, retry after a delay.
// We've seen at least these in the wild from API providers:
// * 500 Internal Server Error
// * 502 Bad Gateway
// * 529 Service Overloaded
if retries_remaining == 0 {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"cloud language model completion failed after {} retries with status {status}: {body}",
Self::MAX_RETRIES
);
}
Timer::after(retry_delay).await;
retries_remaining -= 1;
retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError { status, body }));
}
let mut body = String::new();
let headers = response.headers().clone();
response.body_mut().read_to_string(&mut body).await?;
return Err(anyhow!(ApiError {
status,
body,
headers
}));
}
}
}
@ -660,6 +643,19 @@ impl CloudLanguageModel {
struct ApiError {
status: StatusCode,
body: String,
headers: HeaderMap<HeaderValue>,
}
impl From<ApiError> for LanguageModelCompletionError {
fn from(error: ApiError) -> Self {
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
}
}
impl LanguageModel for CloudLanguageModel {
@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
PROVIDER_NAME
}
fn upstream_provider_id(&self) -> LanguageModelProviderId {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
Google => language_model::GOOGLE_PROVIDER_ID,
}
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
Google => language_model::GOOGLE_PROVIDER_NAME,
}
}
fn supports_tools(&self) -> bool {
@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
let headers = response.headers().clone();
let mut response_body = String::new();
response
.body_mut()
@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
} else {
Err(anyhow!(ApiError {
status,
body: response_body
body: response_body,
headers
}))
}
}
@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
)
.await
.map_err(|err| match err.downcast::<ApiError>() {
Ok(api_err) => {
if api_err.status == StatusCode::BAD_REQUEST {
if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
return anyhow!(
LanguageModelKnownError::ContextWindowLimitExceeded {
tokens
}
);
}
}
anyhow!(api_err)
}
Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
Err(err) => anyhow!(err),
})?;
@ -995,7 +1000,7 @@ where
.flat_map(move |event| {
futures::stream::iter(match event {
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]