diff --git a/Cargo.lock b/Cargo.lock index 88283152ba..9cf69ed1c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8864,6 +8864,7 @@ dependencies = [ "mistral", "ollama", "open_ai", + "open_router", "partial-json-fixer", "project", "proto", @@ -10708,6 +10709,19 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "open_router" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.31", + "http_client", + "schemars", + "serde", + "serde_json", + "workspace-hack", +] + [[package]] name = "opener" version = "0.7.2" diff --git a/Cargo.toml b/Cargo.toml index 9152dfd23c..852e3ba413 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ members = [ "crates/notifications", "crates/ollama", "crates/open_ai", + "crates/open_router", "crates/outline", "crates/outline_panel", "crates/panel", @@ -307,6 +308,7 @@ node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } ollama = { path = "crates/ollama" } open_ai = { path = "crates/open_ai" } +open_router = { path = "crates/open_router", features = ["schemars"] } outline = { path = "crates/outline" } outline_panel = { path = "crates/outline_panel" } panel = { path = "crates/panel" } diff --git a/assets/icons/ai_open_router.svg b/assets/icons/ai_open_router.svg new file mode 100644 index 0000000000..cc8597729a --- /dev/null +++ b/assets/icons/ai_open_router.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 3ae4417505..7c0688831d 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1605,6 +1605,9 @@ "version": "1", "api_url": "https://api.openai.com/v1" }, + "open_router": { + "api_url": "https://openrouter.ai/api/v1" + }, "lmstudio": { "api_url": "http://localhost:1234/api/v0" }, diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index ce7bd56047..36480f30d5 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -730,6 +730,7 @@ impl JsonSchema for LanguageModelProviderSetting { "zed.dev".into(), "copilot_chat".into(), "deepseek".into(), + "openrouter".into(), "mistral".into(), ]), ..Default::default() diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 2896a19829..adfbe1e52d 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -18,6 +18,7 @@ pub enum IconName { AiMistral, AiOllama, AiOpenAi, + AiOpenRouter, AiZed, ArrowCircle, ArrowDown, diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 2c5048b910..ab5090e9ba 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -39,6 +39,7 @@ menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +open_router = { workspace = true, features = ["schemars"] } partial-json-fixer.workspace = true project.workspace = true proto.workspace = true diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 61c5dcf642..0224da4e6b 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -19,6 +19,7 @@ use crate::provider::lmstudio::LmStudioLanguageModelProvider; use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; use crate::provider::open_ai::OpenAiLanguageModelProvider; +use crate::provider::open_router::OpenRouterLanguageModelProvider; pub use crate::settings::*; pub fn init(user_store: Entity, client: Arc, fs: Arc, cx: &mut App) { @@ -72,5 +73,9 @@ fn register_language_model_providers( BedrockLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + OpenRouterLanguageModelProvider::new(client.http_client(), cx), + cx, + ); registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index 6b183292f3..4f2ea9cc09 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -8,3 +8,4 @@ pub mod lmstudio; pub mod mistral; pub mod ollama; pub mod open_ai; +pub mod open_router; diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs new file mode 100644 index 0000000000..7af265544a --- /dev/null +++ b/crates/language_models/src/provider/open_router.rs @@ -0,0 +1,788 @@ +use anyhow::{Context as _, Result, anyhow}; +use collections::HashMap; +use credentials_provider::CredentialsProvider; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; +use gpui::{ + AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, +}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, +}; +use open_router::{Model, ResponseStreamEvent, list_models, stream_completion}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr as _; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{Icon, IconName, List, Tooltip, prelude::*}; +use util::ResultExt; + +use crate::{AllLanguageModelSettings, ui::InstructionListItem}; + +const PROVIDER_ID: &str = "openrouter"; +const PROVIDER_NAME: &str = "OpenRouter"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenRouterSettings { + pub api_url: String, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub max_output_tokens: Option, + pub max_completion_tokens: Option, +} + +pub struct OpenRouterLanguageModelProvider { + http_client: Arc, + state: gpui::Entity, +} + +pub struct State { + api_key: Option, + api_key_from_env: bool, + http_client: Arc, + available_models: Vec, + fetch_models_task: Option>>, + _subscription: Subscription, +} + +const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + + fn authenticate(&self, cx: &mut Context) -> Task> { + if self.is_authenticated() { + return Task::ready(Ok(())); + } + + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .open_router + .api_url + .clone(); + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { + (api_key, true) + } else { + let (_, api_key) = credentials_provider + .read_credentials(&api_url, &cx) + .await? + .ok_or(AuthenticateError::CredentialsNotFound)?; + ( + String::from_utf8(api_key) + .context(format!("invalid {} API key", PROVIDER_NAME))?, + false, + ) + }; + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + this.api_key_from_env = from_env; + cx.notify(); + })?; + + Ok(()) + }) + } + + fn fetch_models(&mut self, cx: &mut Context) -> Task> { + let settings = &AllLanguageModelSettings::get_global(cx).open_router; + let http_client = self.http_client.clone(); + let api_url = settings.api_url.clone(); + + cx.spawn(async move |this, cx| { + let models = list_models(http_client.as_ref(), &api_url).await?; + + this.update(cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + }) + } + + fn restart_fetch_models_task(&mut self, cx: &mut Context) { + let task = self.fetch_models(cx); + self.fetch_models_task.replace(task); + } +} + +impl OpenRouterLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + _subscription: cx.observe_global::(|this: &mut State, cx| { + this.restart_fetch_models_task(cx); + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: open_router::Model) -> Arc { + Arc::new(OpenRouterLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for OpenRouterLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for OpenRouterLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn icon(&self) -> IconName { + IconName::AiOpenRouter + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models_from_api = self.state.read(cx).available_models.clone(); + let mut settings_models = Vec::new(); + + for model in &AllLanguageModelSettings::get_global(cx) + .open_router + .available_models + { + settings_models.push(open_router::Model { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: Some(false), + }); + } + + for settings_model in &settings_models { + if let Some(pos) = models_from_api + .iter() + .position(|m| m.name == settings_model.name) + { + models_from_api[pos] = settings_model.clone(); + } else { + models_from_api.push(settings_model.clone()); + } + } + + models_from_api + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.reset_api_key(cx)) + } +} + +pub struct OpenRouterLanguageModel { + id: LanguageModelId, + model: open_router::Model, + state: gpui::Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl OpenRouterLanguageModel { + fn stream_completion( + &self, + request: open_router::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).open_router; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!( + "App state dropped: Unable to read API key or API URL from the application state" + ))) + .boxed(); + }; + + let future = self.request_limiter.stream(async move { + let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?; + let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for OpenRouterLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool_calls() + } + + fn telemetry_id(&self) -> String { + format!("openrouter/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } + } + + fn supports_images(&self) -> bool { + false + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_open_router_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + >, + > { + let request = into_open_router(request, &self.model, self.max_output_tokens()); + let completions = self.stream_completion(request, cx); + async move { + let mapper = OpenRouterEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + +pub fn into_open_router( + request: LanguageModelRequest, + model: &Model, + max_output_tokens: Option, +) -> open_router::Request { + let mut messages = Vec::new(); + for req_message in request.messages { + for content in req_message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages + .push(match req_message.role { + Role::User => open_router::RequestMessage::User { content: text }, + Role::Assistant => open_router::RequestMessage::Assistant { + content: Some(text), + tool_calls: Vec::new(), + }, + Role::System => open_router::RequestMessage::System { content: text }, + }), + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(_) => {} + MessageContent::ToolUse(tool_use) => { + let tool_call = open_router::ToolCall { + id: tool_use.id.to_string(), + content: open_router::ToolCallContent::Function { + function: open_router::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + }, + }, + }; + + if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) = + messages.last_mut() + { + tool_calls.push(tool_call); + } else { + messages.push(open_router::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + text.to_string() + } + LanguageModelToolResultContent::Image(_) => { + "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() + } + }; + + messages.push(open_router::RequestMessage::Tool { + content: content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + open_router::Request { + model: model.id().into(), + messages, + stream: true, + stop: request.stop, + temperature: request.temperature.unwrap_or(0.4), + max_tokens: max_output_tokens, + parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { + Some(false) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| open_router::ToolDefinition::Function { + function: open_router::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto, + LanguageModelToolChoice::Any => open_router::ToolChoice::Required, + LanguageModelToolChoice::None => open_router::ToolChoice::None, + }), + } +} + +pub struct OpenRouterEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenRouterEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; + + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + } + } + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, +} + +pub fn count_open_router_tokens( + request: LanguageModelRequest, + _model: open_router::Model, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity, + state: gpui::Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor + .set_placeholder_text("sk_or_000000000000000000000000000000000000000000000000", cx); + editor + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + white_space: WhiteSpace::Normal, + ..Default::default() + }; + EditorElement::new( + &self.api_key_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_from_env; + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials...")).into_any() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's assistant with OpenRouter, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create an API key by visiting", + Some("OpenRouter's console"), + Some("https://openrouter.ai/keys"), + )) + .child(InstructionListItem::text_only( + "Ensure your OpenRouter account has credits", + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .border_1() + .border_color(cx.theme().colors().border) + .rounded_sm() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any() + } else { + h_flex() + .mt_1() + .p_1() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().background) + .child( + h_flex() + .gap_1() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new(if env_var_set { + format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + })), + ) + .child( + Button::new("reset-key", "Reset Key") + .label_size(LabelSize::Small) + .icon(Some(IconName::Trash)) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .disabled(env_var_set) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + .into_any() + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index abbb237b4f..2cf549c8f6 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -20,6 +20,7 @@ use crate::provider::{ mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, + open_router::OpenRouterSettings, }; /// Initializes the language model settings. @@ -61,6 +62,7 @@ pub struct AllLanguageModelSettings { pub bedrock: AmazonBedrockSettings, pub ollama: OllamaSettings, pub openai: OpenAiSettings, + pub open_router: OpenRouterSettings, pub zed_dot_dev: ZedDotDevSettings, pub google: GoogleSettings, pub copilot_chat: CopilotChatSettings, @@ -76,6 +78,7 @@ pub struct AllLanguageModelSettingsContent { pub ollama: Option, pub lmstudio: Option, pub openai: Option, + pub open_router: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, pub google: Option, @@ -271,6 +274,12 @@ pub struct ZedDotDevSettingsContent { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct CopilotChatSettingsContent {} +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct OpenRouterSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + impl settings::Settings for AllLanguageModelSettings { const KEY: Option<&'static str> = Some("language_models"); @@ -409,6 +418,19 @@ impl settings::Settings for AllLanguageModelSettings { &mut settings.mistral.available_models, mistral.as_ref().and_then(|s| s.available_models.clone()), ); + + // OpenRouter + let open_router = value.open_router.clone(); + merge( + &mut settings.open_router.api_url, + open_router.as_ref().and_then(|s| s.api_url.clone()), + ); + merge( + &mut settings.open_router.available_models, + open_router + .as_ref() + .and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/open_router/Cargo.toml b/crates/open_router/Cargo.toml new file mode 100644 index 0000000000..bbc4fe190f --- /dev/null +++ b/crates/open_router/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "open_router" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/open_router.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +http_client.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +workspace-hack.workspace = true diff --git a/crates/open_router/LICENSE-GPL b/crates/open_router/LICENSE-GPL new file mode 120000 index 0000000000..89e542f750 --- /dev/null +++ b/crates/open_router/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs new file mode 100644 index 0000000000..f0fe071503 --- /dev/null +++ b/crates/open_router/src/open_router.rs @@ -0,0 +1,484 @@ +use anyhow::{Context, Result, anyhow}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::convert::TryFrom; + +pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1"; + +fn is_none_or_empty, U>(opt: &Option) -> bool { + opt.as_ref().map_or(true, |v| v.as_ref().is_empty()) +} + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, + Tool, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + "tool" => Ok(Self::Tool), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + Role::Tool => "tool".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub struct Model { + pub name: String, + pub display_name: Option, + pub max_tokens: usize, + pub supports_tools: Option, +} + +impl Model { + pub fn default_fast() -> Self { + Self::new( + "openrouter/auto", + Some("Auto Router"), + Some(2000000), + Some(true), + ) + } + + pub fn default() -> Self { + Self::default_fast() + } + + pub fn new( + name: &str, + display_name: Option<&str>, + max_tokens: Option, + supports_tools: Option, + ) -> Self { + Self { + name: name.to_owned(), + display_name: display_name.map(|s| s.to_owned()), + max_tokens: max_tokens.unwrap_or(2000000), + supports_tools, + } + } + + pub fn id(&self) -> &str { + &self.name + } + + pub fn display_name(&self) -> &str { + self.display_name.as_ref().unwrap_or(&self.name) + } + + pub fn max_token_count(&self) -> usize { + self.max_tokens + } + + pub fn max_output_tokens(&self) -> Option { + None + } + + pub fn supports_tool_calls(&self) -> bool { + self.supports_tools.unwrap_or(false) + } + + pub fn supports_parallel_tool_calls(&self) -> bool { + false + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub stop: Vec, + pub temperature: f32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Auto, + Required, + None, + Other(ToolDefinition), +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + #[allow(dead_code)] + Function { function: FunctionDefinition }, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum RequestMessage { + Assistant { + content: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + tool_calls: Vec, + }, + User { + content: String, + }, + System { + content: String, + }, + Tool { + content: String, + tool_call_id: String, + }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCall { + pub id: String, + #[serde(flatten)] + pub content: ToolCallContent, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolCallContent { + Function { function: FunctionContent }, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionContent { + pub name: String, + pub arguments: String, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessageDelta { + pub role: Option, + pub content: Option, + #[serde(default, skip_serializing_if = "is_none_or_empty")] + pub tool_calls: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ToolCallChunk { + pub index: usize, + pub id: Option, + pub function: Option, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct FunctionChunk { + pub name: Option, + pub arguments: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ChoiceDelta { + pub index: u32, + pub delta: ResponseMessageDelta, + pub finish_reason: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResponseStreamEvent { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id: Option, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Response { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Choice { + pub index: u32, + pub message: RequestMessage, + pub finish_reason: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Deserialize)] +pub struct ListModelsResponse { + pub data: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Deserialize)] +pub struct ModelEntry { + pub id: String, + pub name: String, + pub created: usize, + pub description: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub context_length: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub supported_parameters: Vec, +} + +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .header("HTTP-Referer", "https://zed.dev") + .header("X-Title", "Zed Editor"); + + let mut request_body = request; + request_body.stream = false; + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: Response = serde_json::from_str(&body)?; + Ok(response) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenRouterResponse { + error: OpenRouterError, + } + + #[derive(Deserialize)] + struct OpenRouterError { + message: String, + #[serde(default)] + code: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => { + let error_message = if !response.error.code.is_empty() { + format!("{}: {}", response.error.code, response.error.message) + } else { + response.error.message + }; + + Err(anyhow!( + "Failed to connect to OpenRouter API: {}", + error_message + )) + } + _ => Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )), + } + } +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/chat/completions"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .header("HTTP-Referer", "https://zed.dev") + .header("X-Title", "Zed Editor"); + + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + if line.starts_with(':') { + return None; + } + + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str::(line) { + Ok(response) => Some(Ok(response)), + Err(error) => { + #[derive(Deserialize)] + struct ErrorResponse { + error: String, + } + + match serde_json::from_str::(line) { + Ok(err_response) => Some(Err(anyhow!(err_response.error))), + Err(_) => { + if line.trim().is_empty() { + None + } else { + Some(Err(anyhow!( + "Failed to parse response: {}. Original content: '{}'", + error, line + ))) + } + } + } + } + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenRouterResponse { + error: OpenRouterError, + } + + #[derive(Deserialize)] + struct OpenRouterError { + message: String, + #[serde(default)] + code: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => { + let error_message = if !response.error.code.is_empty() { + format!("{}: {}", response.error.code, response.error.message) + } else { + response.error.message + }; + + Err(anyhow!( + "Failed to connect to OpenRouter API: {}", + error_message + )) + } + _ => Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )), + } + } +} + +pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result> { + let uri = format!("{api_url}/models"); + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Accept", "application/json"); + + let request = request_builder.body(AsyncBody::default())?; + let mut response = client.send(request).await?; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + if response.status().is_success() { + let response: ListModelsResponse = + serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?; + + let models = response + .data + .into_iter() + .map(|entry| Model { + name: entry.id, + // OpenRouter returns display names in the format "provider_name: model_name". + // When displayed in the UI, these names can get truncated from the right. + // Since users typically already know the provider, we extract just the model name + // portion (after the colon) to create a more concise and user-friendly label + // for the model dropdown in the agent panel. + display_name: Some( + entry + .name + .split(':') + .next_back() + .unwrap_or(&entry.name) + .trim() + .to_string(), + ), + max_tokens: entry.context_length.unwrap_or(2000000), + supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())), + }) + .collect(); + + Ok(models) + } else { + Err(anyhow!( + "Failed to connect to OpenRouter API: {} {}", + response.status(), + body, + )) + } +}