diff --git a/Cargo.lock b/Cargo.lock index c835b503ad..c223dd6a0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11192,6 +11192,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "util", "workspace-hack", ] diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 8f2abfce35..09dda51126 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AppContext, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, + WhiteSpace, }; use http_client::HttpClient; use language_model::{ @@ -15,7 +16,8 @@ use language_model::{ LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use open_router::{ - Model, ModelMode as OpenRouterModelMode, ResponseStreamEvent, list_models, stream_completion, + Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models, + stream_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -48,6 +50,7 @@ pub struct AvailableModel { pub supports_tools: Option, pub supports_images: Option, pub mode: Option, + pub provider: Option, } #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] @@ -278,6 +281,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { supports_tools: model.supports_tools, supports_images: model.supports_images, mode: model.mode.clone().unwrap_or_default().into(), + provider: model.provider.clone(), }); } @@ -556,6 +560,7 @@ pub fn into_open_router( LanguageModelToolChoice::Any => open_router::ToolChoice::Required, LanguageModelToolChoice::None => open_router::ToolChoice::None, }), + provider: model.provider.clone(), } } diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml index bbc4fe190f..03f2ce29ce 100644 --- a/crates/open_router/Cargo.toml +++ b/crates/open_router/Cargo.toml @@ -23,3 +23,4 @@ schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true workspace-hack.workspace = true +util.workspace = true diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index b7e6d69d8f..6ca3a02148 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::convert::TryFrom; +use util::serde::default_true; pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1"; @@ -11,6 +12,41 @@ fn is_none_or_empty, U>(opt: &Option) -> bool { opt.as_ref().is_none_or(|v| v.as_ref().is_empty()) } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum DataCollection { + Allow, + Disallow, +} + +impl Default for DataCollection { + fn default() -> Self { + Self::Allow + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Provider { + #[serde(skip_serializing_if = "Option::is_none")] + order: Option>, + #[serde(default = "default_true")] + allow_fallbacks: bool, + #[serde(default)] + require_parameters: bool, + #[serde(default)] + data_collection: DataCollection, + #[serde(skip_serializing_if = "Option::is_none")] + only: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + ignore: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + quantizations: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + sort: Option, +} + #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { @@ -55,6 +91,7 @@ pub struct Model { pub supports_images: Option, #[serde(default)] pub mode: ModelMode, + pub provider: Option, } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -76,6 +113,7 @@ impl Model { Some(true), Some(false), Some(ModelMode::Default), + None, ) } @@ -90,6 +128,7 @@ impl Model { supports_tools: Option, supports_images: Option, mode: Option, + provider: Option, ) -> Self { Self { name: name.to_owned(), @@ -98,6 +137,7 @@ impl Model { supports_tools, supports_images, mode: mode.unwrap_or(ModelMode::Default), + provider, } } @@ -145,6 +185,7 @@ pub struct Request { #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning: Option, pub usage: RequestUsage, + pub provider: Option, } #[derive(Debug, Default, Serialize, Deserialize)] @@ -632,6 +673,7 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result