Merge 5ab18e679f
into c14d84cfdb
This commit is contained in:
commit
bacf7396a4
6 changed files with 332 additions and 150 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -9110,6 +9110,7 @@ dependencies = [
|
|||
"icons",
|
||||
"image",
|
||||
"log",
|
||||
"open_router",
|
||||
"parking_lot",
|
||||
"proto",
|
||||
"schemars",
|
||||
|
@ -11192,6 +11193,8 @@ dependencies = [
|
|||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"thiserror 2.0.12",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ test-support = []
|
|||
|
||||
[dependencies]
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
open_router.workspace = true
|
||||
anyhow.workspace = true
|
||||
base64.workspace = true
|
||||
client.workspace = true
|
||||
|
|
|
@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
|||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
use http_client::{StatusCode, http};
|
||||
use icons::IconName;
|
||||
use open_router::OpenRouterError;
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
|
@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
fn from(error: OpenRouterError) -> Self {
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error {
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
OpenRouterError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: open_router::ApiError) -> Self {
|
||||
use open_router::ApiErrorCode::*;
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error.code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PaymentRequiredError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: format!("Payment required: {}", error.message),
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
RequestTimedOut => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||
message: error.message,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the format used to define the input schema for a language model tool.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub enum LanguageModelToolSchemaFormat {
|
||||
|
|
|
@ -152,6 +152,7 @@ impl State {
|
|||
.open_router
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
|
@ -161,11 +162,11 @@ impl State {
|
|||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key)
|
||||
.context(format!("invalid {} API key", PROVIDER_NAME))?,
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
|
@ -183,7 +184,9 @@ impl State {
|
|||
let api_url = settings.api_url.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = list_models(http_client.as_ref(), &api_url).await?;
|
||||
let models = list_models(http_client.as_ref(), &api_url)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.available_models = models;
|
||||
|
@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
|
|||
&self,
|
||||
request: open_router::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<ResponseStreamEvent, open_router::OpenRouterError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!(
|
||||
"App state dropped: Unable to read API key or API URL from the application state"
|
||||
)))
|
||||
return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"App state dropped"
|
||||
))))
|
||||
.boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
request.await.map_err(Into::into)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
>,
|
||||
> {
|
||||
let request = into_open_router(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenRouterEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await?;
|
||||
Ok(OpenRouterEventMapper::new().map_stream(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
|
|||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
events: Pin<
|
||||
Box<
|
||||
dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
|
||||
>,
|
||||
>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
|
||||
Err(error) => vec![Err(error.into())],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -22,4 +22,6 @@ http_client.workspace = true
|
|||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
thiserror.workspace = true
|
||||
strum.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
|
|
@ -1,12 +1,31 @@
|
|||
use anyhow::{Context, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::convert::TryFrom;
|
||||
use std::{convert::TryFrom, io, time::Duration};
|
||||
use strum::EnumString;
|
||||
use thiserror::Error;
|
||||
|
||||
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
|
||||
if let Some(reset) = headers.get("X-RateLimit-Reset") {
|
||||
if let Ok(s) = reset.to_str() {
|
||||
if let Ok(epoch_ms) = s.parse::<u64>() {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
if epoch_ms > now {
|
||||
return Some(std::time::Duration::from_millis(epoch_ms - now));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
||||
opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
|
||||
}
|
||||
|
@ -413,76 +432,12 @@ pub struct ModelArchitecture {
|
|||
pub input_modalities: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn complete(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<Response> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("HTTP-Referer", "https://zed.dev")
|
||||
.header("X-Title", "Zed Editor");
|
||||
|
||||
let mut request_body = request;
|
||||
request_body.stream = false;
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: Response = serde_json::from_str(&body)?;
|
||||
Ok(response)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterResponse {
|
||||
error: OpenRouterError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterError {
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
code: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => {
|
||||
let error_message = if !response.error.code.is_empty() {
|
||||
format!("{}: {}", response.error.code, response.error.message)
|
||||
} else {
|
||||
response.error.message
|
||||
};
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {}",
|
||||
error_message
|
||||
))
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
|
@ -492,8 +447,15 @@ pub async fn stream_completion(
|
|||
.header("HTTP-Referer", "https://zed.dev")
|
||||
.header("X-Title", "Zed Editor");
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::from(
|
||||
serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
|
||||
))
|
||||
.map_err(OpenRouterError::BuildRequestBody)?;
|
||||
let mut response = client
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(OpenRouterError::HttpSend)?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
|
@ -513,86 +475,85 @@ pub async fn stream_completion(
|
|||
match serde_json::from_str::<ResponseStreamEvent>(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => {
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<ErrorResponse>(line) {
|
||||
Ok(err_response) => Some(Err(anyhow!(err_response.error))),
|
||||
Err(_) => {
|
||||
if line.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Err(anyhow!(
|
||||
"Failed to parse response: {}. Original content: '{}'",
|
||||
error, line
|
||||
)))
|
||||
}
|
||||
}
|
||||
if line.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Err(OpenRouterError::DeserializeResponse(error)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let code = ApiErrorCode::from_status(response.status().as_u16());
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(OpenRouterError::ReadResponse)?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterResponse {
|
||||
error: OpenRouterError,
|
||||
}
|
||||
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
|
||||
Ok(OpenRouterErrorResponse { error }) => error,
|
||||
Err(_) => OpenRouterErrorBody {
|
||||
code: response.status().as_u16(),
|
||||
message: body,
|
||||
metadata: None,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterError {
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
code: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => {
|
||||
let error_message = if !response.error.code.is_empty() {
|
||||
format!("{}: {}", response.error.code, response.error.message)
|
||||
} else {
|
||||
response.error.message
|
||||
};
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {}",
|
||||
error_message
|
||||
))
|
||||
match code {
|
||||
ApiErrorCode::RateLimitError => {
|
||||
let retry_after = extract_retry_after(response.headers());
|
||||
Err(OpenRouterError::RateLimit {
|
||||
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
|
||||
})
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
ApiErrorCode::OverloadedError => {
|
||||
let retry_after = extract_retry_after(response.headers());
|
||||
Err(OpenRouterError::ServerOverloaded { retry_after })
|
||||
}
|
||||
_ => Err(OpenRouterError::ApiError(ApiError {
|
||||
code: code,
|
||||
message: error_response.message,
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
|
||||
pub async fn list_models(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
) -> Result<Vec<Model>, OpenRouterError> {
|
||||
let uri = format!("{api_url}/models");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(uri)
|
||||
.header("Accept", "application/json");
|
||||
|
||||
let request = request_builder.body(AsyncBody::default())?;
|
||||
let mut response = client.send(request).await?;
|
||||
let request = request_builder
|
||||
.body(AsyncBody::default())
|
||||
.map_err(OpenRouterError::BuildRequestBody)?;
|
||||
let mut response = client
|
||||
.send(request)
|
||||
.await
|
||||
.map_err(OpenRouterError::HttpSend)?;
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(OpenRouterError::ReadResponse)?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let response: ListModelsResponse =
|
||||
serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
|
||||
serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
|
||||
|
||||
let models = response
|
||||
.data
|
||||
|
@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
|
|||
|
||||
Ok(models)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
))
|
||||
let code = ApiErrorCode::from_status(response.status().as_u16());
|
||||
|
||||
let mut body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut body)
|
||||
.await
|
||||
.map_err(OpenRouterError::ReadResponse)?;
|
||||
|
||||
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
|
||||
Ok(OpenRouterErrorResponse { error }) => error,
|
||||
Err(_) => OpenRouterErrorBody {
|
||||
code: response.status().as_u16(),
|
||||
message: body,
|
||||
metadata: None,
|
||||
},
|
||||
};
|
||||
|
||||
match code {
|
||||
ApiErrorCode::RateLimitError => {
|
||||
let retry_after = extract_retry_after(response.headers());
|
||||
Err(OpenRouterError::RateLimit {
|
||||
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
|
||||
})
|
||||
}
|
||||
ApiErrorCode::OverloadedError => {
|
||||
let retry_after = extract_retry_after(response.headers());
|
||||
Err(OpenRouterError::ServerOverloaded { retry_after })
|
||||
}
|
||||
_ => Err(OpenRouterError::ApiError(ApiError {
|
||||
code: code,
|
||||
message: error_response.message,
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum OpenRouterError {
|
||||
/// Failed to serialize the HTTP request body to JSON
|
||||
SerializeRequest(serde_json::Error),
|
||||
|
||||
/// Failed to construct the HTTP request body
|
||||
BuildRequestBody(http::Error),
|
||||
|
||||
/// Failed to send the HTTP request
|
||||
HttpSend(anyhow::Error),
|
||||
|
||||
/// Failed to deserialize the response from JSON
|
||||
DeserializeResponse(serde_json::Error),
|
||||
|
||||
/// Failed to read from response stream
|
||||
ReadResponse(io::Error),
|
||||
|
||||
/// Rate limit exceeded
|
||||
RateLimit { retry_after: Duration },
|
||||
|
||||
/// Server overloaded
|
||||
ServerOverloaded { retry_after: Option<Duration> },
|
||||
|
||||
/// API returned an error response
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OpenRouterErrorBody {
|
||||
pub code: u16,
|
||||
pub message: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OpenRouterErrorResponse {
|
||||
pub error: OpenRouterErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Error)]
|
||||
#[error("OpenRouter API Error: {code}: {message}")]
|
||||
pub struct ApiError {
|
||||
pub code: ApiErrorCode,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// An OpenROuter API error code.
|
||||
/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum ApiErrorCode {
|
||||
/// 400: Bad Request (invalid or missing params, CORS)
|
||||
InvalidRequestError,
|
||||
/// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
|
||||
AuthenticationError,
|
||||
/// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
|
||||
PaymentRequiredError,
|
||||
/// 403: Your chosen model requires moderation and your input was flagged
|
||||
PermissionError,
|
||||
/// 408: Your request timed out
|
||||
RequestTimedOut,
|
||||
/// 429: You are being rate limited
|
||||
RateLimitError,
|
||||
/// 502: Your chosen model is down or we received an invalid response from it
|
||||
ApiError,
|
||||
/// 503: There is no available model provider that meets your routing requirements
|
||||
OverloadedError,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ApiErrorCode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
ApiErrorCode::InvalidRequestError => "invalid_request_error",
|
||||
ApiErrorCode::AuthenticationError => "authentication_error",
|
||||
ApiErrorCode::PaymentRequiredError => "payment_required_error",
|
||||
ApiErrorCode::PermissionError => "permission_error",
|
||||
ApiErrorCode::RequestTimedOut => "request_timed_out",
|
||||
ApiErrorCode::RateLimitError => "rate_limit_error",
|
||||
ApiErrorCode::ApiError => "api_error",
|
||||
ApiErrorCode::OverloadedError => "overloaded_error",
|
||||
};
|
||||
write!(f, "{s}")
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiErrorCode {
|
||||
pub fn from_status(status: u16) -> Self {
|
||||
match status {
|
||||
400 => ApiErrorCode::InvalidRequestError,
|
||||
401 => ApiErrorCode::AuthenticationError,
|
||||
402 => ApiErrorCode::PaymentRequiredError,
|
||||
403 => ApiErrorCode::PermissionError,
|
||||
408 => ApiErrorCode::RequestTimedOut,
|
||||
429 => ApiErrorCode::RateLimitError,
|
||||
502 => ApiErrorCode::ApiError,
|
||||
503 => ApiErrorCode::OverloadedError,
|
||||
_ => ApiErrorCode::ApiError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue