collab: Use StripeClient in sync_subscription (#31761)

This PR updates the `sync_subscription` function to use the
`StripeClient` trait instead of using `stripe::Client` directly.

Release Notes:

- N/A

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Marshall Bowers 2025-05-30 12:08:58 -04:00 committed by GitHub
parent 07436b4284
commit f725b5e248
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 177 additions and 75 deletions

View file

@ -17,8 +17,8 @@ use stripe::{
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
@ -29,7 +29,10 @@ use crate::db::billing_subscription::{
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
use crate::stripe_client::{
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
StripeSubscriptionId,
};
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{
@ -426,7 +429,7 @@ async fn manage_billing_subscription(
.await?
.context("user not found")?;
let Some(stripe_client) = app.stripe_client.clone() else {
let Some(stripe_client) = app.real_stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
@ -644,7 +647,7 @@ async fn migrate_to_new_billing(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
) -> Result<Json<MigrateToNewBillingResponse>> {
let Some(stripe_client) = app.stripe_client.clone() else {
let Some(stripe_client) = app.real_stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
@ -723,6 +726,13 @@ async fn sync_billing_subscription(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
) -> Result<Json<SyncBillingSubscriptionResponse>> {
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
@ -748,7 +758,7 @@ async fn sync_billing_subscription(
.context("failed to parse Stripe customer ID from database")?;
let subscriptions = Subscription::list(
&stripe_client,
&real_stripe_client,
&stripe::ListSubscriptions {
customer: Some(stripe_customer_id),
// Sync all non-canceled subscriptions.
@ -761,7 +771,7 @@ async fn sync_billing_subscription(
for subscription in subscriptions.data {
let subscription_id = subscription.id.clone();
sync_subscription(&app, &stripe_client, subscription)
sync_subscription(&app, &stripe_client, subscription.into())
.await
.with_context(|| {
format!(
@ -806,6 +816,10 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
@ -816,7 +830,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &rpc_server, &stripe_client)
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
.await
.log_err();
@ -829,7 +843,8 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
@ -861,7 +876,7 @@ async fn poll_stripe_events(
params.types = Some(event_types.clone());
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
let mut event_pages = stripe::Event::list(&stripe_client, &params)
let mut event_pages = stripe::Event::list(&real_stripe_client, &params)
.await?
.paginate(params);
@ -905,7 +920,7 @@ async fn poll_stripe_events(
break;
} else {
log::info!("Stripe events: retrieving next page");
event_pages = event_pages.next(&stripe_client).await?;
event_pages = event_pages.next(&real_stripe_client).await?;
}
} else {
break;
@ -945,7 +960,7 @@ async fn poll_stripe_events(
let process_result = match event.type_ {
EventType::CustomerCreated | EventType::CustomerUpdated => {
handle_customer_event(app, stripe_client, event).await
handle_customer_event(app, real_stripe_client, event).await
}
EventType::CustomerSubscriptionCreated
| EventType::CustomerSubscriptionUpdated
@ -1020,8 +1035,8 @@ async fn handle_customer_event(
async fn sync_subscription(
app: &Arc<AppState>,
stripe_client: &stripe::Client,
subscription: stripe::Subscription,
stripe_client: &Arc<dyn StripeClient>,
subscription: StripeSubscription,
) -> anyhow::Result<billing_customer::Model> {
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
stripe_billing
@ -1032,7 +1047,7 @@ async fn sync_subscription(
};
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
.await?
.context("billing customer not found")?;
@ -1060,7 +1075,7 @@ async fn sync_subscription(
.as_ref()
.and_then(|details| details.reason)
.map_or(false, |reason| {
reason == CancellationDetailsReason::PaymentFailed
reason == StripeCancellationDetailsReason::PaymentFailed
});
if was_canceled_due_to_payment_failure {
@ -1077,7 +1092,7 @@ async fn sync_subscription(
if let Some(existing_subscription) = app
.db
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
.await?
{
app.db
@ -1118,20 +1133,13 @@ async fn sync_subscription(
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
{
let stripe_subscription_id = existing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
let stripe_subscription_id = StripeSubscriptionId(
existing_subscription.stripe_subscription_id.clone().into(),
);
Subscription::cancel(
&stripe_client,
&stripe_subscription_id,
stripe::CancelSubscription {
invoice_now: None,
..Default::default()
},
)
.await?;
stripe_client
.cancel_subscription(&stripe_subscription_id)
.await?;
} else {
// If the user already has an active billing subscription, ignore the
// event and return an `Ok` to signal that it was processed
@ -1198,7 +1206,7 @@ async fn sync_subscription(
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
stripe_client: &Arc<dyn StripeClient>,
event: stripe::Event,
) -> anyhow::Result<()> {
let EventObject::Subscription(subscription) = event.data.object else {
@ -1207,7 +1215,7 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
// When the user's subscription changes, push down any changes to their plan.
rpc_server
@ -1403,30 +1411,20 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
/// Finds or creates a billing customer using the provided customer.
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &stripe::Client,
customer_or_id: Expandable<Customer>,
stripe_client: &dyn StripeClient,
customer_id: &StripeCustomerId,
) -> anyhow::Result<Option<billing_customer::Model>> {
let customer_id = match &customer_or_id {
Expandable::Id(id) => id,
Expandable::Object(customer) => customer.id.as_ref(),
};
// If we already have a billing customer record associated with the Stripe customer,
// there's nothing more we need to do.
if let Some(billing_customer) = app
.db
.get_billing_customer_by_stripe_customer_id(customer_id)
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
.await?
{
return Ok(Some(billing_customer));
}
// If all we have is a customer ID, resolve it to a full customer record by
// hitting the Stripe API.
let customer = match customer_or_id {
Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
Expandable::Object(customer) => *customer,
};
let customer = stripe_client.get_customer(customer_id).await?;
let Some(email) = customer.email else {
return Ok(None);