This commit is contained in:
Umesh Yadav 2025-08-26 15:08:05 +05:30 committed by GitHub
commit bacf7396a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 332 additions and 150 deletions

3
Cargo.lock generated
View file

@ -9110,6 +9110,7 @@ dependencies = [
"icons",
"image",
"log",
"open_router",
"parking_lot",
"proto",
"schemars",
@ -11192,6 +11193,8 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
"workspace-hack",
]

View file

@ -17,6 +17,7 @@ test-support = []
[dependencies]
anthropic = { workspace = true, features = ["schemars"] }
open_router.workspace = true
anyhow.workspace = true
base64.workspace = true
client.workspace = true

View file

@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
use open_router::OpenRouterError;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
}
}
impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
}
}
}
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {

View file

@ -152,6 +152,7 @@ impl State {
.open_router
.api_url
.clone();
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
(api_key, true)
@ -161,11 +162,11 @@ impl State {
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key)
.context(format!("invalid {} API key", PROVIDER_NAME))?,
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
@ -183,7 +184,9 @@ impl State {
let api_url = settings.api_url.clone();
cx.spawn(async move |this, cx| {
let models = list_models(http_client.as_ref(), &api_url).await?;
let models = list_models(http_client.as_ref(), &api_url)
.await
.map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
this.update(cx, |this, cx| {
this.available_models = models;
@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
&self,
request: open_router::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<ResponseStreamEvent, open_router::OpenRouterError>,
>,
LanguageModelCompletionError,
>,
> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!(
"App state dropped: Unable to read API key or API URL from the application state"
)))
return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
"App state dropped"
))))
.boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
request.await.map_err(Into::into)
}
.boxed()
}
}
@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
>,
> {
let request = into_open_router(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
let mapper = OpenRouterEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await?;
Ok(OpenRouterEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
events: Pin<
Box<
dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
>,
>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
Err(error) => vec![Err(error.into())],
})
})
}

View file

@ -22,4 +22,6 @@ http_client.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
strum.workspace = true
workspace-hack.workspace = true

View file

@ -1,12 +1,31 @@
use anyhow::{Context, Result, anyhow};
use anyhow::{Result, anyhow};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::convert::TryFrom;
use std::{convert::TryFrom, io, time::Duration};
use strum::EnumString;
use thiserror::Error;
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
if let Some(reset) = headers.get("X-RateLimit-Reset") {
if let Ok(s) = reset.to_str() {
if let Ok(epoch_ms) = s.parse::<u64>() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
if epoch_ms > now {
return Some(std::time::Duration::from_millis(epoch_ms - now));
}
}
}
}
None
}
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
}
@ -413,76 +432,12 @@ pub struct ModelArchitecture {
pub input_modalities: Vec<String>,
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<Response> {
let uri = format!("{api_url}/chat/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor");
let mut request_body = request;
request_body.stream = false;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: Response = serde_json::from_str(&body)?;
Ok(response)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenRouterResponse {
error: OpenRouterError,
}
#[derive(Deserialize)]
struct OpenRouterError {
message: String,
#[serde(default)]
code: String,
}
match serde_json::from_str::<OpenRouterResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => {
let error_message = if !response.error.code.is_empty() {
format!("{}: {}", response.error.code, response.error.message)
} else {
response.error.message
};
Err(anyhow!(
"Failed to connect to OpenRouter API: {}",
error_message
))
}
_ => Err(anyhow!(
"Failed to connect to OpenRouter API: {} {}",
response.status(),
body,
)),
}
}
}
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
let uri = format!("{api_url}/chat/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
@ -492,8 +447,15 @@ pub async fn stream_completion(
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor");
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;
let request = request_builder
.body(AsyncBody::from(
serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
))
.map_err(OpenRouterError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
.map_err(OpenRouterError::HttpSend)?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@ -513,86 +475,85 @@ pub async fn stream_completion(
match serde_json::from_str::<ResponseStreamEvent>(line) {
Ok(response) => Some(Ok(response)),
Err(error) => {
#[derive(Deserialize)]
struct ErrorResponse {
error: String,
}
match serde_json::from_str::<ErrorResponse>(line) {
Ok(err_response) => Some(Err(anyhow!(err_response.error))),
Err(_) => {
if line.trim().is_empty() {
None
} else {
Some(Err(anyhow!(
"Failed to parse response: {}. Original content: '{}'",
error, line
)))
}
}
if line.trim().is_empty() {
None
} else {
Some(Err(OpenRouterError::DeserializeResponse(error)))
}
}
}
}
}
Err(error) => Some(Err(anyhow!(error))),
Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
}
})
.boxed())
} else {
let code = ApiErrorCode::from_status(response.status().as_u16());
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
#[derive(Deserialize)]
struct OpenRouterResponse {
error: OpenRouterError,
}
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
Ok(OpenRouterErrorResponse { error }) => error,
Err(_) => OpenRouterErrorBody {
code: response.status().as_u16(),
message: body,
metadata: None,
},
};
#[derive(Deserialize)]
struct OpenRouterError {
message: String,
#[serde(default)]
code: String,
}
match serde_json::from_str::<OpenRouterResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => {
let error_message = if !response.error.code.is_empty() {
format!("{}: {}", response.error.code, response.error.message)
} else {
response.error.message
};
Err(anyhow!(
"Failed to connect to OpenRouter API: {}",
error_message
))
match code {
ApiErrorCode::RateLimitError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::RateLimit {
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
})
}
_ => Err(anyhow!(
"Failed to connect to OpenRouter API: {} {}",
response.status(),
body,
)),
ApiErrorCode::OverloadedError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::ServerOverloaded { retry_after })
}
_ => Err(OpenRouterError::ApiError(ApiError {
code: code,
message: error_response.message,
})),
}
}
}
pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
pub async fn list_models(
client: &dyn HttpClient,
api_url: &str,
) -> Result<Vec<Model>, OpenRouterError> {
let uri = format!("{api_url}/models");
let request_builder = HttpRequest::builder()
.method(Method::GET)
.uri(uri)
.header("Accept", "application/json");
let request = request_builder.body(AsyncBody::default())?;
let mut response = client.send(request).await?;
let request = request_builder
.body(AsyncBody::default())
.map_err(OpenRouterError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
.map_err(OpenRouterError::HttpSend)?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
if response.status().is_success() {
let response: ListModelsResponse =
serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
let models = response
.data
@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
Ok(models)
} else {
Err(anyhow!(
"Failed to connect to OpenRouter API: {} {}",
response.status(),
body,
))
let code = ApiErrorCode::from_status(response.status().as_u16());
let mut body = String::new();
response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
Ok(OpenRouterErrorResponse { error }) => error,
Err(_) => OpenRouterErrorBody {
code: response.status().as_u16(),
message: body,
metadata: None,
},
};
match code {
ApiErrorCode::RateLimitError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::RateLimit {
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
})
}
ApiErrorCode::OverloadedError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::ServerOverloaded { retry_after })
}
_ => Err(OpenRouterError::ApiError(ApiError {
code: code,
message: error_response.message,
})),
}
}
}
#[derive(Debug)]
pub enum OpenRouterError {
/// Failed to serialize the HTTP request body to JSON
SerializeRequest(serde_json::Error),
/// Failed to construct the HTTP request body
BuildRequestBody(http::Error),
/// Failed to send the HTTP request
HttpSend(anyhow::Error),
/// Failed to deserialize the response from JSON
DeserializeResponse(serde_json::Error),
/// Failed to read from response stream
ReadResponse(io::Error),
/// Rate limit exceeded
RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response
ApiError(ApiError),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenRouterErrorBody {
pub code: u16,
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenRouterErrorResponse {
pub error: OpenRouterErrorBody,
}
#[derive(Debug, Serialize, Deserialize, Error)]
#[error("OpenRouter API Error: {code}: {message}")]
pub struct ApiError {
pub code: ApiErrorCode,
pub message: String,
}
/// An OpenROuter API error code.
/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
#[strum(serialize_all = "snake_case")]
pub enum ApiErrorCode {
/// 400: Bad Request (invalid or missing params, CORS)
InvalidRequestError,
/// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
AuthenticationError,
/// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
PaymentRequiredError,
/// 403: Your chosen model requires moderation and your input was flagged
PermissionError,
/// 408: Your request timed out
RequestTimedOut,
/// 429: You are being rate limited
RateLimitError,
/// 502: Your chosen model is down or we received an invalid response from it
ApiError,
/// 503: There is no available model provider that meets your routing requirements
OverloadedError,
}
impl std::fmt::Display for ApiErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
ApiErrorCode::InvalidRequestError => "invalid_request_error",
ApiErrorCode::AuthenticationError => "authentication_error",
ApiErrorCode::PaymentRequiredError => "payment_required_error",
ApiErrorCode::PermissionError => "permission_error",
ApiErrorCode::RequestTimedOut => "request_timed_out",
ApiErrorCode::RateLimitError => "rate_limit_error",
ApiErrorCode::ApiError => "api_error",
ApiErrorCode::OverloadedError => "overloaded_error",
};
write!(f, "{s}")
}
}
impl ApiErrorCode {
pub fn from_status(status: u16) -> Self {
match status {
400 => ApiErrorCode::InvalidRequestError,
401 => ApiErrorCode::AuthenticationError,
402 => ApiErrorCode::PaymentRequiredError,
403 => ApiErrorCode::PermissionError,
408 => ApiErrorCode::RequestTimedOut,
429 => ApiErrorCode::RateLimitError,
502 => ApiErrorCode::ApiError,
503 => ApiErrorCode::OverloadedError,
_ => ApiErrorCode::ApiError,
}
}
}