collab: Add endpoint for managing a billing subscription (#15455)

This PR adds a new `POST /billing/subscriptions/manage` endpoint that
can be used to manage a billing subscription.

The endpoint accepts a `github_user_id` to identify the user, as well as
an optional `subscription_id` for managing a specific subscription. If
`subscription_id` is not provided, it try and use the active
subscription, if there is only one.

Right now the endpoint only supports cancelling an active subscription.
This is done by passing `"intent": "cancel"` in the request body.

The endpoint will return the URL to a Stripe customer portal session,
which the caller can redirect the user to.

Here's an example of how to call it:

```sh
curl -X POST "http://localhost:8080/billing/subscriptions/manage" \
     -H "Authorization: <ADMIN_TOKEN>" \
     -H "Content-Type: application/json" \
     -d '{"github_user_id": 12345, "intent": "cancel"}'
```

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-07-29 20:05:17 -04:00 committed by GitHub
parent 4d8ad7ae42
commit 66121fa0e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 124 additions and 4 deletions

View file

@ -1,17 +1,29 @@
use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use anyhow::{anyhow, Context};
use axum::{extract, routing::post, Extension, Json, Router};
use collections::HashSet;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
use stripe::{
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CustomerId,
};
use crate::db::BillingSubscriptionId;
use crate::{AppState, Error, Result};
pub fn router() -> Router {
Router::new().route("/billing/subscriptions", post(create_billing_subscription))
Router::new()
.route("/billing/subscriptions", post(create_billing_subscription))
.route(
"/billing/subscriptions/manage",
post(manage_billing_subscription),
)
}
#[derive(Debug, Deserialize)]
@ -61,7 +73,7 @@ async fn create_billing_subscription(
distinct_customer_ids
.into_iter()
.next()
.map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
.map(|id| CustomerId::from_str(id).context("failed to parse customer ID"))
.transpose()
}?;
@ -86,3 +98,96 @@ async fn create_billing_subscription(
.ok_or_else(|| anyhow!("no checkout session URL"))?,
}))
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ManageSubscriptionIntent {
/// The user intends to cancel their subscription.
Cancel,
}
#[derive(Debug, Deserialize)]
struct ManageBillingSubscriptionBody {
github_user_id: i32,
intent: ManageSubscriptionIntent,
/// The ID of the subscription to manage.
///
/// If not provided, we will try to use the active subscription (if there is only one).
subscription_id: Option<BillingSubscriptionId>,
}
#[derive(Debug, Serialize)]
struct ManageBillingSubscriptionResponse {
billing_portal_session_url: String,
}
/// Initiates a Stripe customer portal session for managing a billing subscription.
async fn manage_billing_subscription(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<ManageBillingSubscriptionBody>,
) -> Result<Json<ManageBillingSubscriptionResponse>> {
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::Http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let subscription = if let Some(subscription_id) = body.subscription_id {
app.db
.get_billing_subscription_by_id(subscription_id)
.await?
.ok_or_else(|| anyhow!("subscription not found"))?
} else {
// If no subscription ID was provided, try to find the only active subscription ID.
let subscriptions = app.db.get_active_billing_subscriptions(user.id).await?;
if subscriptions.len() > 1 {
Err(anyhow!("user has multiple active subscriptions"))?;
}
subscriptions
.into_iter()
.next()
.ok_or_else(|| anyhow!("user has no active subscriptions"))?
};
let customer_id = CustomerId::from_str(&subscription.stripe_customer_id)
.context("failed to parse customer ID")?;
let flow = match body.intent {
ManageSubscriptionIntent::Cancel => CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
return_url: "https://zed.dev/billing".into(),
}),
..Default::default()
}),
subscription_cancel: Some(
stripe::CreateBillingPortalSessionFlowDataSubscriptionCancel {
subscription: subscription.stripe_subscription_id,
retention: None,
},
),
..Default::default()
},
};
let mut params = CreateBillingPortalSession::new(customer_id);
params.flow_data = Some(flow);
params.return_url = Some("https://zed.dev/billing");
let session = BillingPortalSession::create(&stripe_client, params).await?;
Ok(Json(ManageBillingSubscriptionResponse {
billing_portal_session_url: session.url,
}))
}

View file

@ -32,6 +32,19 @@ impl Database {
.await
}
/// Returns the billing subscription with the specified ID.
pub async fn get_billing_subscription_by_id(
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
})
.await
}
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
@ -44,6 +57,7 @@ impl Database {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.filter(billing_subscription::Column::UserId.eq(user_id))
.order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;
@ -65,6 +79,7 @@ impl Database {
.eq(StripeSubscriptionStatus::Active),
),
)
.order_by_asc(billing_subscription::Column::Id)
.all(&*tx)
.await?;