collab: Lay groundwork for reconciling with Stripe using the events API (#15459)
This PR lays the initial groundwork for using the Stripe events API to reconcile the data in our system with what's in Stripe. We're using the events API over webhooks so that we don't need to stand up the associated infrastructure needed to handle webhooks effectively (namely an asynchronous job queue). Since we haven't configured the Stripe API keys yet, we won't actually spawn the reconciliation background task yet, so this is currently a no-op. Release Notes: - N/A
This commit is contained in:
parent
28c14cdee4
commit
d93891ba63
7 changed files with 257 additions and 4 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -840,6 +840,7 @@ version = "0.37.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"hyper",
|
||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -309,7 +309,6 @@ async-dispatcher = "0.1"
|
|||
async-fs = "1.6"
|
||||
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
|
||||
async-recursion = "1.0.0"
|
||||
async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] }
|
||||
async-tar = "0.4.2"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.23"
|
||||
|
@ -461,6 +460,19 @@ wasmtime-wasi = "21.0.1"
|
|||
which = "6.0.0"
|
||||
wit-component = "0.201"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
version = "0.37"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
"billing",
|
||||
"checkout",
|
||||
"events",
|
||||
# The features below are only enabled to get the `events` feature to build.
|
||||
"chrono",
|
||||
"connect",
|
||||
]
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.58"
|
||||
features = [
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use axum::{extract, routing::post, Extension, Json, Router};
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -10,10 +11,16 @@ use stripe::{
|
|||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCustomer, Customer, CustomerId,
|
||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
||||
SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::db::BillingSubscriptionId;
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams,
|
||||
};
|
||||
use crate::{AppState, Error, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
|
@ -194,3 +201,184 @@ async fn manage_billing_subscription(
|
|||
billing_portal_session_url: session.url,
|
||||
}))
|
||||
}
|
||||
|
||||
const POLL_EVENTS_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
|
||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &stripe_client).await.log_err();
|
||||
|
||||
executor.sleep(POLL_EVENTS_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
let event_types = [
|
||||
EventType::CustomerCreated.to_string(),
|
||||
EventType::CustomerSubscriptionCreated.to_string(),
|
||||
EventType::CustomerSubscriptionUpdated.to_string(),
|
||||
EventType::CustomerSubscriptionPaused.to_string(),
|
||||
EventType::CustomerSubscriptionResumed.to_string(),
|
||||
EventType::CustomerSubscriptionDeleted.to_string(),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|event_type| {
|
||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||
// so we need to unquote it.
|
||||
event_type.trim_matches('"').to_string()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
loop {
|
||||
log::info!("retrieving events from Stripe: {}", event_types.join(", "));
|
||||
|
||||
let mut params = ListEvents::new();
|
||||
params.types = Some(event_types.clone());
|
||||
params.limit = Some(100);
|
||||
|
||||
let events = stripe::Event::list(stripe_client, ¶ms).await?;
|
||||
for event in events.data {
|
||||
match event.type_ {
|
||||
EventType::CustomerCreated => {
|
||||
handle_customer_event(app, stripe_client, event)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
EventType::CustomerSubscriptionCreated
|
||||
| EventType::CustomerSubscriptionUpdated
|
||||
| EventType::CustomerSubscriptionPaused
|
||||
| EventType::CustomerSubscriptionResumed
|
||||
| EventType::CustomerSubscriptionDeleted => {
|
||||
handle_customer_subscription_event(app, stripe_client, event)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !events.has_more {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_customer_event(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
let EventObject::Customer(customer) = event.data.object else {
|
||||
bail!("unexpected event payload for {}", event.id);
|
||||
};
|
||||
|
||||
find_or_create_billing_customer(app, stripe_client, Expandable::Object(Box::new(customer)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
let EventObject::Subscription(subscription) = event.data.object else {
|
||||
bail!("unexpected event payload for {}", event.id);
|
||||
};
|
||||
|
||||
let billing_customer =
|
||||
find_or_create_billing_customer(app, stripe_client, subscription.customer)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("billing customer not found"))?;
|
||||
|
||||
app.db
|
||||
.upsert_billing_subscription_by_stripe_subscription_id(&CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
stripe_subscription_id: subscription.id.to_string(),
|
||||
stripe_subscription_status: subscription.status.into(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
match value {
|
||||
SubscriptionStatus::Incomplete => Self::Incomplete,
|
||||
SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
|
||||
SubscriptionStatus::Trialing => Self::Trialing,
|
||||
SubscriptionStatus::Active => Self::Active,
|
||||
SubscriptionStatus::PastDue => Self::PastDue,
|
||||
SubscriptionStatus::Canceled => Self::Canceled,
|
||||
SubscriptionStatus::Unpaid => Self::Unpaid,
|
||||
SubscriptionStatus::Paused => Self::Paused,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
customer_or_id: Expandable<Customer>,
|
||||
) -> anyhow::Result<Option<billing_customer::Model>> {
|
||||
let customer_id = match &customer_or_id {
|
||||
Expandable::Id(id) => id,
|
||||
Expandable::Object(customer) => customer.id.as_ref(),
|
||||
};
|
||||
|
||||
// If we already have a billing customer record associated with the Stripe customer,
|
||||
// there's nothing more we need to do.
|
||||
if let Some(billing_customer) = app
|
||||
.db
|
||||
.get_billing_customer_by_stripe_customer_id(&customer_id)
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(billing_customer));
|
||||
}
|
||||
|
||||
// If all we have is a customer ID, resolve it to a full customer record by
|
||||
// hitting the Stripe API.
|
||||
let customer = match customer_or_id {
|
||||
Expandable::Id(id) => Customer::retrieve(&stripe_client, &id, &[]).await?,
|
||||
Expandable::Object(customer) => *customer,
|
||||
};
|
||||
|
||||
let Some(email) = customer.email else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(user) = app.db.get_user_by_email(&email).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let billing_customer = app
|
||||
.db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id: user.id,
|
||||
stripe_customer_id: customer.id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
|
|
@ -39,4 +39,18 @@ impl Database {
|
|||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing customer for the user with the specified Stripe customer ID.
|
||||
pub async fn get_billing_customer_by_stripe_customer_id(
|
||||
&self,
|
||||
stripe_customer_id: &str,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,31 @@ impl Database {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Upserts the billing subscription by its Stripe subscription ID.
|
||||
pub async fn upsert_billing_subscription_by_stripe_subscription_id(
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
|
||||
..Default::default()
|
||||
})
|
||||
.on_conflict(
|
||||
OnConflict::columns([billing_subscription::Column::StripeSubscriptionId])
|
||||
.update_columns([billing_subscription::Column::StripeSubscriptionStatus])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing subscription with the specified ID.
|
||||
pub async fn get_billing_subscription_by_id(
|
||||
&self,
|
||||
|
|
|
@ -61,6 +61,17 @@ impl Database {
|
|||
.await
|
||||
}
|
||||
|
||||
/// Returns a user by email address. There are no access checks here, so this should only be used internally.
|
||||
pub async fn get_user_by_email(&self, email: &str) -> Result<Option<User>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(user::Entity::find()
|
||||
.filter(user::Column::EmailAddress.eq(email))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally.
|
||||
pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result<Option<User>> {
|
||||
self.transaction(|tx| async move {
|
||||
|
|
|
@ -5,6 +5,7 @@ use axum::{
|
|||
routing::get,
|
||||
Extension, Router,
|
||||
};
|
||||
use collab::api::billing::poll_stripe_events_periodically;
|
||||
use collab::{
|
||||
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
|
||||
rpc::ResultExt, AppState, Config, RateLimiter, Result,
|
||||
|
@ -95,6 +96,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
if is_api {
|
||||
poll_stripe_events_periodically(state.clone());
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue