diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index c454f4884b..bc7052252a 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,17 +1,29 @@ use std::str::FromStr; use std::sync::Arc; -use anyhow::anyhow; +use anyhow::{anyhow, Context}; use axum::{extract, routing::post, Extension, Json, Router}; use collections::HashSet; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; -use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId}; +use stripe::{ + BillingPortalSession, CheckoutSession, CreateBillingPortalSession, + CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, + CreateBillingPortalSessionFlowDataAfterCompletionRedirect, + CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, + CustomerId, +}; +use crate::db::BillingSubscriptionId; use crate::{AppState, Error, Result}; pub fn router() -> Router { - Router::new().route("/billing/subscriptions", post(create_billing_subscription)) + Router::new() + .route("/billing/subscriptions", post(create_billing_subscription)) + .route( + "/billing/subscriptions/manage", + post(manage_billing_subscription), + ) } #[derive(Debug, Deserialize)] @@ -61,7 +73,7 @@ async fn create_billing_subscription( distinct_customer_ids .into_iter() .next() - .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err))) + .map(|id| CustomerId::from_str(id).context("failed to parse customer ID")) .transpose() }?; @@ -86,3 +98,96 @@ async fn create_billing_subscription( .ok_or_else(|| anyhow!("no checkout session URL"))?, })) } + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ManageSubscriptionIntent { + /// The user intends to cancel their subscription. + Cancel, +} + +#[derive(Debug, Deserialize)] +struct ManageBillingSubscriptionBody { + github_user_id: i32, + intent: ManageSubscriptionIntent, + /// The ID of the subscription to manage. + /// + /// If not provided, we will try to use the active subscription (if there is only one). + subscription_id: Option, +} + +#[derive(Debug, Serialize)] +struct ManageBillingSubscriptionResponse { + billing_portal_session_url: String, +} + +/// Initiates a Stripe customer portal session for managing a billing subscription. +async fn manage_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let Some(stripe_client) = app.stripe_client.clone() else { + log::error!("failed to retrieve Stripe client"); + Err(Error::Http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let subscription = if let Some(subscription_id) = body.subscription_id { + app.db + .get_billing_subscription_by_id(subscription_id) + .await? + .ok_or_else(|| anyhow!("subscription not found"))? + } else { + // If no subscription ID was provided, try to find the only active subscription ID. + let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?; + if subscriptions.len() > 1 { + Err(anyhow!("user has multiple active subscriptions"))?; + } + + subscriptions + .into_iter() + .next() + .ok_or_else(|| anyhow!("user has no active subscriptions"))? + }; + + let customer_id = CustomerId::from_str(&subscription.stripe_customer_id) + .context("failed to parse customer ID")?; + + let flow = match body.intent { + ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData { + type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, + after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { + type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect, + redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect { + return_url: "https://zed.dev/billing".into(), + }), + ..Default::default() + }), + subscription_cancel: Some( + stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel { + subscription: subscription.stripe_subscription_id, + retention: None, + }, + ), + ..Default::default() + }, + }; + + let mut params = CreateBillingPortalSession::new(customer_id); + params.flow_data = Some(flow); + params.return_url = Some("https://zed.dev/billing"); + + let session = BillingPortalSession::create(&stripe_client, params).await?; + + Ok(Json(ManageBillingSubscriptionResponse { + billing_portal_session_url: session.url, + })) +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index fcacf7ee22..42d1a4f180 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -32,6 +32,19 @@ impl Database { .await } + /// Returns the billing subscription with the specified ID. + pub async fn get_billing_subscription_by_id( + &self, + id: BillingSubscriptionId, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_subscription::Entity::find_by_id(id) + .one(&*tx) + .await?) + }) + .await + } + /// Returns all of the billing subscriptions for the user with the specified ID. /// /// Note that this returns the subscriptions regardless of their status. @@ -44,6 +57,7 @@ impl Database { self.transaction(|tx| async move { let subscriptions = billing_subscription::Entity::find() .filter(billing_subscription::Column::UserId.eq(user_id)) + .order_by_asc(billing_subscription::Column::Id) .all(&*tx) .await?; @@ -65,6 +79,7 @@ impl Database { .eq(StripeSubscriptionStatus::Active), ), ) + .order_by_asc(billing_subscription::Column::Id) .all(&*tx) .await?;