agent: Improve error handling and retry for zed-provided models (#33565)
* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after` when anthropic provides it. * Distinguishes upstream provider errors and rate limits from errors that originate from zed's servers * Moves `LanguageModelCompletionError::BadInputJson` to `LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably this is an error case, the logic in thread is cleaner with this move. There is also precedent for inclusion of errors in the event type - `CompletionRequestStatus::Failed` is how cloud errors arrive. * Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types instead of `&str`, since they can be constructed in a const fashion. * Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME` as the server no longer reads this header and just defaults to that behavior. Release notes for this is covered by #33275 Release Notes: - N/A --------- Co-authored-by: Richard Feldman <oss@rtfeldman.com> Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
f022a13091
commit
d497f52e17
25 changed files with 656 additions and 479 deletions
|
@ -9,17 +9,18 @@ mod telemetry;
|
|||
pub mod fake_provider;
|
||||
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::http;
|
||||
use http_client::{StatusCode, http};
|
||||
use icons::IconName;
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use std::ops::{Add, Sub};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
|
@ -34,11 +35,22 @@ pub use crate::request::*;
|
|||
pub use crate::role::*;
|
||||
pub use crate::telemetry::*;
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
||||
pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
|
||||
LanguageModelProviderId::new("anthropic");
|
||||
pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Anthropic");
|
||||
|
||||
/// If we get a rate limit error that doesn't tell us when we can retry,
|
||||
/// default to waiting this long before retrying.
|
||||
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
|
||||
pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
|
||||
pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Google AI");
|
||||
|
||||
pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
|
||||
pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("OpenAI");
|
||||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
|
||||
pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
|
||||
LanguageModelProviderName::new("Zed");
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
init_settings(cx);
|
||||
|
@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
|
|||
data: String,
|
||||
},
|
||||
ToolUse(LanguageModelToolUse),
|
||||
ToolUseJsonParseError {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
|
@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
|
|||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("rate limit exceeded, retry after {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
#[error("language model provider's API is overloaded")]
|
||||
Overloaded,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
#[error("invalid request format to language model provider's API")]
|
||||
BadRequestFormat,
|
||||
#[error("authentication error with language model provider's API")]
|
||||
AuthenticationError,
|
||||
#[error("permission error with language model provider's API")]
|
||||
PermissionError,
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound,
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("internal server error in language model provider's API")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error reading response from language model provider's API: {0:?}")]
|
||||
ApiReadResponseError(io::Error),
|
||||
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
|
||||
HttpResponseError { status: u16, body: String },
|
||||
#[error("error serializing request to language model provider API: {0}")]
|
||||
SerializeRequest(serde_json::Error),
|
||||
#[error("error building request body to language model provider API: {0}")]
|
||||
BuildRequestBody(http::Error),
|
||||
#[error("error sending HTTP request to language model provider API: {0}")]
|
||||
HttpSend(anyhow::Error),
|
||||
#[error("error deserializing language model provider API response: {0}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("unexpected language model provider API response format: {0}")]
|
||||
UnknownResponseFormat(String),
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Client errors
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
|
||||
// TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
if let Some(tokens) = parse_prompt_too_long(&message) {
|
||||
// TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
|
||||
// to be reported. This is a temporary workaround to handle this in the case where the
|
||||
// token limit has been exceeded.
|
||||
Self::PromptTooLarge {
|
||||
tokens: Some(tokens),
|
||||
}
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend(error),
|
||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
|
||||
AnthropicError::HttpResponseError { status, body } => {
|
||||
Self::HttpResponseError { status, body }
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: retry_after,
|
||||
},
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
|||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
|
||||
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
|
||||
PermissionError => LanguageModelCompletionError::PermissionError,
|
||||
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
|
||||
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => LanguageModelCompletionError::ApiInternalServerError,
|
||||
OverloadedError => LanguageModelCompletionError::Overloaded,
|
||||
},
|
||||
None => LanguageModelCompletionError::Other(error.into()),
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
|
|||
fn name(&self) -> LanguageModelName;
|
||||
fn provider_id(&self) -> LanguageModelProviderId;
|
||||
fn provider_name(&self) -> LanguageModelProviderName;
|
||||
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||
self.provider_id()
|
||||
}
|
||||
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||
self.provider_name()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String;
|
||||
|
||||
fn api_key(&self, _cx: &App) -> Option<String> {
|
||||
|
@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
|
|||
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
..
|
||||
}) => None,
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
|
@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LanguageModelKnownError {
|
||||
#[error("Context window limit exceeded ({tokens})")]
|
||||
ContextWindowLimitExceeded { tokens: u64 },
|
||||
#[error("Language model provider's API is currently overloaded")]
|
||||
Overloaded,
|
||||
#[error("Language model provider's API encountered an internal server error")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error while reading response from language model provider's API: {0:?}")]
|
||||
ReadResponseError(io::Error),
|
||||
#[error("Error deserializing response from language model provider's API: {0:?}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("Language model provider's API returned a response in an unknown format")]
|
||||
UnknownResponseFormat(String),
|
||||
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
}
|
||||
|
||||
impl LanguageModelKnownError {
|
||||
/// Attempts to map an HTTP response status code to a known error type.
|
||||
/// Returns None if the status code doesn't map to a specific known error.
|
||||
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
|
||||
match status {
|
||||
429 => Some(Self::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
}),
|
||||
503 => Some(Self::Overloaded),
|
||||
500..=599 => Some(Self::ApiInternalServerError),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
fn name() -> String;
|
||||
fn description() -> String;
|
||||
|
@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
|
|||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct LanguageModelProviderName(pub SharedString);
|
||||
|
||||
impl LanguageModelProviderId {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderName {
|
||||
pub const fn new(id: &'static str) -> Self {
|
||||
Self(SharedString::new_static(id))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageModelProviderName {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LanguageModelId {
|
||||
fn from(value: String) -> Self {
|
||||
Self(SharedString::from(value))
|
||||
|
|
|
@ -98,7 +98,7 @@ impl ConfiguredModel {
|
|||
}
|
||||
|
||||
pub fn is_provided_by_zed(&self) -> bool {
|
||||
self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
|
||||
self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::ANTHROPIC_PROVIDER_ID;
|
||||
use anthropic::ANTHROPIC_API_URL;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use client::telemetry::Telemetry;
|
||||
|
@ -8,8 +9,6 @@ use std::sync::Arc;
|
|||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use util::ResultExt;
|
||||
|
||||
pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
||||
|
||||
pub fn report_assistant_event(
|
||||
event: AssistantEventData,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
|
@ -19,7 +18,7 @@ pub fn report_assistant_event(
|
|||
) {
|
||||
if let Some(telemetry) = telemetry.as_ref() {
|
||||
telemetry.report_assistant_event(event.clone());
|
||||
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
|
||||
if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
|
||||
if let Some(api_key) = model_api_key {
|
||||
executor
|
||||
.spawn(async move {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue