diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 46bcad23e8..70b9585736 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -61,6 +61,10 @@ pub fn router() -> Router { "/billing/subscriptions/migrate", post(migrate_to_new_billing), ) + .route( + "/billing/subscriptions/sync", + post(sync_billing_subscription), + ) .route("/billing/monthly_spend", get(get_monthly_spend)) .route("/billing/usage", get(get_current_usage)) } @@ -737,6 +741,73 @@ async fn migrate_to_new_billing( })) } +#[derive(Debug, Deserialize)] +struct SyncBillingSubscriptionBody { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct SyncBillingSubscriptionResponse { + stripe_customer_id: String, +} + +async fn sync_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let Some(stripe_client) = app.stripe_client.clone() else { + log::error!("failed to retrieve Stripe client"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let billing_customer = app + .db + .get_billing_customer_by_user_id(user.id) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))?; + let stripe_customer_id = billing_customer + .stripe_customer_id + .parse::() + .context("failed to parse Stripe customer ID from database")?; + + let subscriptions = Subscription::list( + &stripe_client, + &stripe::ListSubscriptions { + customer: Some(stripe_customer_id), + // Sync all non-canceled subscriptions. + status: None, + ..Default::default() + }, + ) + .await?; + + for subscription in subscriptions.data { + let subscription_id = subscription.id.clone(); + + sync_subscription(&app, &stripe_client, subscription) + .await + .with_context(|| { + format!( + "failed to sync subscription {subscription_id} for user {}", + user.id, + ) + })?; + } + + Ok(Json(SyncBillingSubscriptionResponse { + stripe_customer_id: billing_customer.stripe_customer_id.clone(), + })) +} + /// The amount of time we wait in between each poll of Stripe events. /// /// This value should strike a balance between: @@ -979,18 +1050,11 @@ async fn handle_customer_event( Ok(()) } -async fn handle_customer_subscription_event( +async fn sync_subscription( app: &Arc, - rpc_server: &Arc, stripe_client: &stripe::Client, - event: stripe::Event, -) -> anyhow::Result<()> { - let EventObject::Subscription(subscription) = event.data.object else { - bail!("unexpected event payload for {}", event.id); - }; - - log::info!("handling Stripe {} event: {}", event.type_, event.id); - + subscription: stripe::Subscription, +) -> anyhow::Result { let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing { stripe_billing .determine_subscription_kind(&subscription) @@ -1102,7 +1166,7 @@ async fn handle_customer_subscription_event( user_id = billing_customer.user_id, subscription_id = subscription.id ); - return Ok(()); + return Ok(billing_customer); } app.db @@ -1121,6 +1185,23 @@ async fn handle_customer_subscription_event( .await?; } + Ok(billing_customer) +} + +async fn handle_customer_subscription_event( + app: &Arc, + rpc_server: &Arc, + stripe_client: &stripe::Client, + event: stripe::Event, +) -> anyhow::Result<()> { + let EventObject::Subscription(subscription) = event.data.object else { + bail!("unexpected event payload for {}", event.id); + }; + + log::info!("handling Stripe {} event: {}", event.type_, event.id); + + let billing_customer = sync_subscription(app, stripe_client, subscription).await?; + // When the user's subscription changes, push down any changes to their plan. rpc_server .update_plan_for_user(billing_customer.user_id)