Thread Anthropic errors into LanguageModelKnownError (#33261)
This PR is in preparation for doing automatic retries for certain errors, e.g. Overloaded. It doesn't change behavior yet (aside from some granularity of error messages shown to the user), but rather mostly changes some error handling to be exhaustive enum matches instead of `anyhow` downcasts, and leaves some comments for where the behavior change will be in a future PR. Release Notes: - N/A
This commit is contained in:
parent
aabfea4c10
commit
c610ebfb03
5 changed files with 283 additions and 123 deletions
|
@ -8,19 +8,21 @@ mod telemetry;
|
|||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod fake_provider;
|
||||
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::Result;
|
||||
use client::Client;
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::http;
|
||||
use icons::IconName;
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use std::fmt;
|
||||
use std::ops::{Add, Sub};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use thiserror::Error;
|
||||
use util::serde::is_default;
|
||||
use zed_llm_client::CompletionRequestStatus;
|
||||
|
@ -34,6 +36,10 @@ pub use crate::telemetry::*;
|
|||
|
||||
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
|
||||
|
||||
/// If we get a rate limit error that doesn't tell us when we can retry,
|
||||
/// default to waiting this long before retrying.
|
||||
const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
init_settings(cx);
|
||||
RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
|
@ -70,8 +76,8 @@ pub enum LanguageModelCompletionEvent {
|
|||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("rate limit exceeded, retry after {0:?}")]
|
||||
RateLimit(Duration),
|
||||
#[error("rate limit exceeded, retry after {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
|
@ -79,8 +85,78 @@ pub enum LanguageModelCompletionError {
|
|||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
#[error("language model provider's API is overloaded")]
|
||||
Overloaded,
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
#[error("invalid request format to language model provider's API")]
|
||||
BadRequestFormat,
|
||||
#[error("authentication error with language model provider's API")]
|
||||
AuthenticationError,
|
||||
#[error("permission error with language model provider's API")]
|
||||
PermissionError,
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound,
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("internal server error in language model provider's API")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error reading response from language model provider's API: {0:?}")]
|
||||
ApiReadResponseError(io::Error),
|
||||
#[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
|
||||
HttpResponseError { status: u16, body: String },
|
||||
#[error("error serializing request to language model provider API: {0}")]
|
||||
SerializeRequest(serde_json::Error),
|
||||
#[error("error building request body to language model provider API: {0}")]
|
||||
BuildRequestBody(http::Error),
|
||||
#[error("error sending HTTP request to language model provider API: {0}")]
|
||||
HttpSend(anyhow::Error),
|
||||
#[error("error deserializing language model provider API response: {0}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("unexpected language model provider API response format: {0}")]
|
||||
UnknownResponseFormat(String),
|
||||
}
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend(error),
|
||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
|
||||
AnthropicError::HttpResponseError { status, body } => {
|
||||
Self::HttpResponseError { status, body }
|
||||
}
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
|
||||
AuthenticationError => LanguageModelCompletionError::AuthenticationError,
|
||||
PermissionError => LanguageModelCompletionError::PermissionError,
|
||||
NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
|
||||
RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
},
|
||||
ApiError => LanguageModelCompletionError::ApiInternalServerError,
|
||||
OverloadedError => LanguageModelCompletionError::Overloaded,
|
||||
},
|
||||
None => LanguageModelCompletionError::Other(error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the format used to define the input schema for a language model tool.
|
||||
|
@ -319,6 +395,33 @@ pub trait LanguageModel: Send + Sync {
|
|||
pub enum LanguageModelKnownError {
|
||||
#[error("Context window limit exceeded ({tokens})")]
|
||||
ContextWindowLimitExceeded { tokens: u64 },
|
||||
#[error("Language model provider's API is currently overloaded")]
|
||||
Overloaded,
|
||||
#[error("Language model provider's API encountered an internal server error")]
|
||||
ApiInternalServerError,
|
||||
#[error("I/O error while reading response from language model provider's API: {0:?}")]
|
||||
ReadResponseError(io::Error),
|
||||
#[error("Error deserializing response from language model provider's API: {0:?}")]
|
||||
DeserializeResponse(serde_json::Error),
|
||||
#[error("Language model provider's API returned a response in an unknown format")]
|
||||
UnknownResponseFormat(String),
|
||||
#[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
|
||||
RateLimitExceeded { retry_after: Duration },
|
||||
}
|
||||
|
||||
impl LanguageModelKnownError {
|
||||
/// Attempts to map an HTTP response status code to a known error type.
|
||||
/// Returns None if the status code doesn't map to a specific known error.
|
||||
pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
|
||||
match status {
|
||||
429 => Some(Self::RateLimitExceeded {
|
||||
retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
|
||||
}),
|
||||
503 => Some(Self::Overloaded),
|
||||
500..=599 => Some(Self::ApiInternalServerError),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue