agent: Improve error handling and retry for zed-provided models (#33565)

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Michael Sloan 2025-06-30 21:01:32 -06:00 committed by GitHub
parent f022a13091
commit d497f52e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 656 additions and 479 deletions

View file

@ -9,17 +9,18 @@ mod telemetry;
pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::Result;
use anyhow::{Result, anyhow};
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::http;
use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
@ -34,11 +35,22 @@ pub use crate::request::*;
pub use crate::role::*;
pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
LanguageModelProviderId::new("anthropic");
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Anthropic");
/// If we get a rate limit error that doesn't tell us when we can retry,
/// default to waiting this long before retrying.
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Google AI");
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("OpenAI");
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
LanguageModelProviderName::new("Zed");
pub fn init(client: Arc<Client>, cx: &mut App) {
init_settings(cx);
@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
data: String,
},
ToolUse(LanguageModelToolUse),
ToolUseJsonParseError {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
StartMessage {
message_id: String,
},
@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("rate limit exceeded, retry after {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
tool_name: Arc<str>,
raw_input: Arc<str>,
json_parse_error: String,
},
#[error("language model provider's API is overloaded")]
Overloaded,
#[error(transparent)]
Other(#[from] anyhow::Error),
#[error("invalid request format to language model provider's API")]
BadRequestFormat,
#[error("authentication error with language model provider's API")]
AuthenticationError,
#[error("permission error with language model provider's API")]
PermissionError,
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound,
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("internal server error in language model provider's API")]
ApiInternalServerError,
#[error("I/O error reading response from language model provider's API: {0:?}")]
ApiReadResponseError(io::Error),
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
HttpResponseError { status: u16, body: String },
#[error("error serializing request to language model provider API: {0}")]
SerializeRequest(serde_json::Error),
#[error("error building request body to language model provider API: {0}")]
BuildRequestBody(http::Error),
#[error("error sending HTTP request to language model provider API: {0}")]
HttpSend(anyhow::Error),
#[error("error deserializing language model provider API response: {0}")]
DeserializeResponse(serde_json::Error),
#[error("unexpected language model provider API response format: {0}")]
UnknownResponseFormat(String),
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl LanguageModelCompletionError {
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
) -> Self {
if let Some(tokens) = parse_prompt_too_long(&message) {
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
// to be reported. This is a temporary workaround to handle this in the case where the
// token limit has been exceeded.
Self::PromptTooLarge {
tokens: Some(tokens),
}
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
}
}
}
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
AnthropicError::HttpSend(error) => Self::HttpSend(error),
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
AnthropicError::HttpResponseError { status, body } => {
Self::HttpResponseError { status, body }
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after: retry_after,
},
AnthropicError::ApiError(api_error) => api_error.into(),
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
}
}
}
@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
PermissionError => LanguageModelCompletionError::PermissionError,
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
ApiError => LanguageModelCompletionError::ApiInternalServerError,
OverloadedError => LanguageModelCompletionError::Overloaded,
},
None => LanguageModelCompletionError::Other(error.into()),
None => Self::Other(error.into()),
}
}
}
@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn upstream_provider_id(&self) -> LanguageModelProviderId {
self.provider_id()
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
self.provider_name()
}
fn telemetry_id(&self) -> String;
fn api_key(&self, _cx: &App) -> Option<String> {
@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
..
}) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
}
}
#[derive(Debug, Error)]
pub enum LanguageModelKnownError {
#[error("Context window limit exceeded ({tokens})")]
ContextWindowLimitExceeded { tokens: u64 },
#[error("Language model provider's API is currently overloaded")]
Overloaded,
#[error("Language model provider's API encountered an internal server error")]
ApiInternalServerError,
#[error("I/O error while reading response from language model provider's API: {0:?}")]
ReadResponseError(io::Error),
#[error("Error deserializing response from language model provider's API: {0:?}")]
DeserializeResponse(serde_json::Error),
#[error("Language model provider's API returned a response in an unknown format")]
UnknownResponseFormat(String),
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
RateLimitExceeded { retry_after: Duration },
}
impl LanguageModelKnownError {
/// Attempts to map an HTTP response status code to a known error type.
/// Returns None if the status code doesn't map to a specific known error.
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
match status {
429 => Some(Self::RateLimitExceeded {
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
}),
503 => Some(Self::Overloaded),
500..=599 => Some(Self::ApiInternalServerError),
_ => None,
}
}
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString);
impl LanguageModelProviderId {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl LanguageModelProviderName {
pub const fn new(id: &'static str) -> Self {
Self(SharedString::new_static(id))
}
}
impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl fmt::Display for LanguageModelProviderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<String> for LanguageModelId {
fn from(value: String) -> Self {
Self(SharedString::from(value))

View file

@ -98,7 +98,7 @@ impl ConfiguredModel {
}
pub fn is_provided_by_zed(&self) -> bool {
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
}
}

View file

@ -1,3 +1,4 @@
use crate::ANTHROPIC_PROVIDER_ID;
use anthropic::ANTHROPIC_API_URL;
use anyhow::{Context as _, anyhow};
use client::telemetry::Telemetry;
@ -8,8 +9,6 @@ use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use util::ResultExt;
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
pub fn report_assistant_event(
event: AssistantEventData,
telemetry: Option<Arc<Telemetry>>,
@ -19,7 +18,7 @@ pub fn report_assistant_event(
) {
if let Some(telemetry) = telemetry.as_ref() {
telemetry.report_assistant_event(event.clone());
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
if let Some(api_key) = model_api_key {
executor
.spawn(async move {