collab: Introduce StripeClient trait to abstract over Stripe interactions (#31615)

This PR introduces a new `StripeClient` trait to abstract over
interacting with the Stripe API.

This will allow us to more easily test our billing code.

This initial cut is small and focuses just on making
`StripeBilling::find_or_create_customer_by_email` testable. I'll follow
up with using the `StripeClient` in more places.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-28 14:34:44 -04:00 committed by GitHub
parent 68724ea99e
commit 361ceee72b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 257 additions and 33 deletions

View file

@ -20,6 +20,7 @@ test-support = ["sqlite"]
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
async-stripe.workspace = true async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true async-tungstenite.workspace = true
aws-config = { version = "1.1.5" } aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" } aws-sdk-s3 = { version = "1.15.0" }

View file

@ -344,6 +344,7 @@ async fn create_billing_subscription(
stripe_billing stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref()) .find_or_create_customer_by_email(user.email_address.as_deref())
.await? .await?
.try_into()?
}; };
let success_url = format!( let success_url = format!(

View file

@ -9,6 +9,7 @@ pub mod migrations;
pub mod rpc; pub mod rpc;
pub mod seed; pub mod seed;
pub mod stripe_billing; pub mod stripe_billing;
pub mod stripe_client;
pub mod user_backfiller; pub mod user_backfiller;
#[cfg(test)] #[cfg(test)]

View file

@ -4039,7 +4039,8 @@ async fn get_llm_api_token(
} else { } else {
let customer_id = stripe_billing let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref()) .find_or_create_customer_by_email(user.email_address.as_deref())
.await?; .await?
.try_into()?;
find_or_create_billing_customer( find_or_create_billing_customer(
&session.app_state, &session.app_state,

View file

@ -1,19 +1,22 @@
use std::sync::Arc; use std::sync::Arc;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use anyhow::{Context as _, anyhow}; use anyhow::{Context as _, anyhow};
use chrono::Utc; use chrono::Utc;
use collections::HashMap; use collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus}; use stripe::{PriceId, SubscriptionStatus};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use uuid::Uuid; use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
pub struct StripeBilling { pub struct StripeBilling {
state: RwLock<StripeBillingState>, state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>, real_client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>,
} }
#[derive(Default)] #[derive(Default)]
@ -26,6 +29,17 @@ struct StripeBillingState {
impl StripeBilling { impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self { pub fn new(client: Arc<stripe::Client>) -> Self {
Self { Self {
client: Arc::new(RealStripeClient::new(client.clone())),
real_client: client,
state: RwLock::default(),
}
}
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
// This is just temporary until we can remove all usages of the real Stripe client.
real_client: Arc::new(stripe::Client::new("sk_test")),
client, client,
state: RwLock::default(), state: RwLock::default(),
} }
@ -37,9 +51,9 @@ impl StripeBilling {
let mut state = self.state.write().await; let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!( let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.client), StripeMeter::list(&self.real_client),
stripe::Price::list( stripe::Price::list(
&self.client, &self.real_client,
&stripe::ListPrices { &stripe::ListPrices {
limit: Some(100), limit: Some(100),
..Default::default() ..Default::default()
@ -129,18 +143,11 @@ impl StripeBilling {
pub async fn find_or_create_customer_by_email( pub async fn find_or_create_customer_by_email(
&self, &self,
email_address: Option<&str>, email_address: Option<&str>,
) -> Result<CustomerId> { ) -> Result<StripeCustomerId> {
let existing_customer = if let Some(email) = email_address { let existing_customer = if let Some(email) = email_address {
let customers = Customer::list( let customers = self.client.list_customers_by_email(email).await?;
&self.client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
customers.data.first().cloned() customers.first().cloned()
} else { } else {
None None
}; };
@ -148,14 +155,12 @@ impl StripeBilling {
let customer_id = if let Some(existing_customer) = existing_customer { let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id existing_customer.id
} else { } else {
let customer = Customer::create( let customer = self
&self.client, .client
CreateCustomer { .create_customer(crate::stripe_client::CreateCustomerParams {
email: email_address, email: email_address,
..Default::default() })
}, .await?;
)
.await?;
customer.id customer.id
}; };
@ -169,7 +174,7 @@ impl StripeBilling {
price: &stripe::Price, price: &stripe::Price,
) -> Result<()> { ) -> Result<()> {
let subscription = let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) { if subscription_contains_price(&subscription, &price.id) {
return Ok(()); return Ok(());
@ -181,7 +186,7 @@ impl StripeBilling {
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update( stripe::Subscription::update(
&self.client, &self.real_client,
subscription_id, subscription_id,
stripe::UpdateSubscription { stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems { items: Some(vec![stripe::UpdateSubscriptionItems {
@ -211,7 +216,7 @@ impl StripeBilling {
let idempotency_key = Uuid::new_v4(); let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create( StripeMeterEvent::create(
&self.client, &self.real_client,
StripeCreateMeterEventParams { StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key), identifier: &format!("model_requests/{}", idempotency_key),
event_name, event_name,
@ -246,7 +251,7 @@ impl StripeBilling {
}]); }]);
params.success_url = Some(success_url); params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?; let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?) Ok(session.url.context("no checkout session URL")?)
} }
@ -300,7 +305,7 @@ impl StripeBilling {
}]); }]);
params.success_url = Some(success_url); params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?; let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?) Ok(session.url.context("no checkout session URL")?)
} }
@ -311,7 +316,7 @@ impl StripeBilling {
let zed_free_price_id = self.zed_free_price_id().await?; let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list( let existing_subscriptions = stripe::Subscription::list(
&self.client, &self.real_client,
&stripe::ListSubscriptions { &stripe::ListSubscriptions {
customer: Some(customer_id.clone()), customer: Some(customer_id.clone()),
status: None, status: None,
@ -339,7 +344,7 @@ impl StripeBilling {
..Default::default() ..Default::default()
}]); }]);
let subscription = stripe::Subscription::create(&self.client, params).await?; let subscription = stripe::Subscription::create(&self.real_client, params).await?;
Ok(subscription) Ok(subscription)
} }
@ -365,7 +370,7 @@ impl StripeBilling {
}]); }]);
params.success_url = Some(success_url); params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?; let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
Ok(session.url.context("no checkout session URL")?) Ok(session.url.context("no checkout session URL")?)
} }
} }

View file

@ -0,0 +1,33 @@
#[cfg(test)]
mod fake_stripe_client;
mod real_stripe_client;
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripeCustomer {
pub id: StripeCustomerId,
pub email: Option<String>,
}
#[derive(Debug)]
pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
}

View file

@ -0,0 +1,47 @@
use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
}
}
}
#[async_trait]
impl StripeClient for FakeStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
Ok(self
.customers
.lock()
.values()
.filter(|customer| customer.email.as_deref() == Some(email))
.cloned()
.collect())
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = StripeCustomer {
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
email: params.email.map(|email| email.to_string()),
};
self.customers
.lock()
.insert(customer.id.clone(), customer.clone());
Ok(customer)
}
}

View file

@ -0,0 +1,74 @@
use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
}
impl RealStripeClient {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self { client }
}
}
#[async_trait]
impl StripeClient for RealStripeClient {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
let response = Customer::list(
&self.client,
&ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
Ok(response
.data
.into_iter()
.map(StripeCustomer::from)
.collect())
}
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: params.email,
..Default::default()
},
)
.await?;
Ok(StripeCustomer::from(customer))
}
}
impl From<CustomerId> for StripeCustomerId {
fn from(value: CustomerId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
id: value.id.into(),
email: value.email,
}
}
}

View file

@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
mod random_project_collaboration_tests; mod random_project_collaboration_tests;
mod randomized_test_helpers; mod randomized_test_helpers;
mod remote_editing_collaboration_tests; mod remote_editing_collaboration_tests;
mod stripe_billing_tests;
mod test_server; mod test_server;
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};

View file

@ -0,0 +1,60 @@
use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::FakeStripeClient;
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
let stripe_billing = StripeBilling::test(stripe_client.clone());
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Create a customer with an email that doesn't yet correspond to a customer.
{
let email = "user@example.com";
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
// Create a customer with an email that corresponds to an existing customer.
{
let email = "user2@example.com";
let existing_customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
let customer_id = stripe_billing
.find_or_create_customer_by_email(Some(email))
.await
.unwrap();
assert_eq!(customer_id, existing_customer_id);
let customer = stripe_client
.customers
.lock()
.get(&customer_id)
.unwrap()
.clone();
assert_eq!(customer.email.as_deref(), Some(email));
}
}