This commit is contained in:
Umesh Yadav 2025-08-26 15:08:05 +05:30 committed by GitHub
commit bacf7396a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 332 additions and 150 deletions

3
Cargo.lock generated
View file

@ -9110,6 +9110,7 @@ dependencies = [
"icons", "icons",
"image", "image",
"log", "log",
"open_router",
"parking_lot", "parking_lot",
"proto", "proto",
"schemars", "schemars",
@ -11192,6 +11193,8 @@ dependencies = [
"schemars", "schemars",
"serde", "serde",
"serde_json", "serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
"workspace-hack", "workspace-hack",
] ]

View file

@ -17,6 +17,7 @@ test-support = []
[dependencies] [dependencies]
anthropic = { workspace = true, features = ["schemars"] } anthropic = { workspace = true, features = ["schemars"] }
open_router.workspace = true
anyhow.workspace = true anyhow.workspace = true
base64.workspace = true base64.workspace = true
client.workspace = true client.workspace = true

View file

@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http}; use http_client::{StatusCode, http};
use icons::IconName; use icons::IconName;
use open_router::OpenRouterError;
use parking_lot::Mutex; use parking_lot::Mutex;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize, de::DeserializeOwned};
@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
} }
} }
impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
}
impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
}
}
}
/// Indicates the format used to define the input schema for a language model tool. /// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat { pub enum LanguageModelToolSchemaFormat {

View file

@ -152,6 +152,7 @@ impl State {
.open_router .open_router
.api_url .api_url
.clone(); .clone();
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
(api_key, true) (api_key, true)
@ -161,11 +162,11 @@ impl State {
.await? .await?
.ok_or(AuthenticateError::CredentialsNotFound)?; .ok_or(AuthenticateError::CredentialsNotFound)?;
( (
String::from_utf8(api_key) String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
.context(format!("invalid {} API key", PROVIDER_NAME))?,
false, false,
) )
}; };
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.api_key = Some(api_key); this.api_key = Some(api_key);
this.api_key_from_env = from_env; this.api_key_from_env = from_env;
@ -183,7 +184,9 @@ impl State {
let api_url = settings.api_url.clone(); let api_url = settings.api_url.clone();
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let models = list_models(http_client.as_ref(), &api_url).await?; let models = list_models(http_client.as_ref(), &api_url)
.await
.map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.available_models = models; this.available_models = models;
@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
&self, &self,
request: open_router::Request, request: open_router::Request,
cx: &AsyncApp, cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>> ) -> BoxFuture<
{ 'static,
Result<
futures::stream::BoxStream<
'static,
Result<ResponseStreamEvent, open_router::OpenRouterError>,
>,
LanguageModelCompletionError,
>,
> {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).open_router; let settings = &AllLanguageModelSettings::get_global(cx).open_router;
(state.api_key.clone(), settings.api_url.clone()) (state.api_key.clone(), settings.api_url.clone())
}) else { }) else {
return futures::future::ready(Err(anyhow!( return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
"App state dropped: Unable to read API key or API URL from the application state" "App state dropped"
))) ))))
.boxed(); .boxed();
}; };
let future = self.request_limiter.stream(async move { async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?; let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
};
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?; request.await.map_err(Into::into)
Ok(response) }
}); .boxed()
async move { Ok(future.await?.boxed()) }.boxed()
} }
} }
@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
>, >,
> { > {
let request = into_open_router(request, &self.model, self.max_output_tokens()); let request = into_open_router(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx); let request = self.stream_completion(request, cx);
async move { let future = self.request_limiter.stream(async move {
let mapper = OpenRouterEventMapper::new(); let response = request.await?;
Ok(mapper.map_stream(completions.await?).boxed()) Ok(OpenRouterEventMapper::new().map_stream(response))
} });
.boxed() async move { Ok(future.await?.boxed()) }.boxed()
} }
} }
@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
pub fn map_stream( pub fn map_stream(
mut self, mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>, events: Pin<
Box<
dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
>,
>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{ {
events.flat_map(move |event| { events.flat_map(move |event| {
futures::stream::iter(match event { futures::stream::iter(match event {
Ok(event) => self.map_event(event), Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], Err(error) => vec![Err(error.into())],
}) })
}) })
} }

View file

@ -22,4 +22,6 @@ http_client.workspace = true
schemars = { workspace = true, optional = true } schemars = { workspace = true, optional = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
thiserror.workspace = true
strum.workspace = true
workspace-hack.workspace = true workspace-hack.workspace = true

View file

@ -1,12 +1,31 @@
use anyhow::{Context, Result, anyhow}; use anyhow::{Result, anyhow};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::convert::TryFrom; use std::{convert::TryFrom, io, time::Duration};
use strum::EnumString;
use thiserror::Error;
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1"; pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
if let Some(reset) = headers.get("X-RateLimit-Reset") {
if let Ok(s) = reset.to_str() {
if let Ok(epoch_ms) = s.parse::<u64>() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
if epoch_ms > now {
return Some(std::time::Duration::from_millis(epoch_ms - now));
}
}
}
}
None
}
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool { fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
opt.as_ref().is_none_or(|v| v.as_ref().is_empty()) opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
} }
@ -413,76 +432,12 @@ pub struct ModelArchitecture {
pub input_modalities: Vec<String>, pub input_modalities: Vec<String>,
} }
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
) -> Result<Response> {
let uri = format!("{api_url}/chat/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor");
let mut request_body = request;
request_body.stream = false;
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
let response: Response = serde_json::from_str(&body)?;
Ok(response)
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
#[derive(Deserialize)]
struct OpenRouterResponse {
error: OpenRouterError,
}
#[derive(Deserialize)]
struct OpenRouterError {
message: String,
#[serde(default)]
code: String,
}
match serde_json::from_str::<OpenRouterResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => {
let error_message = if !response.error.code.is_empty() {
format!("{}: {}", response.error.code, response.error.message)
} else {
response.error.message
};
Err(anyhow!(
"Failed to connect to OpenRouter API: {}",
error_message
))
}
_ => Err(anyhow!(
"Failed to connect to OpenRouter API: {} {}",
response.status(),
body,
)),
}
}
}
pub async fn stream_completion( pub async fn stream_completion(
client: &dyn HttpClient, client: &dyn HttpClient,
api_url: &str, api_url: &str,
api_key: &str, api_key: &str,
request: Request, request: Request,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> { ) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
let uri = format!("{api_url}/chat/completions"); let uri = format!("{api_url}/chat/completions");
let request_builder = HttpRequest::builder() let request_builder = HttpRequest::builder()
.method(Method::POST) .method(Method::POST)
@ -492,8 +447,15 @@ pub async fn stream_completion(
.header("HTTP-Referer", "https://zed.dev") .header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor"); .header("X-Title", "Zed Editor");
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let request = request_builder
let mut response = client.send(request).await?; .body(AsyncBody::from(
serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
))
.map_err(OpenRouterError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
.map_err(OpenRouterError::HttpSend)?;
if response.status().is_success() { if response.status().is_success() {
let reader = BufReader::new(response.into_body()); let reader = BufReader::new(response.into_body());
@ -513,86 +475,85 @@ pub async fn stream_completion(
match serde_json::from_str::<ResponseStreamEvent>(line) { match serde_json::from_str::<ResponseStreamEvent>(line) {
Ok(response) => Some(Ok(response)), Ok(response) => Some(Ok(response)),
Err(error) => { Err(error) => {
#[derive(Deserialize)] if line.trim().is_empty() {
struct ErrorResponse { None
error: String, } else {
} Some(Err(OpenRouterError::DeserializeResponse(error)))
match serde_json::from_str::<ErrorResponse>(line) {
Ok(err_response) => Some(Err(anyhow!(err_response.error))),
Err(_) => {
if line.trim().is_empty() {
None
} else {
Some(Err(anyhow!(
"Failed to parse response: {}. Original content: '{}'",
error, line
)))
}
}
} }
} }
} }
} }
} }
Err(error) => Some(Err(anyhow!(error))), Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
} }
}) })
.boxed()) .boxed())
} else { } else {
let code = ApiErrorCode::from_status(response.status().as_u16());
let mut body = String::new(); let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?; response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
#[derive(Deserialize)] let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
struct OpenRouterResponse { Ok(OpenRouterErrorResponse { error }) => error,
error: OpenRouterError, Err(_) => OpenRouterErrorBody {
} code: response.status().as_u16(),
message: body,
metadata: None,
},
};
#[derive(Deserialize)] match code {
struct OpenRouterError { ApiErrorCode::RateLimitError => {
message: String, let retry_after = extract_retry_after(response.headers());
#[serde(default)] Err(OpenRouterError::RateLimit {
code: String, retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
} })
match serde_json::from_str::<OpenRouterResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => {
let error_message = if !response.error.code.is_empty() {
format!("{}: {}", response.error.code, response.error.message)
} else {
response.error.message
};
Err(anyhow!(
"Failed to connect to OpenRouter API: {}",
error_message
))
} }
_ => Err(anyhow!( ApiErrorCode::OverloadedError => {
"Failed to connect to OpenRouter API: {} {}", let retry_after = extract_retry_after(response.headers());
response.status(), Err(OpenRouterError::ServerOverloaded { retry_after })
body, }
)), _ => Err(OpenRouterError::ApiError(ApiError {
code: code,
message: error_response.message,
})),
} }
} }
} }
pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> { pub async fn list_models(
client: &dyn HttpClient,
api_url: &str,
) -> Result<Vec<Model>, OpenRouterError> {
let uri = format!("{api_url}/models"); let uri = format!("{api_url}/models");
let request_builder = HttpRequest::builder() let request_builder = HttpRequest::builder()
.method(Method::GET) .method(Method::GET)
.uri(uri) .uri(uri)
.header("Accept", "application/json"); .header("Accept", "application/json");
let request = request_builder.body(AsyncBody::default())?; let request = request_builder
let mut response = client.send(request).await?; .body(AsyncBody::default())
.map_err(OpenRouterError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
.map_err(OpenRouterError::HttpSend)?;
let mut body = String::new(); let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?; response
.body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
if response.status().is_success() { if response.status().is_success() {
let response: ListModelsResponse = let response: ListModelsResponse =
serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?; serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
let models = response let models = response
.data .data
@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
Ok(models) Ok(models)
} else { } else {
Err(anyhow!( let code = ApiErrorCode::from_status(response.status().as_u16());
"Failed to connect to OpenRouter API: {} {}",
response.status(), let mut body = String::new();
body, response
)) .body_mut()
.read_to_string(&mut body)
.await
.map_err(OpenRouterError::ReadResponse)?;
let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
Ok(OpenRouterErrorResponse { error }) => error,
Err(_) => OpenRouterErrorBody {
code: response.status().as_u16(),
message: body,
metadata: None,
},
};
match code {
ApiErrorCode::RateLimitError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::RateLimit {
retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
})
}
ApiErrorCode::OverloadedError => {
let retry_after = extract_retry_after(response.headers());
Err(OpenRouterError::ServerOverloaded { retry_after })
}
_ => Err(OpenRouterError::ApiError(ApiError {
code: code,
message: error_response.message,
})),
}
}
}
#[derive(Debug)]
pub enum OpenRouterError {
/// Failed to serialize the HTTP request body to JSON
SerializeRequest(serde_json::Error),
/// Failed to construct the HTTP request body
BuildRequestBody(http::Error),
/// Failed to send the HTTP request
HttpSend(anyhow::Error),
/// Failed to deserialize the response from JSON
DeserializeResponse(serde_json::Error),
/// Failed to read from response stream
ReadResponse(io::Error),
/// Rate limit exceeded
RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response
ApiError(ApiError),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenRouterErrorBody {
pub code: u16,
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OpenRouterErrorResponse {
pub error: OpenRouterErrorBody,
}
#[derive(Debug, Serialize, Deserialize, Error)]
#[error("OpenRouter API Error: {code}: {message}")]
pub struct ApiError {
pub code: ApiErrorCode,
pub message: String,
}
/// An OpenROuter API error code.
/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
#[strum(serialize_all = "snake_case")]
pub enum ApiErrorCode {
/// 400: Bad Request (invalid or missing params, CORS)
InvalidRequestError,
/// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
AuthenticationError,
/// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
PaymentRequiredError,
/// 403: Your chosen model requires moderation and your input was flagged
PermissionError,
/// 408: Your request timed out
RequestTimedOut,
/// 429: You are being rate limited
RateLimitError,
/// 502: Your chosen model is down or we received an invalid response from it
ApiError,
/// 503: There is no available model provider that meets your routing requirements
OverloadedError,
}
impl std::fmt::Display for ApiErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
ApiErrorCode::InvalidRequestError => "invalid_request_error",
ApiErrorCode::AuthenticationError => "authentication_error",
ApiErrorCode::PaymentRequiredError => "payment_required_error",
ApiErrorCode::PermissionError => "permission_error",
ApiErrorCode::RequestTimedOut => "request_timed_out",
ApiErrorCode::RateLimitError => "rate_limit_error",
ApiErrorCode::ApiError => "api_error",
ApiErrorCode::OverloadedError => "overloaded_error",
};
write!(f, "{s}")
}
}
impl ApiErrorCode {
pub fn from_status(status: u16) -> Self {
match status {
400 => ApiErrorCode::InvalidRequestError,
401 => ApiErrorCode::AuthenticationError,
402 => ApiErrorCode::PaymentRequiredError,
403 => ApiErrorCode::PermissionError,
408 => ApiErrorCode::RequestTimedOut,
429 => ApiErrorCode::RateLimitError,
502 => ApiErrorCode::ApiError,
503 => ApiErrorCode::OverloadedError,
_ => ApiErrorCode::ApiError,
}
} }
} }