use anyhow::{Context as _, Result, anyhow}; use collections::HashMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{ AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use open_router::{ Model, ModelMode as OpenRouterModelMode, ResponseStreamEvent, list_models, stream_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; use std::sync::Arc; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; use crate::{AllLanguageModelSettings, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenRouterSettings { pub api_url: String, pub available_models: Vec, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct AvailableModel { pub name: String, pub display_name: Option, pub max_tokens: u64, pub max_output_tokens: Option, pub max_completion_tokens: Option, pub supports_tools: Option, pub supports_images: Option, pub mode: Option, } #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "type", rename_all = "lowercase")] pub enum ModelMode { #[default] Default, Thinking { budget_tokens: Option, }, } impl From for OpenRouterModelMode { fn from(value: ModelMode) -> Self { match value { ModelMode::Default => OpenRouterModelMode::Default, ModelMode::Thinking { budget_tokens } => { OpenRouterModelMode::Thinking { budget_tokens } } } } } impl From for ModelMode { fn from(value: OpenRouterModelMode) -> Self { match value { OpenRouterModelMode::Default => ModelMode::Default, OpenRouterModelMode::Thinking { budget_tokens } => { ModelMode::Thinking { budget_tokens } } } } } pub struct OpenRouterLanguageModelProvider { http_client: Arc, state: gpui::Entity, } pub struct State { api_key: Option, api_key_from_env: bool, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, settings: OpenRouterSettings, _subscription: Subscription, } const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; impl State { fn is_authenticated(&self) -> bool { self.api_key.is_some() } fn reset_api_key(&self, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .open_router .api_url .clone(); cx.spawn(async move |this, cx| { credentials_provider .delete_credentials(&api_url, &cx) .await .log_err(); this.update(cx, |this, cx| { this.api_key = None; this.api_key_from_env = false; cx.notify(); }) }) } fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .open_router .api_url .clone(); cx.spawn(async move |this, cx| { credentials_provider .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx) .await .log_err(); this.update(cx, |this, cx| { this.api_key = Some(api_key); this.restart_fetch_models_task(cx); cx.notify(); }) }) } fn authenticate(&self, cx: &mut Context) -> Task> { if self.is_authenticated() { return Task::ready(Ok(())); } let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .open_router .api_url .clone(); cx.spawn(async move |this, cx| { let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { (api_key, true) } else { let (_, api_key) = credentials_provider .read_credentials(&api_url, &cx) .await? .ok_or(AuthenticateError::CredentialsNotFound)?; ( String::from_utf8(api_key) .context(format!("invalid {} API key", PROVIDER_NAME))?, false, ) }; this.update(cx, |this, cx| { this.api_key = Some(api_key); this.api_key_from_env = from_env; this.restart_fetch_models_task(cx); cx.notify(); })?; Ok(()) }) } fn fetch_models(&mut self, cx: &mut Context) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).open_router; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); cx.spawn(async move |this, cx| { let models = list_models(http_client.as_ref(), &api_url).await?; this.update(cx, |this, cx| { this.available_models = models; cx.notify(); }) }) } fn restart_fetch_models_task(&mut self, cx: &mut Context) { if self.is_authenticated() { let task = self.fetch_models(cx); self.fetch_models_task.replace(task); } } } impl OpenRouterLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let state = cx.new(|cx| State { api_key: None, api_key_from_env: false, http_client: http_client.clone(), available_models: Vec::new(), fetch_models_task: None, settings: OpenRouterSettings::default(), _subscription: cx.observe_global::(|this: &mut State, cx| { let current_settings = &AllLanguageModelSettings::get_global(cx).open_router; let settings_changed = current_settings != &this.settings; if settings_changed { this.settings = current_settings.clone(); this.restart_fetch_models_task(cx); } cx.notify(); }), }); Self { http_client, state } } fn create_language_model(&self, model: open_router::Model) -> Arc { Arc::new(OpenRouterLanguageModel { id: LanguageModelId::from(model.id().to_string()), model, state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) } } impl LanguageModelProviderState for OpenRouterLanguageModelProvider { type ObservableEntity = State; fn observable_entity(&self) -> Option> { Some(self.state.clone()) } } impl LanguageModelProvider for OpenRouterLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn icon(&self) -> IconName { IconName::AiOpenRouter } fn default_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(open_router::Model::default())) } fn default_fast_model(&self, _cx: &App) -> Option> { Some(self.create_language_model(open_router::Model::default_fast())) } fn provided_models(&self, cx: &App) -> Vec> { let mut models_from_api = self.state.read(cx).available_models.clone(); let mut settings_models = Vec::new(); for model in &AllLanguageModelSettings::get_global(cx) .open_router .available_models { settings_models.push(open_router::Model { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, supports_tools: model.supports_tools, supports_images: model.supports_images, mode: model.mode.clone().unwrap_or_default().into(), }); } for settings_model in &settings_models { if let Some(pos) = models_from_api .iter() .position(|m| m.name == settings_model.name) { models_from_api[pos] = settings_model.clone(); } else { models_from_api.push(settings_model.clone()); } } models_from_api .into_iter() .map(|model| self.create_language_model(model)) .collect() } fn is_authenticated(&self, cx: &App) -> bool { self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } pub struct OpenRouterLanguageModel { id: LanguageModelId, model: open_router::Model, state: gpui::Entity, http_client: Arc, request_limiter: RateLimiter, } impl OpenRouterLanguageModel { fn stream_completion( &self, request: open_router::Request, cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { let settings = &AllLanguageModelSettings::get_global(cx).open_router; (state.api_key.clone(), settings.api_url.clone()) }) else { return futures::future::ready(Err(anyhow!( "App state dropped: Unable to read API key or API URL from the application state" ))) .boxed(); }; let future = self.request_limiter.stream(async move { let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?; let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; Ok(response) }); async move { Ok(future.await?.boxed()) }.boxed() } } impl LanguageModel for OpenRouterLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() } fn name(&self) -> LanguageModelName { LanguageModelName::from(self.model.display_name().to_string()) } fn provider_id(&self) -> LanguageModelProviderId { PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { PROVIDER_NAME } fn supports_tools(&self) -> bool { self.model.supports_tool_calls() } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { let model_id = self.model.id().trim().to_lowercase(); if model_id.contains("gemini") || model_id.contains("grok-4") { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema } } fn telemetry_id(&self) -> String { format!("openrouter/{}", self.model.id()) } fn max_token_count(&self) -> u64 { self.model.max_token_count() } fn max_output_tokens(&self) -> Option { self.model.max_output_tokens() } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto => true, LanguageModelToolChoice::Any => true, LanguageModelToolChoice::None => true, } } fn supports_images(&self) -> bool { self.model.supports_images.unwrap_or(false) } fn count_tokens( &self, request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { count_open_router_tokens(request, self.model.clone(), cx) } fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< futures::stream::BoxStream< 'static, Result, >, LanguageModelCompletionError, >, > { let request = into_open_router(request, &self.model, self.max_output_tokens()); let completions = self.stream_completion(request, cx); async move { let mapper = OpenRouterEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) } .boxed() } } pub fn into_open_router( request: LanguageModelRequest, model: &Model, max_output_tokens: Option, ) -> open_router::Request { let mut messages = Vec::new(); for message in request.messages { for content in message.content { match content { MessageContent::Text(text) => add_message_content_part( open_router::MessagePart::Text { text }, message.role, &mut messages, ), MessageContent::Thinking { .. } => {} MessageContent::RedactedThinking(_) => {} MessageContent::Image(image) => { add_message_content_part( open_router::MessagePart::Image { image_url: image.to_base64_url(), }, message.role, &mut messages, ); } MessageContent::ToolUse(tool_use) => { let tool_call = open_router::ToolCall { id: tool_use.id.to_string(), content: open_router::ToolCallContent::Function { function: open_router::FunctionContent { name: tool_use.name.to_string(), arguments: serde_json::to_string(&tool_use.input) .unwrap_or_default(), }, }, }; if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) = messages.last_mut() { tool_calls.push(tool_call); } else { messages.push(open_router::RequestMessage::Assistant { content: None, tool_calls: vec![tool_call], }); } } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { LanguageModelToolResultContent::Text(text) => { vec![open_router::MessagePart::Text { text: text.to_string(), }] } LanguageModelToolResultContent::Image(image) => { vec![open_router::MessagePart::Image { image_url: image.to_base64_url(), }] } }; messages.push(open_router::RequestMessage::Tool { content: content.into(), tool_call_id: tool_result.tool_use_id.to_string(), }); } } } } open_router::Request { model: model.id().into(), messages, stream: true, stop: request.stop, temperature: request.temperature.unwrap_or(0.4), max_tokens: max_output_tokens, parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { Some(false) } else { None }, usage: open_router::RequestUsage { include: true }, reasoning: if request.thinking_allowed && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode { Some(open_router::Reasoning { effort: None, max_tokens: budget_tokens, exclude: Some(false), enabled: Some(true), }) } else { None }, tools: request .tools .into_iter() .map(|tool| open_router::ToolDefinition::Function { function: open_router::FunctionDefinition { name: tool.name, description: Some(tool.description), parameters: Some(tool.input_schema), }, }) .collect(), tool_choice: request.tool_choice.map(|choice| match choice { LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto, LanguageModelToolChoice::Any => open_router::ToolChoice::Required, LanguageModelToolChoice::None => open_router::ToolChoice::None, }), } } fn add_message_content_part( new_part: open_router::MessagePart, role: Role, messages: &mut Vec, ) { match (role, messages.last_mut()) { (Role::User, Some(open_router::RequestMessage::User { content })) | (Role::System, Some(open_router::RequestMessage::System { content })) => { content.push_part(new_part); } ( Role::Assistant, Some(open_router::RequestMessage::Assistant { content: Some(content), .. }), ) => { content.push_part(new_part); } _ => { messages.push(match role { Role::User => open_router::RequestMessage::User { content: open_router::MessageContent::from(vec![new_part]), }, Role::Assistant => open_router::RequestMessage::Assistant { content: Some(open_router::MessageContent::from(vec![new_part])), tool_calls: Vec::new(), }, Role::System => open_router::RequestMessage::System { content: open_router::MessageContent::from(vec![new_part]), }, }); } } } pub struct OpenRouterEventMapper { tool_calls_by_index: HashMap, } impl OpenRouterEventMapper { pub fn new() -> Self { Self { tool_calls_by_index: HashMap::default(), } } pub fn map_stream( mut self, events: Pin>>>, ) -> impl Stream> { events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], }) }) } pub fn map_event( &mut self, event: ResponseStreamEvent, ) -> Vec> { let Some(choice) = event.choices.first() else { return vec![Err(LanguageModelCompletionError::from(anyhow!( "Response contained no choices" )))]; }; let mut events = Vec::new(); if let Some(reasoning) = choice.delta.reasoning.clone() { events.push(Ok(LanguageModelCompletionEvent::Thinking { text: reasoning, signature: None, })); } if let Some(content) = choice.delta.content.clone() { // OpenRouter send empty content string with the reasoning content // This is a workaround for the OpenRouter API bug if !content.is_empty() { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } } if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { for tool_call in tool_calls { let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); if let Some(tool_id) = tool_call.id.clone() { entry.id = tool_id; } if let Some(function) = tool_call.function.as_ref() { if let Some(name) = function.name.clone() { entry.name = name; } if let Some(arguments) = function.arguments.clone() { entry.arguments.push_str(&arguments); } } } } if let Some(usage) = event.usage { events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { input_tokens: usage.prompt_tokens, output_tokens: usage.completion_tokens, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, }))); } match choice.finish_reason.as_deref() { Some("stop") => { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); } Some("tool_calls") => { events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { match serde_json::Value::from_str(&tool_call.arguments) { Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: tool_call.id.clone().into(), name: tool_call.name.as_str().into(), is_input_complete: true, input, raw_input: tool_call.arguments.clone(), }, )), Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { id: tool_call.id.clone().into(), tool_name: tool_call.name.as_str().into(), raw_input: tool_call.arguments.clone().into(), json_parse_error: error.to_string(), }), } })); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); } Some(stop_reason) => { log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); } None => {} } events } } #[derive(Default)] struct RawToolCall { id: String, name: String, arguments: String, } pub fn count_open_router_tokens( request: LanguageModelRequest, _model: open_router::Model, cx: &App, ) -> BoxFuture<'static, Result> { cx.background_spawn(async move { let messages = request .messages .into_iter() .map(|message| tiktoken_rs::ChatCompletionRequestMessage { role: match message.role { Role::User => "user".into(), Role::Assistant => "assistant".into(), Role::System => "system".into(), }, content: Some(message.string_contents()), name: None, function_call: None, }) .collect::>(); tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) }) .boxed() } struct ConfigurationView { api_key_editor: Entity, state: gpui::Entity, load_credentials_task: Option>, } impl ConfigurationView { fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { let api_key_editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); editor .set_placeholder_text("sk_or_000000000000000000000000000000000000000000000000", cx); editor }); cx.observe(&state, |_, _, cx| { cx.notify(); }) .detach(); let load_credentials_task = Some(cx.spawn_in(window, { let state = state.clone(); async move |this, cx| { if let Some(task) = state .update(cx, |state, cx| state.authenticate(cx)) .log_err() { let _ = task.await; } this.update(cx, |this, cx| { this.load_credentials_task = None; cx.notify(); }) .log_err(); } })); Self { api_key_editor, state, load_credentials_task, } } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let api_key = self.api_key_editor.read(cx).text(cx); if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state .update(cx, |state, cx| state.set_api_key(api_key, cx))? .await }) .detach_and_log_err(cx); cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { self.api_key_editor .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state.update(cx, |state, cx| state.reset_api_key(cx))?.await }) .detach_and_log_err(cx); cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { let settings = ThemeSettings::get_global(cx); let text_style = TextStyle { color: cx.theme().colors().text, font_family: settings.ui_font.family.clone(), font_features: settings.ui_font.features.clone(), font_fallbacks: settings.ui_font.fallbacks.clone(), font_size: rems(0.875).into(), font_weight: settings.ui_font.weight, font_style: FontStyle::Normal, line_height: relative(1.3), white_space: WhiteSpace::Normal, ..Default::default() }; EditorElement::new( &self.api_key_editor, EditorStyle { background: cx.theme().colors().editor_background, local_player: cx.theme().players().local(), text: text_style, ..Default::default() }, ) } fn should_render_editor(&self, cx: &mut Context) -> bool { !self.state.read(cx).is_authenticated() } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let env_var_set = self.state.read(cx).api_key_from_env; if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() } else if self.should_render_editor(cx) { v_flex() .id("open-router-container") .size_full() .on_action(cx.listener(Self::save_api_key)) .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) .child( List::new() .child(InstructionListItem::new( "Create an API key by visiting", Some("OpenRouter's console"), Some("https://openrouter.ai/keys"), )) .child(InstructionListItem::text_only( "Ensure your OpenRouter account has credits", )) .child(InstructionListItem::text_only( "Paste your API key below and hit enter to start using the assistant", )), ) .child( h_flex() .id("api-key-editor") .w_full() .my_2() .px_2() .py_1() .bg(cx.theme().colors().editor_background) .border_1() .border_color(cx.theme().colors().border) .rounded_sm() .child(self.render_api_key_editor(cx)), ) .child( Label::new( format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) .into_any() } else { h_flex() .mt_1() .p_1() .justify_between() .rounded_md() .border_1() .border_color(cx.theme().colors().border) .bg(cx.theme().colors().background) .child( h_flex() .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") } else { "API key configured.".to_string() })), ) .child( Button::new("reset-key", "Reset Key") .label_size(LabelSize::Small) .icon(Some(IconName::Trash)) .icon_size(IconSize::Small) .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) .into_any() } } }