From c9c603b1d1d1f078b89e4a59650f07cae6fbcacc Mon Sep 17 00:00:00 2001 From: Umesh Yadav <23421535+imumesh18@users.noreply.github.com> Date: Tue, 3 Jun 2025 21:29:46 +0530 Subject: [PATCH] Add support for OpenRouter as a language model provider (#29496) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pull request adds full integration with OpenRouter, allowing users to access a wide variety of language models through a single API key. **Implementation Details:** * **Provider Registration:** Registers OpenRouter as a new language model provider within the application's model registry. This includes UI for API key authentication, token counting, streaming completions, and tool-call handling. * **Dedicated Crate:** Adds a new `open_router` crate to manage interactions with the OpenRouter HTTP API, including model discovery and streaming helpers. * **UI & Configuration:** Extends workspace manifests, the settings schema, icons, and default configurations to surface the OpenRouter provider and its settings within the UI. * **Readability:** Reformats JSON arrays within the settings files for improved readability. **Design Decisions & Discussion Points:** * **Code Reuse:** I leveraged much of the existing logic from the `openai` provider integration due to the significant similarities between the OpenAI and OpenRouter API specifications. * **Default Model:** I set the default model to `openrouter/auto`. This model automatically routes user prompts to the most suitable underlying model on OpenRouter, providing a convenient starting point. * **Model Population Strategy:** * I've implemented dynamic population of available models by querying the OpenRouter API upon initialization. * Currently, this involves three separate API calls: one for all models, one for tool-use models, and one for models good at programming. * The data from the tool-use API call sets a `tool_use` flag for relevant models. * The data from the programming models API call is used to sort the list, prioritizing coding-focused models in the dropdown. * **Feedback Welcome:** I acknowledge this multi-call approach is API-intensive. I am open to feedback and alternative implementation suggestions if the team believes this can be optimized. * **Update: Now this has been simplified to one api call.** * **UI/UX Considerations:** * Authentication Method: Currently, I've implemented the standard API key input in settings, similar to other providers like OpenAI/Anthropic. However, OpenRouter also supports OAuth 2.0 with PKCE. This could offer a potentially smoother, more integrated setup experience for users (e.g., clicking a button to authorize instead of copy-pasting a key). Should we prioritize implementing OAuth PKCE now, or perhaps add it as an alternative option later?(PKCE is not straight forward and complicated so skipping this for now. So that we can add the support and work on this later.) * To visually distinguish models better suited for programming, I've considered adding a marker (e.g., `` or `🧠`) next to their names. Thoughts on this proposal?. (This will require a changes and discussion across model provider. This doesn't fall under the scope of current PR). * OpenRouter offers 300+ models. The current implementation loads all of them. **Feedback Needed:** Should we refine this list or implement more sophisticated filtering/categorization for better usability? **Motivation:** This integration directly addresses one of the most highly upvoted feature requests/discussions within the Zed community. Adding OpenRouter support significantly expands the range of AI models accessible to users. I welcome feedback from the Zed team on this implementation and the design choices made. I am eager to refine this feature and make it available to users. ISSUES: https://github.com/zed-industries/zed/discussions/16576 Release Notes: - Added support for OpenRouter as a language model provider. --------- Signed-off-by: Umesh Yadav Co-authored-by: Marshall Bowers --- Cargo.lock | 14 + Cargo.toml | 2 + assets/icons/ai_open_router.svg | 8 + assets/settings/default.json | 3 + crates/agent_settings/src/agent_settings.rs | 1 + crates/icons/src/icons.rs | 1 + crates/language_models/Cargo.toml | 1 + crates/language_models/src/language_models.rs | 5 + crates/language_models/src/provider.rs | 1 + .../src/provider/open_router.rs | 788 ++++++++++++++++++ crates/language_models/src/settings.rs | 22 + crates/open_router/Cargo.toml | 25 + crates/open_router/LICENSE-GPL | 1 + crates/open_router/src/open_router.rs | 484 +++++++++++ 14 files changed, 1356 insertions(+) create mode 100644 assets/icons/ai_open_router.svg create mode 100644 crates/language_models/src/provider/open_router.rs create mode 100644 crates/open_router/Cargo.toml create mode 120000 crates/open_router/LICENSE-GPL create mode 100644 crates/open_router/src/open_router.rs diff --git a/Cargo.lock b/Cargo.lock index 88283152ba..9cf69ed1c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8864,6 +8864,7 @@ dependencies = [ "mistral", "ollama", "open_ai", + "open_router", "partial-json-fixer", "project", "proto", @@ -10708,6 +10709,19 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "open_router" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", + "workspace-hack", +] + [[package]] name = "opener" version = "0.7.2" diff --git a/Cargo.toml b/Cargo.toml index 9152dfd23c..852e3ba413 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ members = [ "crates/notifications", "crates/ollama", "crates/open_ai", + "crates/open_router", "crates/outline", "crates/outline_panel", "crates/panel", @@ -307,6 +308,7 @@ node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } ollama = { path = "crates/ollama" } open_ai = { path = "crates/open_ai" } +open_router = { path = "crates/open_router", features = ["schemars"] } outline = { path = "crates/outline" } outline_panel = { path = "crates/outline_panel" } panel = { path = "crates/panel" } diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg new file mode 100644 index 0000000000..cc8597729a --- /dev/null +++ b/assets/icons/ai_open_router.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 3ae4417505..7c0688831d 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1605,6 +1605,9 @@ "version": "1", "api_url": "https://api.openai.com/v1" }, + "open_router": { + "api_url": "https://openrouter.ai/api/v1" + }, "lmstudio": { "api_url": "http://localhost:1234/api/v0" }, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index ce7bd56047..36480f30d5 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -730,6 +730,7 @@ impl JsonSchema for LanguageModelProviderSetting { "zed.dev".into(), "copilot_chat".into(), "deepseek".into(), + "openrouter".into(), "mistral".into(), ]), ..Default::default() diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 2896a19829..adfbe1e52d 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -18,6 +18,7 @@ pub enum IconName { AiMistral, AiOllama, AiOpenAi, + AiOpenRouter, AiZed, ArrowCircle, ArrowDown, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 2c5048b910..ab5090e9ba 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -39,6 +39,7 @@ menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +open_router = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true project.workspace = true proto.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 61c5dcf642..0224da4e6b 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -19,6 +19,7 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider; use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; +use crate::provider::open_router::OpenRouterLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity, client: Arc, fs: Arc, cx: &mut App) { @@ -72,5 +73,9 @@ fn register_language_model_providers( BedrockLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + OpenRouterLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6b183292f3..4f2ea9cc09 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -8,3 +8,4 @@ pub mod lmstudio; pub mod mistral; pub mod ollama; pub mod open_ai; +pub mod open_router; diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs new file mode 100644 index 0000000000..7af265544a --- /dev/null +++ b/crates/language_models/src/provider/open_router.rs @@ -0,0 +1,788 @@ +use anyhow::{Context as _, Result, anyhow}; +use collections::HashMap; +use credentials_provider::CredentialsProvider; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; +use gpui::{ + AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, +}; +use open_router::{Model, ResponseStreamEvent, list_models, stream_completion}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr as _; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{Icon, IconName, List, Tooltip, prelude::*}; +use util::ResultExt; + +use crate::{AllLanguageModelSettings, ui::InstructionListItem}; + +const PROVIDER_ID: &str = "openrouter"; +const PROVIDER_NAME: &str = "OpenRouter"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenRouterSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub max_output_tokens: Option, + pub max_completion_tokens: Option, +} + +pub struct OpenRouterLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + http_client: Arc, + available_models: Vec, + fetch_models_task: Option>>, + _subscription: Subscription, +} + +const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key) + .context(format!("invalid {} API key", PROVIDER_NAME))?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } + + fn fetch_models(&mut self, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).open_router; + let http_client = self.http_client.clone(); + let api_url = settings.api_url.clone(); + + cx.spawn(async move |this, cx| { + let models = list_models(http_client.as_ref(), &api_url).await?; + + this.update(cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + }) + } + + fn restart_fetch_models_task(&mut self, cx: &mut Context) { + let task = self.fetch_models(cx); + self.fetch_models_task.replace(task); + } +} + +impl OpenRouterLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.restart_fetch_models_task(cx); + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: open_router::Model) -> Arc { + Arc::new(OpenRouterLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for OpenRouterLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for OpenRouterLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiOpenRouter + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models_from_api = self.state.read(cx).available_models.clone(); + let mut settings_models = Vec::new(); + + for model in &AllLanguageModelSettings::get_global(cx) + .open_router + .available_models + { + settings_models.push(open_router::Model { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: Some(false), + }); + } + + for settings_model in &settings_models { + if let Some(pos) = models_from_api + .iter() + .position(|m| m.name == settings_model.name) + { + models_from_api[pos] = settings_model.clone(); + } else { + models_from_api.push(settings_model.clone()); + } + } + + models_from_api + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct OpenRouterLanguageModel { + id: LanguageModelId, + model: open_router::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl OpenRouterLanguageModel { + fn stream_completion( + &self, + request: open_router::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).open_router; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!( + "App state dropped: Unable to read API key or API URL from the application state" + ))) + .boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?; + let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for OpenRouterLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool_calls() + } + + fn telemetry_id(&self) -> String { + format!("openrouter/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } + } + + fn supports_images(&self) -> bool { + false + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_open_router_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + >, + > { + let request = into_open_router(request, &self.model, self.max_output_tokens()); + let completions = self.stream_completion(request, cx); + async move { + let mapper = OpenRouterEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +pub fn into_open_router( + request: LanguageModelRequest, + model: &Model, + max_output_tokens: Option, +) -> open_router::Request { + let mut messages = Vec::new(); + for req_message in request.messages { + for content in req_message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match req_message.role { + Role::User => open_router::RequestMessage::User { content: text }, + Role::Assistant => open_router::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => open_router::RequestMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = open_router::ToolCall { + id: tool_use.id.to_string(), + content: open_router::ToolCallContent::Function { + function: open_router::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(open_router::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + text.to_string() + } + LanguageModelToolResultContent::Image(_) => { + "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() + } + }; + + messages.push(open_router::RequestMessage::Tool { + content: content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + open_router::Request { + model: model.id().into(), + messages, + stream: true, + stop: request.stop, + temperature: request.temperature.unwrap_or(0.4), + max_tokens: max_output_tokens, + parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { + Some(false) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| open_router::ToolDefinition::Function { + function: open_router::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto, + LanguageModelToolChoice::Any => open_router::ToolChoice::Required, + LanguageModelToolChoice::None => open_router::ToolChoice::None, + }), + } +} + +pub struct OpenRouterEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenRouterEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; + + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + +pub fn count_open_router_tokens( + request: LanguageModelRequest, + _model: open_router::Model, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity, + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor + .set_placeholder_text("sk_or_000000000000000000000000000000000000000000000000", cx); + editor + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + white_space: WhiteSpace::Normal, + ..Default::default() + }; + EditorElement::new( + &self.api_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create an API key by visiting", + Some("OpenRouter's console"), + Some("https://openrouter.ai/keys"), + )) + .child(InstructionListItem::text_only( + "Ensure your OpenRouter account has credits", + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(cx.theme().colors().border) + .rounded_sm() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset Key") + .label_size(LabelSize::Small) + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index abbb237b4f..2cf549c8f6 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -20,6 +20,7 @@ use crate::provider::{ mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, + open_router::OpenRouterSettings, }; /// Initializes the language model settings. @@ -61,6 +62,7 @@ pub struct AllLanguageModelSettings { pub bedrock: AmazonBedrockSettings, pub ollama: OllamaSettings, pub openai: OpenAiSettings, + pub open_router: OpenRouterSettings, pub zed_dot_dev: ZedDotDevSettings, pub google: GoogleSettings, pub copilot_chat: CopilotChatSettings, @@ -76,6 +78,7 @@ pub struct AllLanguageModelSettingsContent { pub ollama: Option, pub lmstudio: Option, pub openai: Option, + pub open_router: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, pub google: Option, @@ -271,6 +274,12 @@ pub struct ZedDotDevSettingsContent { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct CopilotChatSettingsContent {} +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OpenRouterSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + impl settings::Settings for AllLanguageModelSettings { const KEY: Option<&'static str> = Some("language_models"); @@ -409,6 +418,19 @@ impl settings::Settings for AllLanguageModelSettings { &mut settings.mistral.available_models, mistral.as_ref().and_then(|s| s.available_models.clone()), ); + + // OpenRouter + let open_router = value.open_router.clone(); + merge( + &mut settings.open_router.api_url, + open_router.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.open_router.available_models, + open_router + .as_ref() + .and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml new file mode 100644 index 0000000000..bbc4fe190f --- /dev/null +++ b/crates/open_router/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "open_router" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/open_router.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true diff --git a/crates/open_router/LICENSE-GPL b/crates/open_router/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/open_router/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs new file mode 100644 index 0000000000..f0fe071503 --- /dev/null +++ b/crates/open_router/src/open_router.rs @@ -0,0 +1,484 @@ +use anyhow::{Context, Result, anyhow}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::convert::TryFrom; + +pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1"; + +fn is_none_or_empty, U>(opt: &Option) -> bool { + opt.as_ref().map_or(true, |v| v.as_ref().is_empty()) +} + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub supports_tools: Option, +} + +impl Model { + pub fn default_fast() -> Self { + Self::new( + "openrouter/auto", + Some("Auto Router"), + Some(2000000), + Some(true), + ) + } + + pub fn default() -> Self { + Self::default_fast() + } + + pub fn new( + name: &str, + display_name: Option<&str>, + max_tokens: Option, + supports_tools: Option, + ) -> Self { + Self { + name: name.to_owned(), + display_name: display_name.map(|s| s.to_owned()), + max_tokens: max_tokens.unwrap_or(2000000), + supports_tools, + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + self.display_name.as_ref().unwrap_or(&self.name) + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } + + pub fn max_output_tokens(&self) -> Option { + None + } + + pub fn supports_tool_calls(&self) -> bool { + self.supports_tools.unwrap_or(false) + } + + pub fn supports_parallel_tool_calls(&self) -> bool { + false + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub stop: Vec, + pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto, + Required, + None, + Other(ToolDefinition), +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + #[allow(dead_code)] + Function { function: FunctionDefinition }, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessageDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "is_none_or_empty")] + pub tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ChoiceDelta { + pub index: u32, + pub delta: ResponseMessageDelta, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResponseStreamEvent { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Deserialize)] +pub struct ListModelsResponse { + pub data: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Deserialize)] +pub struct ModelEntry { + pub id: String, + pub name: String, + pub created: usize, + pub description: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub context_length: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub supported_parameters: Vec, +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .header("HTTP-Referer", "https://zed.dev") + .header("X-Title", "Zed Editor"); + + let mut request_body = request; + request_body.stream = false; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: Response = serde_json::from_str(&body)?; + Ok(response) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenRouterResponse { + error: OpenRouterError, + } + + #[derive(Deserialize)] + struct OpenRouterError { + message: String, + #[serde(default)] + code: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => { + let error_message = if !response.error.code.is_empty() { + format!("{}: {}", response.error.code, response.error.message) + } else { + response.error.message + }; + + Err(anyhow!( + "Failed to connect to OpenRouter API: {}", + error_message + )) + } + _ => Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )), + } + } +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .header("HTTP-Referer", "https://zed.dev") + .header("X-Title", "Zed Editor"); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + if line.starts_with(':') { + return None; + } + + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str::(line) { + Ok(response) => Some(Ok(response)), + Err(error) => { + #[derive(Deserialize)] + struct ErrorResponse { + error: String, + } + + match serde_json::from_str::(line) { + Ok(err_response) => Some(Err(anyhow!(err_response.error))), + Err(_) => { + if line.trim().is_empty() { + None + } else { + Some(Err(anyhow!( + "Failed to parse response: {}. Original content: '{}'", + error, line + ))) + } + } + } + } + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenRouterResponse { + error: OpenRouterError, + } + + #[derive(Deserialize)] + struct OpenRouterError { + message: String, + #[serde(default)] + code: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => { + let error_message = if !response.error.code.is_empty() { + format!("{}: {}", response.error.code, response.error.message) + } else { + response.error.message + }; + + Err(anyhow!( + "Failed to connect to OpenRouter API: {}", + error_message + )) + } + _ => Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )), + } + } +} + +pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result> { + let uri = format!("{api_url}/models"); + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + let request = request_builder.body(AsyncBody::default())?; + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: ListModelsResponse = + serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?; + + let models = response + .data + .into_iter() + .map(|entry| Model { + name: entry.id, + // OpenRouter returns display names in the format "provider_name: model_name". + // When displayed in the UI, these names can get truncated from the right. + // Since users typically already know the provider, we extract just the model name + // portion (after the colon) to create a more concise and user-friendly label + // for the model dropdown in the agent panel. + display_name: Some( + entry + .name + .split(':') + .next_back() + .unwrap_or(&entry.name) + .trim() + .to_string(), + ), + max_tokens: entry.context_length.unwrap_or(2000000), + supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())), + }) + .collect(); + + Ok(models) + } else { + Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )) + } +}