diff --git a/Cargo.lock b/Cargo.lock
index 88283152ba..9cf69ed1c4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -8864,6 +8864,7 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
+ "open_router",
"partial-json-fixer",
"project",
"proto",
@@ -10708,6 +10709,19 @@ dependencies = [
"workspace-hack",
]
+[[package]]
+name = "open_router"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "http_client",
+ "schemars",
+ "serde",
+ "serde_json",
+ "workspace-hack",
+]
+
[[package]]
name = "opener"
version = "0.7.2"
diff --git a/Cargo.toml b/Cargo.toml
index 9152dfd23c..852e3ba413 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -100,6 +100,7 @@ members = [
"crates/notifications",
"crates/ollama",
"crates/open_ai",
+ "crates/open_router",
"crates/outline",
"crates/outline_panel",
"crates/panel",
@@ -307,6 +308,7 @@ node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
ollama = { path = "crates/ollama" }
open_ai = { path = "crates/open_ai" }
+open_router = { path = "crates/open_router", features = ["schemars"] }
outline = { path = "crates/outline" }
outline_panel = { path = "crates/outline_panel" }
panel = { path = "crates/panel" }
diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg
new file mode 100644
index 0000000000..cc8597729a
--- /dev/null
+++ b/assets/icons/ai_open_router.svg
@@ -0,0 +1,8 @@
+
diff --git a/assets/settings/default.json b/assets/settings/default.json
index 3ae4417505..7c0688831d 100644
--- a/assets/settings/default.json
+++ b/assets/settings/default.json
@@ -1605,6 +1605,9 @@
"version": "1",
"api_url": "https://api.openai.com/v1"
},
+ "open_router": {
+ "api_url": "https://openrouter.ai/api/v1"
+ },
"lmstudio": {
"api_url": "http://localhost:1234/api/v0"
},
diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs
index ce7bd56047..36480f30d5 100644
--- a/crates/agent_settings/src/agent_settings.rs
+++ b/crates/agent_settings/src/agent_settings.rs
@@ -730,6 +730,7 @@ impl JsonSchema for LanguageModelProviderSetting {
"zed.dev".into(),
"copilot_chat".into(),
"deepseek".into(),
+ "openrouter".into(),
"mistral".into(),
]),
..Default::default()
diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs
index 2896a19829..adfbe1e52d 100644
--- a/crates/icons/src/icons.rs
+++ b/crates/icons/src/icons.rs
@@ -18,6 +18,7 @@ pub enum IconName {
AiMistral,
AiOllama,
AiOpenAi,
+ AiOpenRouter,
AiZed,
ArrowCircle,
ArrowDown,
diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml
index 2c5048b910..ab5090e9ba 100644
--- a/crates/language_models/Cargo.toml
+++ b/crates/language_models/Cargo.toml
@@ -39,6 +39,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
+open_router = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
project.workspace = true
proto.workspace = true
diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs
index 61c5dcf642..0224da4e6b 100644
--- a/crates/language_models/src/language_models.rs
+++ b/crates/language_models/src/language_models.rs
@@ -19,6 +19,7 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
+use crate::provider::open_router::OpenRouterLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity, client: Arc, fs: Arc, cx: &mut App) {
@@ -72,5 +73,9 @@ fn register_language_model_providers(
BedrockLanguageModelProvider::new(client.http_client(), cx),
cx,
);
+ registry.register_provider(
+ OpenRouterLanguageModelProvider::new(client.http_client(), cx),
+ cx,
+ );
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}
diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs
index 6b183292f3..4f2ea9cc09 100644
--- a/crates/language_models/src/provider.rs
+++ b/crates/language_models/src/provider.rs
@@ -8,3 +8,4 @@ pub mod lmstudio;
pub mod mistral;
pub mod ollama;
pub mod open_ai;
+pub mod open_router;
diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs
new file mode 100644
index 0000000000..7af265544a
--- /dev/null
+++ b/crates/language_models/src/provider/open_router.rs
@@ -0,0 +1,788 @@
+use anyhow::{Context as _, Result, anyhow};
+use collections::HashMap;
+use credentials_provider::CredentialsProvider;
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
+use gpui::{
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+};
+use http_client::HttpClient;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, StopReason,
+};
+use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr as _;
+use std::sync::Arc;
+use theme::ThemeSettings;
+use ui::{Icon, IconName, List, Tooltip, prelude::*};
+use util::ResultExt;
+
+use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+
+const PROVIDER_ID: &str = "openrouter";
+const PROVIDER_NAME: &str = "OpenRouter";
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenRouterSettings {
+ pub api_url: String,
+ pub available_models: Vec,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+ pub name: String,
+ pub display_name: Option,
+ pub max_tokens: usize,
+ pub max_output_tokens: Option,
+ pub max_completion_tokens: Option,
+}
+
+pub struct OpenRouterLanguageModelProvider {
+ http_client: Arc,
+ state: gpui::Entity,
+}
+
+pub struct State {
+ api_key: Option,
+ api_key_from_env: bool,
+ http_client: Arc,
+ available_models: Vec,
+ fetch_models_task: Option>>,
+ _subscription: Subscription,
+}
+
+const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY";
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
+ fn reset_api_key(&self, cx: &mut Context) -> Task> {
+ let credentials_provider = ::global(cx);
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .open_router
+ .api_url
+ .clone();
+ cx.spawn(async move |this, cx| {
+ credentials_provider
+ .delete_credentials(&api_url, &cx)
+ .await
+ .log_err();
+ this.update(cx, |this, cx| {
+ this.api_key = None;
+ this.api_key_from_env = false;
+ cx.notify();
+ })
+ })
+ }
+
+ fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> {
+ let credentials_provider = ::global(cx);
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .open_router
+ .api_url
+ .clone();
+ cx.spawn(async move |this, cx| {
+ credentials_provider
+ .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
+ .await
+ .log_err();
+ this.update(cx, |this, cx| {
+ this.api_key = Some(api_key);
+ cx.notify();
+ })
+ })
+ }
+
+ fn authenticate(&self, cx: &mut Context) -> Task> {
+ if self.is_authenticated() {
+ return Task::ready(Ok(()));
+ }
+
+ let credentials_provider = ::global(cx);
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .open_router
+ .api_url
+ .clone();
+ cx.spawn(async move |this, cx| {
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
+ (api_key, true)
+ } else {
+ let (_, api_key) = credentials_provider
+ .read_credentials(&api_url, &cx)
+ .await?
+ .ok_or(AuthenticateError::CredentialsNotFound)?;
+ (
+ String::from_utf8(api_key)
+ .context(format!("invalid {} API key", PROVIDER_NAME))?,
+ false,
+ )
+ };
+ this.update(cx, |this, cx| {
+ this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
+ cx.notify();
+ })?;
+
+ Ok(())
+ })
+ }
+
+ fn fetch_models(&mut self, cx: &mut Context) -> Task> {
+ let settings = &AllLanguageModelSettings::get_global(cx).open_router;
+ let http_client = self.http_client.clone();
+ let api_url = settings.api_url.clone();
+
+ cx.spawn(async move |this, cx| {
+ let models = list_models(http_client.as_ref(), &api_url).await?;
+
+ this.update(cx, |this, cx| {
+ this.available_models = models;
+ cx.notify();
+ })
+ })
+ }
+
+ fn restart_fetch_models_task(&mut self, cx: &mut Context) {
+ let task = self.fetch_models(cx);
+ self.fetch_models_task.replace(task);
+ }
+}
+
+impl OpenRouterLanguageModelProvider {
+ pub fn new(http_client: Arc, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State {
+ api_key: None,
+ api_key_from_env: false,
+ http_client: http_client.clone(),
+ available_models: Vec::new(),
+ fetch_models_task: None,
+ _subscription: cx.observe_global::(|this: &mut State, cx| {
+ this.restart_fetch_models_task(cx);
+ cx.notify();
+ }),
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: open_router::Model) -> Arc {
+ Arc::new(OpenRouterLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+}
+
+impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenRouterLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiOpenRouter
+ }
+
+ fn default_model(&self, _cx: &App) -> Option> {
+ Some(self.create_language_model(open_router::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option> {
+ Some(self.create_language_model(open_router::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec> {
+ let mut models_from_api = self.state.read(cx).available_models.clone();
+ let mut settings_models = Vec::new();
+
+ for model in &AllLanguageModelSettings::get_global(cx)
+ .open_router
+ .available_models
+ {
+ settings_models.push(open_router::Model {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ supports_tools: Some(false),
+ });
+ }
+
+ for settings_model in &settings_models {
+ if let Some(pos) = models_from_api
+ .iter()
+ .position(|m| m.name == settings_model.name)
+ {
+ models_from_api[pos] = settings_model.clone();
+ } else {
+ models_from_api.push(settings_model.clone());
+ }
+ }
+
+ models_from_api
+ .into_iter()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task> {
+ self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ }
+}
+
+pub struct OpenRouterLanguageModel {
+ id: LanguageModelId,
+ model: open_router::Model,
+ state: gpui::Entity,
+ http_client: Arc,
+ request_limiter: RateLimiter,
+}
+
+impl OpenRouterLanguageModel {
+ fn stream_completion(
+ &self,
+ request: open_router::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result>>>
+ {
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).open_router;
+ (state.api_key.clone(), settings.api_url.clone())
+ }) else {
+ return futures::future::ready(Err(anyhow!(
+ "App state dropped: Unable to read API key or API URL from the application state"
+ )))
+ .boxed();
+ };
+
+ let future = self.request_limiter.stream(async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
+ let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for OpenRouterLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tool_calls()
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("openrouter/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> usize {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option {
+ self.model.max_output_tokens()
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto => true,
+ LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn supports_images(&self) -> bool {
+ false
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result> {
+ count_open_router_tokens(request, self.model.clone(), cx)
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result,
+ >,
+ >,
+ > {
+ let request = into_open_router(request, &self.model, self.max_output_tokens());
+ let completions = self.stream_completion(request, cx);
+ async move {
+ let mapper = OpenRouterEventMapper::new();
+ Ok(mapper.map_stream(completions.await?).boxed())
+ }
+ .boxed()
+ }
+}
+
+pub fn into_open_router(
+ request: LanguageModelRequest,
+ model: &Model,
+ max_output_tokens: Option,
+) -> open_router::Request {
+ let mut messages = Vec::new();
+ for req_message in request.messages {
+ for content in req_message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
+ .push(match req_message.role {
+ Role::User => open_router::RequestMessage::User { content: text },
+ Role::Assistant => open_router::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ },
+ Role::System => open_router::RequestMessage::System { content: text },
+ }),
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = open_router::ToolCall {
+ id: tool_use.id.to_string(),
+ content: open_router::ToolCallContent::Function {
+ function: open_router::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(open_router::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ text.to_string()
+ }
+ LanguageModelToolResultContent::Image(_) => {
+ "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
+ }
+ };
+
+ messages.push(open_router::RequestMessage::Tool {
+ content: content,
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ }
+ }
+ }
+
+ open_router::Request {
+ model: model.id().into(),
+ messages,
+ stream: true,
+ stop: request.stop,
+ temperature: request.temperature.unwrap_or(0.4),
+ max_tokens: max_output_tokens,
+ parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
+ Some(false)
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| open_router::ToolDefinition::Function {
+ function: open_router::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => open_router::ToolChoice::Required,
+ LanguageModelToolChoice::None => open_router::ToolChoice::None,
+ }),
+ }
+}
+
+pub struct OpenRouterEventMapper {
+ tool_calls_by_index: HashMap,
+}
+
+impl OpenRouterEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin>>>,
+ ) -> impl Stream- >
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: ResponseStreamEvent,
+ ) -> Vec> {
+ let Some(choice) = event.choices.first() else {
+ return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))];
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+ }
+ }
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
+ }
+
+ events
+ }
+}
+
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+}
+
+pub fn count_open_router_tokens(
+ request: LanguageModelRequest,
+ _model: open_router::Model,
+ cx: &App,
+) -> BoxFuture<'static, Result> {
+ cx.background_spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::>();
+
+ tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
+ })
+ .boxed()
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity,
+ state: gpui::Entity,
+ load_credentials_task: Option>,
+}
+
+impl ConfigurationView {
+ fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ let mut editor = Editor::single_line(window, cx);
+ editor
+ .set_placeholder_text("sk_or_000000000000000000000000000000000000000000000000", cx);
+ editor
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ let _ = task.await;
+ }
+
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) {
+ let api_key = self.api_key_editor.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement {
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_fallbacks: settings.ui_font.fallbacks.clone(),
+ font_size: rems(0.875).into(),
+ font_weight: settings.ui_font.weight,
+ font_style: FontStyle::Normal,
+ line_height: relative(1.3),
+ white_space: WhiteSpace::Normal,
+ ..Default::default()
+ };
+ EditorElement::new(
+ &self.api_key_editor,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+
+ fn should_render_editor(&self, cx: &mut Context) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_from_env;
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials...")).into_any()
+ } else if self.should_render_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:"))
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create an API key by visiting",
+ Some("OpenRouter's console"),
+ Some("https://openrouter.ai/keys"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Ensure your OpenRouter account has credits",
+ ))
+ .child(InstructionListItem::text_only(
+ "Paste your API key below and hit enter to start using the assistant",
+ )),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .my_2()
+ .px_2()
+ .py_1()
+ .bg(cx.theme().colors().editor_background)
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .rounded_sm()
+ .child(self.render_api_key_editor(cx)),
+ )
+ .child(
+ Label::new(
+ format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .into_any()
+ } else {
+ h_flex()
+ .mt_1()
+ .p_1()
+ .justify_between()
+ .rounded_md()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().background)
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new(if env_var_set {
+ format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
+ )
+ .child(
+ Button::new("reset-key", "Reset Key")
+ .label_size(LabelSize::Small)
+ .icon(Some(IconName::Trash))
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable.")))
+ })
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
+ )
+ .into_any()
+ }
+ }
+}
diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs
index abbb237b4f..2cf549c8f6 100644
--- a/crates/language_models/src/settings.rs
+++ b/crates/language_models/src/settings.rs
@@ -20,6 +20,7 @@ use crate::provider::{
mistral::MistralSettings,
ollama::OllamaSettings,
open_ai::OpenAiSettings,
+ open_router::OpenRouterSettings,
};
/// Initializes the language model settings.
@@ -61,6 +62,7 @@ pub struct AllLanguageModelSettings {
pub bedrock: AmazonBedrockSettings,
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
+ pub open_router: OpenRouterSettings,
pub zed_dot_dev: ZedDotDevSettings,
pub google: GoogleSettings,
pub copilot_chat: CopilotChatSettings,
@@ -76,6 +78,7 @@ pub struct AllLanguageModelSettingsContent {
pub ollama: Option,
pub lmstudio: Option,
pub openai: Option,
+ pub open_router: Option,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option,
pub google: Option,
@@ -271,6 +274,12 @@ pub struct ZedDotDevSettingsContent {
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct CopilotChatSettingsContent {}
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct OpenRouterSettingsContent {
+ pub api_url: Option,
+ pub available_models: Option>,
+}
+
impl settings::Settings for AllLanguageModelSettings {
const KEY: Option<&'static str> = Some("language_models");
@@ -409,6 +418,19 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.mistral.available_models,
mistral.as_ref().and_then(|s| s.available_models.clone()),
);
+
+ // OpenRouter
+ let open_router = value.open_router.clone();
+ merge(
+ &mut settings.open_router.api_url,
+ open_router.as_ref().and_then(|s| s.api_url.clone()),
+ );
+ merge(
+ &mut settings.open_router.available_models,
+ open_router
+ .as_ref()
+ .and_then(|s| s.available_models.clone()),
+ );
}
Ok(settings)
diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml
new file mode 100644
index 0000000000..bbc4fe190f
--- /dev/null
+++ b/crates/open_router/Cargo.toml
@@ -0,0 +1,25 @@
+[package]
+name = "open_router"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/open_router.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+http_client.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+workspace-hack.workspace = true
diff --git a/crates/open_router/LICENSE-GPL b/crates/open_router/LICENSE-GPL
new file mode 120000
index 0000000000..89e542f750
--- /dev/null
+++ b/crates/open_router/LICENSE-GPL
@@ -0,0 +1 @@
+../../LICENSE-GPL
\ No newline at end of file
diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs
new file mode 100644
index 0000000000..f0fe071503
--- /dev/null
+++ b/crates/open_router/src/open_router.rs
@@ -0,0 +1,484 @@
+use anyhow::{Context, Result, anyhow};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+use serde_json::Value;
+use std::convert::TryFrom;
+
+pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
+
+fn is_none_or_empty, U>(opt: &Option) -> bool {
+ opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
+}
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+ Tool,
+}
+
+impl TryFrom for Role {
+ type Error = anyhow::Error;
+
+ fn try_from(value: String) -> Result {
+ match value.as_str() {
+ "user" => Ok(Self::User),
+ "assistant" => Ok(Self::Assistant),
+ "system" => Ok(Self::System),
+ "tool" => Ok(Self::Tool),
+ _ => Err(anyhow!("invalid role '{value}'")),
+ }
+ }
+}
+
+impl From for String {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => "user".to_owned(),
+ Role::Assistant => "assistant".to_owned(),
+ Role::System => "system".to_owned(),
+ Role::Tool => "tool".to_owned(),
+ }
+ }
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct Model {
+ pub name: String,
+ pub display_name: Option,
+ pub max_tokens: usize,
+ pub supports_tools: Option,
+}
+
+impl Model {
+ pub fn default_fast() -> Self {
+ Self::new(
+ "openrouter/auto",
+ Some("Auto Router"),
+ Some(2000000),
+ Some(true),
+ )
+ }
+
+ pub fn default() -> Self {
+ Self::default_fast()
+ }
+
+ pub fn new(
+ name: &str,
+ display_name: Option<&str>,
+ max_tokens: Option,
+ supports_tools: Option,
+ ) -> Self {
+ Self {
+ name: name.to_owned(),
+ display_name: display_name.map(|s| s.to_owned()),
+ max_tokens: max_tokens.unwrap_or(2000000),
+ supports_tools,
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ &self.name
+ }
+
+ pub fn display_name(&self) -> &str {
+ self.display_name.as_ref().unwrap_or(&self.name)
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ self.max_tokens
+ }
+
+ pub fn max_output_tokens(&self) -> Option {
+ None
+ }
+
+ pub fn supports_tool_calls(&self) -> bool {
+ self.supports_tools.unwrap_or(false)
+ }
+
+ pub fn supports_parallel_tool_calls(&self) -> bool {
+ false
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Request {
+ pub model: String,
+ pub messages: Vec,
+ pub stream: bool,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub max_tokens: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub stop: Vec,
+ pub temperature: f32,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub parallel_tool_calls: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+ Auto,
+ Required,
+ None,
+ Other(ToolDefinition),
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ToolDefinition {
+ #[allow(dead_code)]
+ Function { function: FunctionDefinition },
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct FunctionDefinition {
+ pub name: String,
+ pub description: Option,
+ pub parameters: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum RequestMessage {
+ Assistant {
+ content: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tool_calls: Vec,
+ },
+ User {
+ content: String,
+ },
+ System {
+ content: String,
+ },
+ Tool {
+ content: String,
+ tool_call_id: String,
+ },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCall {
+ pub id: String,
+ #[serde(flatten)]
+ pub content: ToolCallContent,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+ Function { function: FunctionContent },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionContent {
+ pub name: String,
+ pub arguments: String,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessageDelta {
+ pub role: Option,
+ pub content: Option,
+ #[serde(default, skip_serializing_if = "is_none_or_empty")]
+ pub tool_calls: Option>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCallChunk {
+ pub index: usize,
+ pub id: Option,
+ pub function: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionChunk {
+ pub name: Option,
+ pub arguments: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Usage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct ChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessageDelta,
+ pub finish_reason: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct ResponseStreamEvent {
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub id: Option,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec,
+ pub usage: Option,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Response {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec,
+ pub usage: Usage,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Choice {
+ pub index: u32,
+ pub message: RequestMessage,
+ pub finish_reason: Option,
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
+pub struct ListModelsResponse {
+ pub data: Vec,
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
+pub struct ModelEntry {
+ pub id: String,
+ pub name: String,
+ pub created: usize,
+ pub description: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub context_length: Option,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub supported_parameters: Vec,
+}
+
+pub async fn complete(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result {
+ let uri = format!("{api_url}/chat/completions");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .header("HTTP-Referer", "https://zed.dev")
+ .header("X-Title", "Zed Editor");
+
+ let mut request_body = request;
+ request_body.stream = false;
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: Response = serde_json::from_str(&body)?;
+ Ok(response)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenRouterResponse {
+ error: OpenRouterError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenRouterError {
+ message: String,
+ #[serde(default)]
+ code: String,
+ }
+
+ match serde_json::from_str::(&body) {
+ Ok(response) if !response.error.message.is_empty() => {
+ let error_message = if !response.error.code.is_empty() {
+ format!("{}: {}", response.error.code, response.error.message)
+ } else {
+ response.error.message
+ };
+
+ Err(anyhow!(
+ "Failed to connect to OpenRouter API: {}",
+ error_message
+ ))
+ }
+ _ => Err(anyhow!(
+ "Failed to connect to OpenRouter API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+pub async fn stream_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result>> {
+ let uri = format!("{api_url}/chat/completions");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .header("HTTP-Referer", "https://zed.dev")
+ .header("X-Title", "Zed Editor");
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ if line.starts_with(':') {
+ return None;
+ }
+
+ let line = line.strip_prefix("data: ")?;
+ if line == "[DONE]" {
+ None
+ } else {
+ match serde_json::from_str::(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => {
+ #[derive(Deserialize)]
+ struct ErrorResponse {
+ error: String,
+ }
+
+ match serde_json::from_str::(line) {
+ Ok(err_response) => Some(Err(anyhow!(err_response.error))),
+ Err(_) => {
+ if line.trim().is_empty() {
+ None
+ } else {
+ Some(Err(anyhow!(
+ "Failed to parse response: {}. Original content: '{}'",
+ error, line
+ )))
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenRouterResponse {
+ error: OpenRouterError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenRouterError {
+ message: String,
+ #[serde(default)]
+ code: String,
+ }
+
+ match serde_json::from_str::(&body) {
+ Ok(response) if !response.error.message.is_empty() => {
+ let error_message = if !response.error.code.is_empty() {
+ format!("{}: {}", response.error.code, response.error.message)
+ } else {
+ response.error.message
+ };
+
+ Err(anyhow!(
+ "Failed to connect to OpenRouter API: {}",
+ error_message
+ ))
+ }
+ _ => Err(anyhow!(
+ "Failed to connect to OpenRouter API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result> {
+ let uri = format!("{api_url}/models");
+ let request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Accept", "application/json");
+
+ let request = request_builder.body(AsyncBody::default())?;
+ let mut response = client.send(request).await?;
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ if response.status().is_success() {
+ let response: ListModelsResponse =
+ serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
+
+ let models = response
+ .data
+ .into_iter()
+ .map(|entry| Model {
+ name: entry.id,
+ // OpenRouter returns display names in the format "provider_name: model_name".
+ // When displayed in the UI, these names can get truncated from the right.
+ // Since users typically already know the provider, we extract just the model name
+ // portion (after the colon) to create a more concise and user-friendly label
+ // for the model dropdown in the agent panel.
+ display_name: Some(
+ entry
+ .name
+ .split(':')
+ .next_back()
+ .unwrap_or(&entry.name)
+ .trim()
+ .to_string(),
+ ),
+ max_tokens: entry.context_length.unwrap_or(2000000),
+ supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
+ })
+ .collect();
+
+ Ok(models)
+ } else {
+ Err(anyhow!(
+ "Failed to connect to OpenRouter API: {} {}",
+ response.status(),
+ body,
+ ))
+ }
+}