diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index b19eaecd63..3aa08f776a 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -422,7 +422,8 @@ CREATE TABLE IF NOT EXISTS billing_subscriptions ( created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, billing_customer_id INTEGER NOT NULL REFERENCES billing_customers(id), stripe_subscription_id TEXT NOT NULL, - stripe_subscription_status TEXT NOT NULL + stripe_subscription_status TEXT NOT NULL, + last_stripe_event_id TEXT ); CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id); @@ -432,7 +433,8 @@ CREATE TABLE IF NOT EXISTS billing_customers ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, user_id INTEGER NOT NULL REFERENCES users(id), - stripe_customer_id TEXT NOT NULL + stripe_customer_id TEXT NOT NULL, + last_stripe_event_id TEXT ); CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id); diff --git a/crates/collab/migrations/20240730122654_add_last_stripe_event_id.sql b/crates/collab/migrations/20240730122654_add_last_stripe_event_id.sql new file mode 100644 index 0000000000..477eadd742 --- /dev/null +++ b/crates/collab/migrations/20240730122654_add_last_stripe_event_id.sql @@ -0,0 +1,2 @@ +ALTER TABLE billing_customers ADD COLUMN last_stripe_event_id TEXT; +ALTER TABLE billing_subscriptions ADD COLUMN last_stripe_event_id TEXT; diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index e6860942d8..17ef748143 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -5,13 +5,14 @@ use std::time::Duration; use anyhow::{anyhow, bail, Context}; use axum::{extract, routing::post, Extension, Json, Router}; use reqwest::StatusCode; +use sea_orm::ActiveValue; use serde::{Deserialize, Serialize}; use stripe::{ BillingPortalSession, CheckoutSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, + CreateCustomer, Customer, CustomerId, EventId, EventObject, EventType, Expandable, ListEvents, SubscriptionStatus, }; use util::ResultExt; @@ -19,7 +20,7 @@ use util::ResultExt; use crate::db::billing_subscription::StripeSubscriptionStatus; use crate::db::{ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, - CreateBillingSubscriptionParams, + CreateBillingSubscriptionParams, UpdateBillingCustomerParams, UpdateBillingSubscriptionParams, }; use crate::{AppState, Error, Result}; @@ -231,6 +232,7 @@ async fn poll_stripe_events( ) -> anyhow::Result<()> { let event_types = [ EventType::CustomerCreated.to_string(), + EventType::CustomerUpdated.to_string(), EventType::CustomerSubscriptionCreated.to_string(), EventType::CustomerSubscriptionUpdated.to_string(), EventType::CustomerSubscriptionPaused.to_string(), @@ -255,7 +257,7 @@ async fn poll_stripe_events( let events = stripe::Event::list(stripe_client, ¶ms).await?; for event in events.data { match event.type_ { - EventType::CustomerCreated => { + EventType::CustomerCreated | EventType::CustomerUpdated => { handle_customer_event(app, stripe_client, event) .await .log_err(); @@ -283,15 +285,59 @@ async fn poll_stripe_events( async fn handle_customer_event( app: &Arc, - stripe_client: &stripe::Client, + _stripe_client: &stripe::Client, event: stripe::Event, ) -> anyhow::Result<()> { let EventObject::Customer(customer) = event.data.object else { bail!("unexpected event payload for {}", event.id); }; - find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer))) - .await?; + log::info!("handling Stripe {} event: {}", event.type_, event.id); + + let Some(email) = customer.email else { + log::info!("Stripe customer has no email: skipping"); + return Ok(()); + }; + + let Some(user) = app.db.get_user_by_email(&email).await? else { + log::info!("no user found for email: skipping"); + return Ok(()); + }; + + if let Some(existing_customer) = app + .db + .get_billing_customer_by_stripe_customer_id(&customer.id) + .await? + { + if should_ignore_event(&event.id, existing_customer.last_stripe_event_id.as_deref()) { + log::info!( + "ignoring Stripe event {} based on last seen event ID", + event.id + ); + return Ok(()); + } + + app.db + .update_billing_customer( + existing_customer.id, + &UpdateBillingCustomerParams { + // For now we just update the last event ID for the customer + // and leave the rest of the information as-is, as it is not + // likely to change. + last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())), + ..Default::default() + }, + ) + .await?; + } else { + app.db + .create_billing_customer(&CreateBillingCustomerParams { + user_id: user.id, + stripe_customer_id: customer.id.to_string(), + last_stripe_event_id: Some(event.id.to_string()), + }) + .await?; + } Ok(()) } @@ -305,18 +351,60 @@ async fn handle_customer_subscription_event( bail!("unexpected event payload for {}", event.id); }; - let billing_customer = - find_or_create_billing_customer(app, stripe_client, subscription.customer) - .await? - .ok_or_else(|| anyhow!("billing customer not found"))?; + log::info!("handling Stripe {} event: {}", event.type_, event.id); - app.db - .upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams { - billing_customer_id: billing_customer.id, - stripe_subscription_id: subscription.id.to_string(), - stripe_subscription_status: subscription.status.into(), - }) - .await?; + let billing_customer = find_or_create_billing_customer( + app, + stripe_client, + // Even though we're handling a subscription event, we can still set + // the ID as the last seen event ID on the customer in the event that + // we have to create it. + // + // This is done to avoid any potential rollback in the customer's values + // if we then see an older event that pertains to the customer. + &event.id, + subscription.customer, + ) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))?; + + if let Some(existing_subscription) = app + .db + .get_billing_subscription_by_stripe_subscription_id(&subscription.id) + .await? + { + if should_ignore_event( + &event.id, + existing_subscription.last_stripe_event_id.as_deref(), + ) { + log::info!( + "ignoring Stripe event {} based on last seen event ID", + event.id + ); + return Ok(()); + } + + app.db + .update_billing_subscription( + existing_subscription.id, + &UpdateBillingSubscriptionParams { + billing_customer_id: ActiveValue::set(billing_customer.id), + stripe_subscription_id: ActiveValue::set(subscription.id.to_string()), + stripe_subscription_status: ActiveValue::set(subscription.status.into()), + last_stripe_event_id: ActiveValue::set(Some(event.id.to_string())), + }, + ) + .await?; + } else { + app.db + .create_billing_subscription(&CreateBillingSubscriptionParams { + billing_customer_id: billing_customer.id, + stripe_subscription_id: subscription.id.to_string(), + stripe_subscription_status: subscription.status.into(), + last_stripe_event_id: Some(event.id.to_string()), + }) + .await?; + } Ok(()) } @@ -340,6 +428,7 @@ impl From for StripeSubscriptionStatus { async fn find_or_create_billing_customer( app: &Arc, stripe_client: &stripe::Client, + event_id: &EventId, customer_or_id: Expandable, ) -> anyhow::Result> { let customer_id = match &customer_or_id { @@ -377,8 +466,70 @@ async fn find_or_create_billing_customer( .create_billing_customer(&CreateBillingCustomerParams { user_id: user.id, stripe_customer_id: customer.id.to_string(), + last_stripe_event_id: Some(event_id.to_string()), }) .await?; Ok(Some(billing_customer)) } + +/// Returns whether an [`Event`] should be ignored, based on its ID and the last +/// seen event ID for this object. +#[inline] +fn should_ignore_event(event_id: &EventId, last_event_id: Option<&str>) -> bool { + !should_apply_event(event_id, last_event_id) +} + +/// Returns whether an [`Event`] should be applied, based on its ID and the last +/// seen event ID for this object. +fn should_apply_event(event_id: &EventId, last_event_id: Option<&str>) -> bool { + let Some(last_event_id) = last_event_id else { + return true; + }; + + event_id.as_str() < last_event_id +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_should_apply_event() { + let subscription_created_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNafuZSGsmh").unwrap(); + let subscription_updated_event = EventId::from_str("evt_1Pi5s9RxOf7d5PNa5UZLSsto").unwrap(); + + assert_eq!( + should_apply_event( + &subscription_created_event, + Some(subscription_created_event.as_str()) + ), + false, + "Events should not be applied when the IDs are the same." + ); + + assert_eq!( + should_apply_event( + &subscription_created_event, + Some(subscription_updated_event.as_str()) + ), + false, + "Events should not be applied when the last event ID is newer than the event ID." + ); + + assert_eq!( + should_apply_event(&subscription_created_event, None), + true, + "Events should be applied when we don't have a last event ID." + ); + + assert_eq!( + should_apply_event( + &subscription_updated_event, + Some(subscription_created_event.as_str()) + ), + true, + "Events should be applied when the event ID is newer than the last event ID." + ); + } +} diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index b34de6b326..ef3b0a4903 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -45,8 +45,10 @@ use tokio::sync::{Mutex, OwnedMutexGuard}; pub use tests::TestDb; pub use ids::*; -pub use queries::billing_customers::CreateBillingCustomerParams; -pub use queries::billing_subscriptions::CreateBillingSubscriptionParams; +pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; +pub use queries::billing_subscriptions::{ + CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, +}; pub use queries::contributors::ContributorSelector; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; diff --git a/crates/collab/src/db/queries/billing_customers.rs b/crates/collab/src/db/queries/billing_customers.rs index fd6dc8e7a1..1d9d287a60 100644 --- a/crates/collab/src/db/queries/billing_customers.rs +++ b/crates/collab/src/db/queries/billing_customers.rs @@ -4,6 +4,14 @@ use super::*; pub struct CreateBillingCustomerParams { pub user_id: UserId, pub stripe_customer_id: String, + pub last_stripe_event_id: Option, +} + +#[derive(Debug, Default)] +pub struct UpdateBillingCustomerParams { + pub user_id: ActiveValue, + pub stripe_customer_id: ActiveValue, + pub last_stripe_event_id: ActiveValue>, } impl Database { @@ -26,6 +34,28 @@ impl Database { .await } + /// Updates the specified billing customer. + pub async fn update_billing_customer( + &self, + id: BillingCustomerId, + params: &UpdateBillingCustomerParams, + ) -> Result<()> { + self.transaction(|tx| async move { + billing_customer::Entity::update(billing_customer::ActiveModel { + id: ActiveValue::set(id), + user_id: params.user_id.clone(), + stripe_customer_id: params.stripe_customer_id.clone(), + last_stripe_event_id: params.last_stripe_event_id.clone(), + ..Default::default() + }) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } + /// Returns the billing customer for the user with the specified ID. pub async fn get_billing_customer_by_user_id( &self, diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 85e2766a74..e6af85672f 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -1,3 +1,5 @@ +use sea_orm::IntoActiveValue; + use crate::db::billing_subscription::StripeSubscriptionStatus; use super::*; @@ -7,6 +9,15 @@ pub struct CreateBillingSubscriptionParams { pub billing_customer_id: BillingCustomerId, pub stripe_subscription_id: String, pub stripe_subscription_status: StripeSubscriptionStatus, + pub last_stripe_event_id: Option, +} + +#[derive(Debug, Default)] +pub struct UpdateBillingSubscriptionParams { + pub billing_customer_id: ActiveValue, + pub stripe_subscription_id: ActiveValue, + pub stripe_subscription_status: ActiveValue, + pub last_stripe_event_id: ActiveValue>, } impl Database { @@ -20,6 +31,7 @@ impl Database { billing_customer_id: ActiveValue::set(params.billing_customer_id), stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), + last_stripe_event_id: params.last_stripe_event_id.clone().into_active_value(), ..Default::default() }) .exec_without_returning(&*tx) @@ -30,24 +42,22 @@ impl Database { .await } - /// Upserts the billing subscription by its Stripe subscription ID. - pub async fn upsert_billing_subscription_by_stripe_subscription_id( + /// Updates the specified billing subscription. + pub async fn update_billing_subscription( &self, - params: &CreateBillingSubscriptionParams, + id: BillingSubscriptionId, + params: &UpdateBillingSubscriptionParams, ) -> Result<()> { self.transaction(|tx| async move { - billing_subscription::Entity::insert(billing_subscription::ActiveModel { - billing_customer_id: ActiveValue::set(params.billing_customer_id), - stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), - stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), + billing_subscription::Entity::update(billing_subscription::ActiveModel { + id: ActiveValue::set(id), + billing_customer_id: params.billing_customer_id.clone(), + stripe_subscription_id: params.stripe_subscription_id.clone(), + stripe_subscription_status: params.stripe_subscription_status.clone(), + last_stripe_event_id: params.last_stripe_event_id.clone(), ..Default::default() }) - .on_conflict( - OnConflict::columns([billing_subscription::Column::StripeSubscriptionId]) - .update_columns([billing_subscription::Column::StripeSubscriptionStatus]) - .to_owned(), - ) - .exec_with_returning(&*tx) + .exec(&*tx) .await?; Ok(()) @@ -68,6 +78,22 @@ impl Database { .await } + /// Returns the billing subscription with the specified Stripe subscription ID. + pub async fn get_billing_subscription_by_stripe_subscription_id( + &self, + stripe_subscription_id: &str, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_subscription::Entity::find() + .filter( + billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id), + ) + .one(&*tx) + .await?) + }) + .await + } + /// Returns all of the billing subscriptions for the user with the specified ID. /// /// Note that this returns the subscriptions regardless of their status. diff --git a/crates/collab/src/db/tables/billing_customer.rs b/crates/collab/src/db/tables/billing_customer.rs index 258a7e0c0c..2e186068b2 100644 --- a/crates/collab/src/db/tables/billing_customer.rs +++ b/crates/collab/src/db/tables/billing_customer.rs @@ -9,6 +9,7 @@ pub struct Model { pub id: BillingCustomerId, pub user_id: UserId, pub stripe_customer_id: String, + pub last_stripe_event_id: Option, pub created_at: DateTime, } diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 4cbde6bec0..3911a094ad 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -10,6 +10,7 @@ pub struct Model { pub billing_customer_id: BillingCustomerId, pub stripe_subscription_id: String, pub stripe_subscription_status: StripeSubscriptionStatus, + pub last_stripe_event_id: Option, pub created_at: DateTime, } diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs index 19f5463ac2..182a4a9cf7 100644 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ b/crates/collab/src/db/tests/billing_subscription_tests.rs @@ -29,6 +29,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { .create_billing_customer(&CreateBillingCustomerParams { user_id, stripe_customer_id: "cus_active_user".into(), + last_stripe_event_id: None, }) .await .unwrap(); @@ -38,6 +39,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { billing_customer_id: customer.id, stripe_subscription_id: "sub_active_user".into(), stripe_subscription_status: StripeSubscriptionStatus::Active, + last_stripe_event_id: None, }) .await .unwrap(); @@ -63,6 +65,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { .create_billing_customer(&CreateBillingCustomerParams { user_id, stripe_customer_id: "cus_past_due_user".into(), + last_stripe_event_id: None, }) .await .unwrap(); @@ -72,6 +75,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { billing_customer_id: customer.id, stripe_subscription_id: "sub_past_due_user".into(), stripe_subscription_status: StripeSubscriptionStatus::PastDue, + last_stripe_event_id: None, }) .await .unwrap();