Surface upstream rate limits from Anthropic (#16118)

This PR makes it so hitting upstream rate limits from Anthropic result
in an HTTP 429 response instead of an HTTP 500.

To do this we need to surface structured errors out of the `anthropic`
crate.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-12 11:59:24 -04:00 committed by GitHub
parent fbb533b3e0
commit ebdb755fef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 144 additions and 47 deletions

View file

@ -24,6 +24,7 @@ schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
[dev-dependencies]
tokio.workspace = true

View file

@ -1,12 +1,14 @@
mod supported_countries;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::time::Duration;
use strum::EnumIter;
use strum::{EnumIter, EnumString};
use thiserror::Error;
pub use supported_countries::*;
@ -96,7 +98,7 @@ pub async fn complete(
api_url: &str,
api_key: &str,
request: Request,
) -> Result<Response> {
) -> Result<Response, AnthropicError> {
let uri = format!("{api_url}/v1/messages");
let request_builder = HttpRequest::builder()
.method(Method::POST)
@ -106,24 +108,40 @@ pub async fn complete(
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let serialized_request =
serde_json::to_string(&request).context("failed to serialize request")?;
let request = request_builder
.body(AsyncBody::from(serialized_request))
.context("failed to construct request body")?;
let mut response = client.send(request).await?;
let mut response = client
.send(request)
.await
.context("failed to send request to Anthropic")?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: Response = serde_json::from_slice(&body)?;
response
.body_mut()
.read_to_end(&mut body)
.await
.context("failed to read response body")?;
let response_message: Response =
serde_json::from_slice(&body).context("failed to deserialize response body")?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
response
.body_mut()
.read_to_end(&mut body)
.await
.context("failed to read response body")?;
let body_str =
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
Err(AnthropicError::Other(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
)))
}
}
@ -133,7 +151,7 @@ pub async fn stream_completion(
api_key: &str,
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<Event>>> {
) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
let request = StreamingRequest {
base: request,
stream: true,
@ -149,10 +167,16 @@ pub async fn stream_completion(
if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
}
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let serialized_request =
serde_json::to_string(&request).context("failed to serialize request")?;
let request = request_builder
.body(AsyncBody::from(serialized_request))
.context("failed to construct request body")?;
let mut response = client.send(request).await?;
let mut response = client
.send(request)
.await
.context("failed to send request to Anthropic")?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
Ok(reader
@ -163,36 +187,41 @@ pub async fn stream_completion(
let line = line.strip_prefix("data: ")?;
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
Err(error) => Some(Err(anyhow!(error))),
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
}
}
Err(error) => Some(Err(anyhow!(error))),
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
}
})
.boxed())
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
response
.body_mut()
.read_to_end(&mut body)
.await
.context("failed to read response body")?;
let body_str = std::str::from_utf8(&body)?;
let body_str =
std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
match serde_json::from_str::<Event>(body_str) {
Ok(Event::Error { error }) => Err(api_error_to_err(error)),
Ok(_) => Err(anyhow!(
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
Ok(_) => Err(AnthropicError::Other(anyhow!(
"Unexpected success response while expecting an error: '{body_str}'",
)),
Err(_) => Err(anyhow!(
))),
Err(_) => Err(AnthropicError::Other(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str,
)),
))),
}
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<Event>>,
) -> impl Stream<Item = Result<String>> {
response: impl Stream<Item = Result<Event, AnthropicError>>,
) -> impl Stream<Item = Result<String, AnthropicError>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
@ -204,7 +233,7 @@ pub fn extract_text_from_events(
ContentDelta::TextDelta { text } => Some(Ok(text)),
_ => None,
},
Event::Error { error } => Some(Err(api_error_to_err(error))),
Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
_ => None,
},
Err(error) => Some(Err(error)),
@ -212,15 +241,6 @@ pub fn extract_text_from_events(
})
}
fn api_error_to_err(
ApiError {
error_type,
message,
}: ApiError,
) -> anyhow::Error {
anyhow!("API error. Type: '{error_type}', message: '{message}'",)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
@ -374,9 +394,53 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>,
}
#[derive(Error, Debug)]
pub enum AnthropicError {
#[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
ApiError(ApiError),
#[error("{0}")]
Other(#[from] anyhow::Error),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
/// An Anthropic API error code.
/// https://docs.anthropic.com/en/api/errors#http-errors
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
#[strum(serialize_all = "snake_case")]
pub enum ApiErrorCode {
/// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
InvalidRequestError,
/// 401 - `authentication_error`: There's an issue with your API key.
AuthenticationError,
/// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
PermissionError,
/// 404 - `not_found_error`: The requested resource was not found.
NotFoundError,
/// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
RequestTooLarge,
/// 429 - `rate_limit_error`: Your account has hit a rate limit.
RateLimitError,
/// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
ApiError,
/// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
OverloadedError,
}
impl ApiError {
pub fn code(&self) -> Option<ApiErrorCode> {
ApiErrorCode::from_str(&self.error_type).ok()
}
pub fn is_rate_limit_error(&self) -> bool {
match self.error_type.as_str() {
"rate_limit_error" => true,
_ => false,
}
}
}