Merge 5ab18e679f
into c14d84cfdb
This commit is contained in:
commit
bacf7396a4
6 changed files with 332 additions and 150 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -9110,6 +9110,7 @@ dependencies = [
|
||||||
"icons",
|
"icons",
|
||||||
"image",
|
"image",
|
||||||
"log",
|
"log",
|
||||||
|
"open_router",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"proto",
|
"proto",
|
||||||
"schemars",
|
"schemars",
|
||||||
|
@ -11192,6 +11193,8 @@ dependencies = [
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"strum 0.27.1",
|
||||||
|
"thiserror 2.0.12",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ test-support = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anthropic = { workspace = true, features = ["schemars"] }
|
anthropic = { workspace = true, features = ["schemars"] }
|
||||||
|
open_router.workspace = true
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
base64.workspace = true
|
base64.workspace = true
|
||||||
client.workspace = true
|
client.workspace = true
|
||||||
|
|
|
@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||||
use http_client::{StatusCode, http};
|
use http_client::{StatusCode, http};
|
||||||
use icons::IconName;
|
use icons::IconName;
|
||||||
|
use open_router::OpenRouterError;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||||
|
@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||||
|
fn from(error: OpenRouterError) -> Self {
|
||||||
|
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||||
|
match error {
|
||||||
|
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||||
|
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||||
|
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||||
|
OpenRouterError::DeserializeResponse(error) => {
|
||||||
|
Self::DeserializeResponse { provider, error }
|
||||||
|
}
|
||||||
|
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||||
|
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||||
|
provider,
|
||||||
|
retry_after: Some(retry_after),
|
||||||
|
},
|
||||||
|
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after,
|
||||||
|
},
|
||||||
|
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||||
|
fn from(error: open_router::ApiError) -> Self {
|
||||||
|
use open_router::ApiErrorCode::*;
|
||||||
|
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||||
|
match error.code {
|
||||||
|
InvalidRequestError => Self::BadRequestFormat {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
AuthenticationError => Self::AuthenticationError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
PaymentRequiredError => Self::AuthenticationError {
|
||||||
|
provider,
|
||||||
|
message: format!("Payment required: {}", error.message),
|
||||||
|
},
|
||||||
|
PermissionError => Self::PermissionError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
RequestTimedOut => Self::HttpResponseError {
|
||||||
|
provider,
|
||||||
|
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
RateLimitError => Self::RateLimitExceeded {
|
||||||
|
provider,
|
||||||
|
retry_after: None,
|
||||||
|
},
|
||||||
|
ApiError => Self::ApiInternalServerError {
|
||||||
|
provider,
|
||||||
|
message: error.message,
|
||||||
|
},
|
||||||
|
OverloadedError => Self::ServerOverloaded {
|
||||||
|
provider,
|
||||||
|
retry_after: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Indicates the format used to define the input schema for a language model tool.
|
/// Indicates the format used to define the input schema for a language model tool.
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||||
pub enum LanguageModelToolSchemaFormat {
|
pub enum LanguageModelToolSchemaFormat {
|
||||||
|
|
|
@ -152,6 +152,7 @@ impl State {
|
||||||
.open_router
|
.open_router
|
||||||
.api_url
|
.api_url
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
|
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
|
||||||
(api_key, true)
|
(api_key, true)
|
||||||
|
@ -161,11 +162,11 @@ impl State {
|
||||||
.await?
|
.await?
|
||||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||||
(
|
(
|
||||||
String::from_utf8(api_key)
|
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||||
.context(format!("invalid {} API key", PROVIDER_NAME))?,
|
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.api_key = Some(api_key);
|
this.api_key = Some(api_key);
|
||||||
this.api_key_from_env = from_env;
|
this.api_key_from_env = from_env;
|
||||||
|
@ -183,7 +184,9 @@ impl State {
|
||||||
let api_url = settings.api_url.clone();
|
let api_url = settings.api_url.clone();
|
||||||
|
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let models = list_models(http_client.as_ref(), &api_url).await?;
|
let models = list_models(http_client.as_ref(), &api_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.available_models = models;
|
this.available_models = models;
|
||||||
|
@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
|
||||||
&self,
|
&self,
|
||||||
request: open_router::Request,
|
request: open_router::Request,
|
||||||
cx: &AsyncApp,
|
cx: &AsyncApp,
|
||||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
) -> BoxFuture<
|
||||||
{
|
'static,
|
||||||
|
Result<
|
||||||
|
futures::stream::BoxStream<
|
||||||
|
'static,
|
||||||
|
Result<ResponseStreamEvent, open_router::OpenRouterError>,
|
||||||
|
>,
|
||||||
|
LanguageModelCompletionError,
|
||||||
|
>,
|
||||||
|
> {
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||||
(state.api_key.clone(), settings.api_url.clone())
|
(state.api_key.clone(), settings.api_url.clone())
|
||||||
}) else {
|
}) else {
|
||||||
return futures::future::ready(Err(anyhow!(
|
return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||||
"App state dropped: Unable to read API key or API URL from the application state"
|
"App state dropped"
|
||||||
)))
|
))))
|
||||||
.boxed();
|
.boxed();
|
||||||
};
|
};
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
async move {
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
|
let Some(api_key) = api_key else {
|
||||||
|
return Err(LanguageModelCompletionError::NoApiKey {
|
||||||
|
provider: PROVIDER_NAME,
|
||||||
|
});
|
||||||
|
};
|
||||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
let response = request.await?;
|
request.await.map_err(Into::into)
|
||||||
Ok(response)
|
}
|
||||||
});
|
.boxed()
|
||||||
|
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_open_router(request, &self.model, self.max_output_tokens());
|
let request = into_open_router(request, &self.model, self.max_output_tokens());
|
||||||
let completions = self.stream_completion(request, cx);
|
let request = self.stream_completion(request, cx);
|
||||||
async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let mapper = OpenRouterEventMapper::new();
|
let response = request.await?;
|
||||||
Ok(mapper.map_stream(completions.await?).boxed())
|
Ok(OpenRouterEventMapper::new().map_stream(response))
|
||||||
}
|
});
|
||||||
.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
|
||||||
|
|
||||||
pub fn map_stream(
|
pub fn map_stream(
|
||||||
mut self,
|
mut self,
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
events: Pin<
|
||||||
|
Box<
|
||||||
|
dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
|
||||||
|
>,
|
||||||
|
>,
|
||||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||||
{
|
{
|
||||||
events.flat_map(move |event| {
|
events.flat_map(move |event| {
|
||||||
futures::stream::iter(match event {
|
futures::stream::iter(match event {
|
||||||
Ok(event) => self.map_event(event),
|
Ok(event) => self.map_event(event),
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
|
Err(error) => vec![Err(error.into())],
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,4 +22,6 @@ http_client.workspace = true
|
||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
strum.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
|
@ -1,12 +1,31 @@
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::convert::TryFrom;
|
use std::{convert::TryFrom, io, time::Duration};
|
||||||
|
use strum::EnumString;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
|
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||||
|
|
||||||
|
fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
|
||||||
|
if let Some(reset) = headers.get("X-RateLimit-Reset") {
|
||||||
|
if let Ok(s) = reset.to_str() {
|
||||||
|
if let Ok(epoch_ms) = s.parse::<u64>() {
|
||||||
|
let now = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64;
|
||||||
|
if epoch_ms > now {
|
||||||
|
return Some(std::time::Duration::from_millis(epoch_ms - now));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
||||||
opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
|
opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
|
||||||
}
|
}
|
||||||
|
@ -413,76 +432,12 @@ pub struct ModelArchitecture {
|
||||||
pub input_modalities: Vec<String>,
|
pub input_modalities: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn complete(
|
|
||||||
client: &dyn HttpClient,
|
|
||||||
api_url: &str,
|
|
||||||
api_key: &str,
|
|
||||||
request: Request,
|
|
||||||
) -> Result<Response> {
|
|
||||||
let uri = format!("{api_url}/chat/completions");
|
|
||||||
let request_builder = HttpRequest::builder()
|
|
||||||
.method(Method::POST)
|
|
||||||
.uri(uri)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.header("HTTP-Referer", "https://zed.dev")
|
|
||||||
.header("X-Title", "Zed Editor");
|
|
||||||
|
|
||||||
let mut request_body = request;
|
|
||||||
request_body.stream = false;
|
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
|
|
||||||
let mut response = client.send(request).await?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
let response: Response = serde_json::from_str(&body)?;
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenRouterResponse {
|
|
||||||
error: OpenRouterError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct OpenRouterError {
|
|
||||||
message: String,
|
|
||||||
#[serde(default)]
|
|
||||||
code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => {
|
|
||||||
let error_message = if !response.error.code.is_empty() {
|
|
||||||
format!("{}: {}", response.error.code, response.error.message)
|
|
||||||
} else {
|
|
||||||
response.error.message
|
|
||||||
};
|
|
||||||
|
|
||||||
Err(anyhow!(
|
|
||||||
"Failed to connect to OpenRouter API: {}",
|
|
||||||
error_message
|
|
||||||
))
|
|
||||||
}
|
|
||||||
_ => Err(anyhow!(
|
|
||||||
"Failed to connect to OpenRouter API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body,
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stream_completion(
|
pub async fn stream_completion(
|
||||||
client: &dyn HttpClient,
|
client: &dyn HttpClient,
|
||||||
api_url: &str,
|
api_url: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
|
||||||
let uri = format!("{api_url}/chat/completions");
|
let uri = format!("{api_url}/chat/completions");
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
|
@ -492,8 +447,15 @@ pub async fn stream_completion(
|
||||||
.header("HTTP-Referer", "https://zed.dev")
|
.header("HTTP-Referer", "https://zed.dev")
|
||||||
.header("X-Title", "Zed Editor");
|
.header("X-Title", "Zed Editor");
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
let request = request_builder
|
||||||
let mut response = client.send(request).await?;
|
.body(AsyncBody::from(
|
||||||
|
serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
|
||||||
|
))
|
||||||
|
.map_err(OpenRouterError::BuildRequestBody)?;
|
||||||
|
let mut response = client
|
||||||
|
.send(request)
|
||||||
|
.await
|
||||||
|
.map_err(OpenRouterError::HttpSend)?;
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let reader = BufReader::new(response.into_body());
|
let reader = BufReader::new(response.into_body());
|
||||||
|
@ -513,86 +475,85 @@ pub async fn stream_completion(
|
||||||
match serde_json::from_str::<ResponseStreamEvent>(line) {
|
match serde_json::from_str::<ResponseStreamEvent>(line) {
|
||||||
Ok(response) => Some(Ok(response)),
|
Ok(response) => Some(Ok(response)),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
#[derive(Deserialize)]
|
if line.trim().is_empty() {
|
||||||
struct ErrorResponse {
|
None
|
||||||
error: String,
|
} else {
|
||||||
}
|
Some(Err(OpenRouterError::DeserializeResponse(error)))
|
||||||
|
|
||||||
match serde_json::from_str::<ErrorResponse>(line) {
|
|
||||||
Ok(err_response) => Some(Err(anyhow!(err_response.error))),
|
|
||||||
Err(_) => {
|
|
||||||
if line.trim().is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(Err(anyhow!(
|
|
||||||
"Failed to parse response: {}. Original content: '{}'",
|
|
||||||
error, line
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(error) => Some(Err(anyhow!(error))),
|
Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.boxed())
|
.boxed())
|
||||||
} else {
|
} else {
|
||||||
|
let code = ApiErrorCode::from_status(response.status().as_u16());
|
||||||
|
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response
|
||||||
|
.body_mut()
|
||||||
|
.read_to_string(&mut body)
|
||||||
|
.await
|
||||||
|
.map_err(OpenRouterError::ReadResponse)?;
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
|
||||||
struct OpenRouterResponse {
|
Ok(OpenRouterErrorResponse { error }) => error,
|
||||||
error: OpenRouterError,
|
Err(_) => OpenRouterErrorBody {
|
||||||
}
|
code: response.status().as_u16(),
|
||||||
|
message: body,
|
||||||
|
metadata: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
match code {
|
||||||
struct OpenRouterError {
|
ApiErrorCode::RateLimitError => {
|
||||||
message: String,
|
let retry_after = extract_retry_after(response.headers());
|
||||||
#[serde(default)]
|
Err(OpenRouterError::RateLimit {
|
||||||
code: String,
|
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
|
||||||
}
|
})
|
||||||
|
|
||||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => {
|
|
||||||
let error_message = if !response.error.code.is_empty() {
|
|
||||||
format!("{}: {}", response.error.code, response.error.message)
|
|
||||||
} else {
|
|
||||||
response.error.message
|
|
||||||
};
|
|
||||||
|
|
||||||
Err(anyhow!(
|
|
||||||
"Failed to connect to OpenRouter API: {}",
|
|
||||||
error_message
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
_ => Err(anyhow!(
|
ApiErrorCode::OverloadedError => {
|
||||||
"Failed to connect to OpenRouter API: {} {}",
|
let retry_after = extract_retry_after(response.headers());
|
||||||
response.status(),
|
Err(OpenRouterError::ServerOverloaded { retry_after })
|
||||||
body,
|
}
|
||||||
)),
|
_ => Err(OpenRouterError::ApiError(ApiError {
|
||||||
|
code: code,
|
||||||
|
message: error_response.message,
|
||||||
|
})),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
|
pub async fn list_models(
|
||||||
|
client: &dyn HttpClient,
|
||||||
|
api_url: &str,
|
||||||
|
) -> Result<Vec<Model>, OpenRouterError> {
|
||||||
let uri = format!("{api_url}/models");
|
let uri = format!("{api_url}/models");
|
||||||
let request_builder = HttpRequest::builder()
|
let request_builder = HttpRequest::builder()
|
||||||
.method(Method::GET)
|
.method(Method::GET)
|
||||||
.uri(uri)
|
.uri(uri)
|
||||||
.header("Accept", "application/json");
|
.header("Accept", "application/json");
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::default())?;
|
let request = request_builder
|
||||||
let mut response = client.send(request).await?;
|
.body(AsyncBody::default())
|
||||||
|
.map_err(OpenRouterError::BuildRequestBody)?;
|
||||||
|
let mut response = client
|
||||||
|
.send(request)
|
||||||
|
.await
|
||||||
|
.map_err(OpenRouterError::HttpSend)?;
|
||||||
|
|
||||||
let mut body = String::new();
|
let mut body = String::new();
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
response
|
||||||
|
.body_mut()
|
||||||
|
.read_to_string(&mut body)
|
||||||
|
.await
|
||||||
|
.map_err(OpenRouterError::ReadResponse)?;
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
let response: ListModelsResponse =
|
let response: ListModelsResponse =
|
||||||
serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
|
serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
|
||||||
|
|
||||||
let models = response
|
let models = response
|
||||||
.data
|
.data
|
||||||
|
@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
|
||||||
|
|
||||||
Ok(models)
|
Ok(models)
|
||||||
} else {
|
} else {
|
||||||
Err(anyhow!(
|
let code = ApiErrorCode::from_status(response.status().as_u16());
|
||||||
"Failed to connect to OpenRouter API: {} {}",
|
|
||||||
response.status(),
|
let mut body = String::new();
|
||||||
body,
|
response
|
||||||
))
|
.body_mut()
|
||||||
|
.read_to_string(&mut body)
|
||||||
|
.await
|
||||||
|
.map_err(OpenRouterError::ReadResponse)?;
|
||||||
|
|
||||||
|
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
|
||||||
|
Ok(OpenRouterErrorResponse { error }) => error,
|
||||||
|
Err(_) => OpenRouterErrorBody {
|
||||||
|
code: response.status().as_u16(),
|
||||||
|
message: body,
|
||||||
|
metadata: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
match code {
|
||||||
|
ApiErrorCode::RateLimitError => {
|
||||||
|
let retry_after = extract_retry_after(response.headers());
|
||||||
|
Err(OpenRouterError::RateLimit {
|
||||||
|
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ApiErrorCode::OverloadedError => {
|
||||||
|
let retry_after = extract_retry_after(response.headers());
|
||||||
|
Err(OpenRouterError::ServerOverloaded { retry_after })
|
||||||
|
}
|
||||||
|
_ => Err(OpenRouterError::ApiError(ApiError {
|
||||||
|
code: code,
|
||||||
|
message: error_response.message,
|
||||||
|
})),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum OpenRouterError {
|
||||||
|
/// Failed to serialize the HTTP request body to JSON
|
||||||
|
SerializeRequest(serde_json::Error),
|
||||||
|
|
||||||
|
/// Failed to construct the HTTP request body
|
||||||
|
BuildRequestBody(http::Error),
|
||||||
|
|
||||||
|
/// Failed to send the HTTP request
|
||||||
|
HttpSend(anyhow::Error),
|
||||||
|
|
||||||
|
/// Failed to deserialize the response from JSON
|
||||||
|
DeserializeResponse(serde_json::Error),
|
||||||
|
|
||||||
|
/// Failed to read from response stream
|
||||||
|
ReadResponse(io::Error),
|
||||||
|
|
||||||
|
/// Rate limit exceeded
|
||||||
|
RateLimit { retry_after: Duration },
|
||||||
|
|
||||||
|
/// Server overloaded
|
||||||
|
ServerOverloaded { retry_after: Option<Duration> },
|
||||||
|
|
||||||
|
/// API returned an error response
|
||||||
|
ApiError(ApiError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OpenRouterErrorBody {
|
||||||
|
pub code: u16,
|
||||||
|
pub message: String,
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OpenRouterErrorResponse {
|
||||||
|
pub error: OpenRouterErrorBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Error)]
|
||||||
|
#[error("OpenRouter API Error: {code}: {message}")]
|
||||||
|
pub struct ApiError {
|
||||||
|
pub code: ApiErrorCode,
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An OpenROuter API error code.
|
||||||
|
/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
|
||||||
|
#[strum(serialize_all = "snake_case")]
|
||||||
|
pub enum ApiErrorCode {
|
||||||
|
/// 400: Bad Request (invalid or missing params, CORS)
|
||||||
|
InvalidRequestError,
|
||||||
|
/// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
|
||||||
|
AuthenticationError,
|
||||||
|
/// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
|
||||||
|
PaymentRequiredError,
|
||||||
|
/// 403: Your chosen model requires moderation and your input was flagged
|
||||||
|
PermissionError,
|
||||||
|
/// 408: Your request timed out
|
||||||
|
RequestTimedOut,
|
||||||
|
/// 429: You are being rate limited
|
||||||
|
RateLimitError,
|
||||||
|
/// 502: Your chosen model is down or we received an invalid response from it
|
||||||
|
ApiError,
|
||||||
|
/// 503: There is no available model provider that meets your routing requirements
|
||||||
|
OverloadedError,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ApiErrorCode {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let s = match self {
|
||||||
|
ApiErrorCode::InvalidRequestError => "invalid_request_error",
|
||||||
|
ApiErrorCode::AuthenticationError => "authentication_error",
|
||||||
|
ApiErrorCode::PaymentRequiredError => "payment_required_error",
|
||||||
|
ApiErrorCode::PermissionError => "permission_error",
|
||||||
|
ApiErrorCode::RequestTimedOut => "request_timed_out",
|
||||||
|
ApiErrorCode::RateLimitError => "rate_limit_error",
|
||||||
|
ApiErrorCode::ApiError => "api_error",
|
||||||
|
ApiErrorCode::OverloadedError => "overloaded_error",
|
||||||
|
};
|
||||||
|
write!(f, "{s}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiErrorCode {
|
||||||
|
pub fn from_status(status: u16) -> Self {
|
||||||
|
match status {
|
||||||
|
400 => ApiErrorCode::InvalidRequestError,
|
||||||
|
401 => ApiErrorCode::AuthenticationError,
|
||||||
|
402 => ApiErrorCode::PaymentRequiredError,
|
||||||
|
403 => ApiErrorCode::PermissionError,
|
||||||
|
408 => ApiErrorCode::RequestTimedOut,
|
||||||
|
429 => ApiErrorCode::RateLimitError,
|
||||||
|
502 => ApiErrorCode::ApiError,
|
||||||
|
503 => ApiErrorCode::OverloadedError,
|
||||||
|
_ => ApiErrorCode::ApiError,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue