collab: Introduce StripeClient trait to abstract over Stripe interactions (#31615)

This PR introduces a new `StripeClient` trait to abstract over
interacting with the Stripe API.

This will allow us to more easily test our billing code.

This initial cut is small and focuses just on making
`StripeBilling::find_or_create_customer_by_email` testable. I'll follow
up with using the `StripeClient` in more places.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-28 14:34:44 -04:00 committed by GitHub
parent 68724ea99e
commit 361ceee72b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 257 additions and 33 deletions

View file

@ -1,19 +1,22 @@
use std::sync::Arc;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
use stripe::{PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>,
real_client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>,
}
#[derive(Default)]
@ -26,6 +29,17 @@ struct StripeBillingState {
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
real_client: client,
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
// This is just temporary until we can remove all usages of the real Stripe client.
real_client: Arc::new(stripe::Client::new("sk_test")),
client,
state: RwLock::default(),
}
@ -37,9 +51,9 @@ impl StripeBilling {
let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.client),
StripeMeter::list(&self.real_client),
stripe::Price::list(
&self.client,
&self.real_client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
@ -129,18 +143,11 @@ impl StripeBilling {
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<CustomerId> {
) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = Customer::list(
&self.client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
let customers = self.client.list_customers_by_email(email).await?;
customers.data.first().cloned()
customers.first().cloned()
} else {
None
};
@ -148,14 +155,12 @@ impl StripeBilling {
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = Customer::create(
&self.client,
CreateCustomer {
let customer = self
.client
.create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address,
..Default::default()
},
)
.await?;
})
.await?;
customer.id
};
@ -169,7 +174,7 @@ impl StripeBilling {
price: &stripe::Price,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) {
return Ok(());
@ -181,7 +186,7 @@ impl StripeBilling {
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
&self.client,
&self.real_client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
@ -211,7 +216,7 @@ impl StripeBilling {
let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create(
&self.client,
&self.real_client,
StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
@ -246,7 +251,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
@ -300,7 +305,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
@ -311,7 +316,7 @@ impl StripeBilling {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list(
&self.client,
&self.real_client,
&stripe::ListSubscriptions {
customer: Some(customer_id.clone()),
status: None,
@ -339,7 +344,7 @@ impl StripeBilling {
..Default::default()
}]);
let subscription = stripe::Subscription::create(&self.client, params).await?;
let subscription = stripe::Subscription::create(&self.real_client, params).await?;
Ok(subscription)
}
@ -365,7 +370,7 @@ impl StripeBilling {
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}