Add support for OpenRouter as a language model provider (#29496)
This pull request adds full integration with OpenRouter, allowing users to access a wide variety of language models through a single API key. **Implementation Details:** * **Provider Registration:** Registers OpenRouter as a new language model provider within the application's model registry. This includes UI for API key authentication, token counting, streaming completions, and tool-call handling. * **Dedicated Crate:** Adds a new `open_router` crate to manage interactions with the OpenRouter HTTP API, including model discovery and streaming helpers. * **UI & Configuration:** Extends workspace manifests, the settings schema, icons, and default configurations to surface the OpenRouter provider and its settings within the UI. * **Readability:** Reformats JSON arrays within the settings files for improved readability. **Design Decisions & Discussion Points:** * **Code Reuse:** I leveraged much of the existing logic from the `openai` provider integration due to the significant similarities between the OpenAI and OpenRouter API specifications. * **Default Model:** I set the default model to `openrouter/auto`. This model automatically routes user prompts to the most suitable underlying model on OpenRouter, providing a convenient starting point. * **Model Population Strategy:** * <strike>I've implemented dynamic population of available models by querying the OpenRouter API upon initialization. * Currently, this involves three separate API calls: one for all models, one for tool-use models, and one for models good at programming. * The data from the tool-use API call sets a `tool_use` flag for relevant models. * The data from the programming models API call is used to sort the list, prioritizing coding-focused models in the dropdown.</strike> * <strike>**Feedback Welcome:** I acknowledge this multi-call approach is API-intensive. I am open to feedback and alternative implementation suggestions if the team believes this can be optimized.</strike> * **Update: Now this has been simplified to one api call.** * **UI/UX Considerations:** * <strike>Authentication Method: Currently, I've implemented the standard API key input in settings, similar to other providers like OpenAI/Anthropic. However, OpenRouter also supports OAuth 2.0 with PKCE. This could offer a potentially smoother, more integrated setup experience for users (e.g., clicking a button to authorize instead of copy-pasting a key). Should we prioritize implementing OAuth PKCE now, or perhaps add it as an alternative option later?</strike>(PKCE is not straight forward and complicated so skipping this for now. So that we can add the support and work on this later.) * <strike>To visually distinguish models better suited for programming, I've considered adding a marker (e.g., `</>` or `🧠`) next to their names. Thoughts on this proposal?</strike>. (This will require a changes and discussion across model provider. This doesn't fall under the scope of current PR). * OpenRouter offers 300+ models. The current implementation loads all of them. **Feedback Needed:** Should we refine this list or implement more sophisticated filtering/categorization for better usability? **Motivation:** This integration directly addresses one of the most highly upvoted feature requests/discussions within the Zed community. Adding OpenRouter support significantly expands the range of AI models accessible to users. I welcome feedback from the Zed team on this implementation and the design choices made. I am eager to refine this feature and make it available to users. ISSUES: https://github.com/zed-industries/zed/discussions/16576 Release Notes: - Added support for OpenRouter as a language model provider. --------- Signed-off-by: Umesh Yadav <umesh4257@gmail.com> Co-authored-by: Marshall Bowers <git@maxdeviant.com>
This commit is contained in:
parent
e13b494c9e
commit
c9c603b1d1
14 changed files with 1356 additions and 0 deletions
25
crates/open_router/Cargo.toml
Normal file
25
crates/open_router/Cargo.toml
Normal file
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
name = "open_router"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/open_router.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace-hack.workspace = true
|
1
crates/open_router/LICENSE-GPL
Symbolic link
1
crates/open_router/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
484
crates/open_router/src/open_router.rs
Normal file
484
crates/open_router/src/open_router.rs
Normal file
|
@ -0,0 +1,484 @@
|
|||
use anyhow::{Context, Result, anyhow};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
||||
opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
"system" => Ok(Self::System),
|
||||
"tool" => Ok(Self::Tool),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
Role::System => "system".to_owned(),
|
||||
Role::Tool => "tool".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Model {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub supports_tools: Option<bool>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn default_fast() -> Self {
|
||||
Self::new(
|
||||
"openrouter/auto",
|
||||
Some("Auto Router"),
|
||||
Some(2000000),
|
||||
Some(true),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn default() -> Self {
|
||||
Self::default_fast()
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
name: &str,
|
||||
display_name: Option<&str>,
|
||||
max_tokens: Option<usize>,
|
||||
supports_tools: Option<bool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_owned(),
|
||||
display_name: display_name.map(|s| s.to_owned()),
|
||||
max_tokens: max_tokens.unwrap_or(2000000),
|
||||
supports_tools,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.display_name.as_ref().unwrap_or(&self.name)
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> Option<u32> {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn supports_tool_calls(&self) -> bool {
|
||||
self.supports_tools.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn supports_parallel_tool_calls(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
Other(ToolDefinition),
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolDefinition {
|
||||
#[allow(dead_code)]
|
||||
Function { function: FunctionDefinition },
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub content: ToolCallContent,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolCallContent {
|
||||
Function { function: FunctionContent },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionContent {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessageDelta {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "is_none_or_empty")]
|
||||
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCallChunk {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: Option<FunctionChunk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionChunk {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessageDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResponseStreamEvent {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChoiceDelta>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Response {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: RequestMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub data: Vec<ModelEntry>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ModelEntry {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub created: usize,
|
||||
pub description: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub context_length: Option<usize>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub supported_parameters: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn complete(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<Response> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("HTTP-Referer", "https://zed.dev")
|
||||
.header("X-Title", "Zed Editor");
|
||||
|
||||
let mut request_body = request;
|
||||
request_body.stream = false;
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
let response: Response = serde_json::from_str(&body)?;
|
||||
Ok(response)
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterResponse {
|
||||
error: OpenRouterError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterError {
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
code: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => {
|
||||
let error_message = if !response.error.code.is_empty() {
|
||||
format!("{}: {}", response.error.code, response.error.message)
|
||||
} else {
|
||||
response.error.message
|
||||
};
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {}",
|
||||
error_message
|
||||
))
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("HTTP-Referer", "https://zed.dev")
|
||||
.header("X-Title", "Zed Editor");
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
if line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
if line == "[DONE]" {
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str::<ResponseStreamEvent>(line) {
|
||||
Ok(response) => Some(Ok(response)),
|
||||
Err(error) => {
|
||||
#[derive(Deserialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<ErrorResponse>(line) {
|
||||
Ok(err_response) => Some(Err(anyhow!(err_response.error))),
|
||||
Err(_) => {
|
||||
if line.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Err(anyhow!(
|
||||
"Failed to parse response: {}. Original content: '{}'",
|
||||
error, line
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterResponse {
|
||||
error: OpenRouterError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OpenRouterError {
|
||||
message: String,
|
||||
#[serde(default)]
|
||||
code: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<OpenRouterResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => {
|
||||
let error_message = if !response.error.code.is_empty() {
|
||||
format!("{}: {}", response.error.code, response.error.message)
|
||||
} else {
|
||||
response.error.message
|
||||
};
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {}",
|
||||
error_message
|
||||
))
|
||||
}
|
||||
_ => Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
|
||||
let uri = format!("{api_url}/models");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(uri)
|
||||
.header("Accept", "application/json");
|
||||
|
||||
let request = request_builder.body(AsyncBody::default())?;
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let response: ListModelsResponse =
|
||||
serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
|
||||
|
||||
let models = response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|entry| Model {
|
||||
name: entry.id,
|
||||
// OpenRouter returns display names in the format "provider_name: model_name".
|
||||
// When displayed in the UI, these names can get truncated from the right.
|
||||
// Since users typically already know the provider, we extract just the model name
|
||||
// portion (after the colon) to create a more concise and user-friendly label
|
||||
// for the model dropdown in the agent panel.
|
||||
display_name: Some(
|
||||
entry
|
||||
.name
|
||||
.split(':')
|
||||
.next_back()
|
||||
.unwrap_or(&entry.name)
|
||||
.trim()
|
||||
.to_string(),
|
||||
),
|
||||
max_tokens: entry.context_length.unwrap_or(2000000),
|
||||
supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Failed to connect to OpenRouter API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue