collab: Fully move StripeBilling over to using StripeClient (#31722)

This PR moves over the last method on `StripeBilling` to use the
`StripeClient` trait, allowing us to fully mock out Stripe behaviors for
`StripeBilling` in tests.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-29 19:49:14 -04:00 committed by GitHub
parent 406d975f39
commit c7047d5f0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 273 additions and 50 deletions

View file

@ -1182,10 +1182,8 @@ async fn sync_subscription(
.has_active_billing_subscription(billing_customer.user_id) .has_active_billing_subscription(billing_customer.user_id)
.await?; .await?;
if !already_has_active_billing_subscription { if !already_has_active_billing_subscription {
let stripe_customer_id = billing_customer let stripe_customer_id =
.stripe_customer_id StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
stripe_billing stripe_billing
.subscribe_to_zed_free(stripe_customer_id) .subscribe_to_zed_free(stripe_customer_id)

View file

@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind; use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase; use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims}; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::stripe_client::StripeCustomerId;
use crate::{ use crate::{
AppState, Error, Result, auth, AppState, Error, Result, auth,
db::{ db::{
@ -4055,10 +4056,8 @@ async fn get_llm_api_token(
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription billing_subscription
} else { } else {
let stripe_customer_id = billing_customer let stripe_customer_id =
.stripe_customer_id StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription = stripe_billing let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id) .subscribe_to_zed_free(stripe_customer_id)

View file

@ -14,8 +14,9 @@ use crate::stripe_client::{
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams, UpdateSubscriptionParams,
@ -23,7 +24,6 @@ use crate::stripe_client::{
pub struct StripeBilling { pub struct StripeBilling {
state: RwLock<StripeBillingState>, state: RwLock<StripeBillingState>,
real_client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>, client: Arc<dyn StripeClient>,
} }
@ -38,7 +38,6 @@ impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self { pub fn new(client: Arc<stripe::Client>) -> Self {
Self { Self {
client: Arc::new(RealStripeClient::new(client.clone())), client: Arc::new(RealStripeClient::new(client.clone())),
real_client: client,
state: RwLock::default(), state: RwLock::default(),
} }
} }
@ -46,8 +45,6 @@ impl StripeBilling {
#[cfg(test)] #[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self { pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self { Self {
// This is just temporary until we can remove all usages of the real Stripe client.
real_client: Arc::new(stripe::Client::new("sk_test")),
client, client,
state: RwLock::default(), state: RwLock::default(),
} }
@ -306,40 +303,33 @@ impl StripeBilling {
pub async fn subscribe_to_zed_free( pub async fn subscribe_to_zed_free(
&self, &self,
customer_id: stripe::CustomerId, customer_id: StripeCustomerId,
) -> Result<stripe::Subscription> { ) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?; let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list( let existing_subscriptions = self
&self.real_client, .client
&stripe::ListSubscriptions { .list_subscriptions_for_customer(&customer_id)
customer: Some(customer_id.clone()), .await?;
status: None,
..Default::default()
},
)
.await?;
let existing_active_subscription = let existing_active_subscription =
existing_subscriptions existing_subscriptions.into_iter().find(|subscription| {
.data subscription.status == SubscriptionStatus::Active
.into_iter() || subscription.status == SubscriptionStatus::Trialing
.find(|subscription| { });
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription { if let Some(subscription) = existing_active_subscription {
return Ok(subscription); return Ok(subscription);
} }
let mut params = stripe::CreateSubscription::new(customer_id); let params = StripeCreateSubscriptionParams {
params.items = Some(vec![stripe::CreateSubscriptionItems { customer: customer_id,
price: Some(zed_free_price_id.to_string()), items: vec![StripeCreateSubscriptionItems {
quantity: Some(1), price: Some(zed_free_price_id),
..Default::default() quantity: Some(1),
}]); }],
};
let subscription = stripe::Subscription::create(&self.real_client, params).await?; let subscription = self.client.create_subscription(params).await?;
Ok(subscription) Ok(subscription)
} }

View file

@ -30,21 +30,38 @@ pub struct CreateCustomerParams<'a> {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>); pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription { pub struct StripeSubscription {
pub id: StripeSubscriptionId, pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>, pub items: Vec<StripeSubscriptionItem>,
} }
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>); pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem { pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId, pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>, pub price: Option<StripePrice>,
} }
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams { pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>, pub items: Option<Vec<UpdateSubscriptionItems>>,
@ -76,7 +93,7 @@ pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>); pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct StripePrice { pub struct StripePrice {
pub id: StripePriceId, pub id: StripePriceId,
pub unit_amount: Option<i64>, pub unit_amount: Option<i64>,
@ -84,7 +101,7 @@ pub struct StripePrice {
pub recurring: Option<StripePriceRecurring>, pub recurring: Option<StripePriceRecurring>,
} }
#[derive(Debug, Clone)] #[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring { pub struct StripePriceRecurring {
pub meter: Option<String>, pub meter: Option<String>,
} }
@ -160,11 +177,21 @@ pub trait StripeClient: Send + Sync {
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>; async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription( async fn get_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>; ) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription( async fn update_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap; use collections::HashMap;
use parking_lot::Mutex; use parking_lot::Mutex;
use uuid::Uuid; use uuid::Uuid;
@ -10,9 +11,10 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
StripeSubscriptionId, UpdateSubscriptionParams, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionParams,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -85,6 +87,21 @@ impl StripeClient for FakeStripeClient {
Ok(customer) Ok(customer)
} }
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription( async fn get_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,
@ -96,6 +113,37 @@ impl StripeClient for FakeStripeClient {
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
} }
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription( async fn update_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,

View file

@ -20,10 +20,11 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
}; };
@ -69,6 +70,29 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer)) Ok(StripeCustomer::from(customer))
} }
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription( async fn get_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,
@ -80,6 +104,30 @@ impl StripeClient for RealStripeClient {
Ok(StripeSubscription::from(subscription)) Ok(StripeSubscription::from(subscription))
} }
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription( async fn update_subscription(
&self, &self,
subscription_id: &StripeSubscriptionId, subscription_id: &StripeSubscriptionId,
@ -220,6 +268,10 @@ impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self { fn from(value: Subscription) -> Self {
Self { Self {
id: value.id.into(), id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(), items: value.items.data.into_iter().map(Into::into).collect(),
} }
} }

View file

@ -1,5 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
@ -163,8 +164,13 @@ async fn test_subscribe_to_price() {
.lock() .lock()
.insert(price.id.clone(), price.clone()); .insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription { let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()), id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![], items: vec![],
}; };
stripe_client stripe_client
@ -194,8 +200,13 @@ async fn test_subscribe_to_price() {
// Subscribing to a price that is already on the subscription is a no-op. // Subscribing to a price that is already on the subscription is a no-op.
{ {
let now = Utc::now();
let subscription = StripeSubscription { let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()), id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem { items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()), id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()), price: Some(price.clone()),
@ -215,6 +226,104 @@ async fn test_subscribe_to_price() {
} }
} }
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test] #[gpui::test]
async fn test_bill_model_request_usage() { async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing(); let (stripe_billing, stripe_client) = make_stripe_billing();