collab: Fully move StripeBilling over to using StripeClient (#31722)

This PR moves over the last method on `StripeBilling` to use the
`StripeClient` trait, allowing us to fully mock out Stripe behaviors for
`StripeBilling` in tests.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-29 19:49:14 -04:00 committed by GitHub
parent 406d975f39
commit c7047d5f0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 273 additions and 50 deletions

View file

@ -1182,10 +1182,8 @@ async fn sync_subscription(
.has_active_billing_subscription(billing_customer.user_id)
.await?;
if !already_has_active_billing_subscription {
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
stripe_billing
.subscribe_to_zed_free(stripe_customer_id)

View file

@ -5,6 +5,7 @@ use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::stripe_client::StripeCustomerId;
use crate::{
AppState, Error, Result, auth,
db::{
@ -4055,10 +4056,8 @@ async fn get_llm_api_token(
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)

View file

@ -14,8 +14,9 @@ use crate::stripe_client::{
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
@ -23,7 +24,6 @@ use crate::stripe_client::{
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
real_client: Arc<stripe::Client>,
client: Arc<dyn StripeClient>,
}
@ -38,7 +38,6 @@ impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client: Arc::new(RealStripeClient::new(client.clone())),
real_client: client,
state: RwLock::default(),
}
}
@ -46,8 +45,6 @@ impl StripeBilling {
#[cfg(test)]
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
Self {
// This is just temporary until we can remove all usages of the real Stripe client.
real_client: Arc::new(stripe::Client::new("sk_test")),
client,
state: RwLock::default(),
}
@ -306,40 +303,33 @@ impl StripeBilling {
pub async fn subscribe_to_zed_free(
&self,
customer_id: stripe::CustomerId,
) -> Result<stripe::Subscription> {
customer_id: StripeCustomerId,
) -> Result<StripeSubscription> {
let zed_free_price_id = self.zed_free_price_id().await?;
let existing_subscriptions = stripe::Subscription::list(
&self.real_client,
&stripe::ListSubscriptions {
customer: Some(customer_id.clone()),
status: None,
..Default::default()
},
)
.await?;
let existing_subscriptions = self
.client
.list_subscriptions_for_customer(&customer_id)
.await?;
let existing_active_subscription =
existing_subscriptions
.data
.into_iter()
.find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
existing_subscriptions.into_iter().find(|subscription| {
subscription.status == SubscriptionStatus::Active
|| subscription.status == SubscriptionStatus::Trialing
});
if let Some(subscription) = existing_active_subscription {
return Ok(subscription);
}
let mut params = stripe::CreateSubscription::new(customer_id);
params.items = Some(vec![stripe::CreateSubscriptionItems {
price: Some(zed_free_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let params = StripeCreateSubscriptionParams {
customer: customer_id,
items: vec![StripeCreateSubscriptionItems {
price: Some(zed_free_price_id),
quantity: Some(1),
}],
};
let subscription = stripe::Subscription::create(&self.real_client, params).await?;
let subscription = self.client.create_subscription(params).await?;
Ok(subscription)
}

View file

@ -30,21 +30,38 @@ pub struct CreateCustomerParams<'a> {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionId(pub Arc<str>);
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscription {
pub id: StripeSubscriptionId,
pub customer: StripeCustomerId,
// TODO: Create our own version of this enum.
pub status: stripe::SubscriptionStatus,
pub current_period_end: i64,
pub current_period_start: i64,
pub items: Vec<StripeSubscriptionItem>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeSubscriptionItemId(pub Arc<str>);
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionItem {
pub id: StripeSubscriptionItemId,
pub price: Option<StripePrice>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionParams {
pub customer: StripeCustomerId,
pub items: Vec<StripeCreateSubscriptionItems>,
}
#[derive(Debug)]
pub struct StripeCreateSubscriptionItems {
pub price: Option<StripePriceId>,
pub quantity: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
@ -76,7 +93,7 @@ pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
@ -84,7 +101,7 @@ pub struct StripePrice {
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
@ -160,11 +177,21 @@ pub trait StripeClient: Send + Sync {
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>>;
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
) -> Result<StripeSubscription>;
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription>;
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use chrono::{Duration, Utc};
use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
@ -10,9 +11,10 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, UpdateSubscriptionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId,
StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionParams,
};
#[derive(Debug, Clone)]
@ -85,6 +87,21 @@ impl StripeClient for FakeStripeClient {
Ok(customer)
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let subscriptions = self
.subscriptions
.lock()
.values()
.filter(|subscription| subscription.customer == *customer_id)
.cloned()
.collect();
Ok(subscriptions)
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@ -96,6 +113,37 @@ impl StripeClient for FakeStripeClient {
.ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}"))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId(format!("sub_{}", Uuid::new_v4()).into()),
customer: params.customer,
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: params
.items
.into_iter()
.map(|item| StripeSubscriptionItem {
id: StripeSubscriptionItemId(format!("si_{}", Uuid::new_v4()).into()),
price: item
.price
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
})
.collect(),
};
self.subscriptions
.lock()
.insert(subscription.id.clone(), subscription.clone());
Ok(subscription)
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,

View file

@ -20,10 +20,11 @@ use crate::stripe_client::{
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
@ -69,6 +70,29 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer))
}
async fn list_subscriptions_for_customer(
&self,
customer_id: &StripeCustomerId,
) -> Result<Vec<StripeSubscription>> {
let customer_id = customer_id.try_into()?;
let subscriptions = stripe::Subscription::list(
&self.client,
&stripe::ListSubscriptions {
customer: Some(customer_id),
status: None,
..Default::default()
},
)
.await?;
Ok(subscriptions
.data
.into_iter()
.map(StripeSubscription::from)
.collect())
}
async fn get_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@ -80,6 +104,30 @@ impl StripeClient for RealStripeClient {
Ok(StripeSubscription::from(subscription))
}
async fn create_subscription(
&self,
params: StripeCreateSubscriptionParams,
) -> Result<StripeSubscription> {
let customer_id = params.customer.try_into()?;
let mut create_subscription = stripe::CreateSubscription::new(customer_id);
create_subscription.items = Some(
params
.items
.into_iter()
.map(|item| stripe::CreateSubscriptionItems {
price: item.price.map(|price| price.to_string()),
quantity: item.quantity,
..Default::default()
})
.collect(),
);
let subscription = Subscription::create(&self.client, create_subscription).await?;
Ok(StripeSubscription::from(subscription))
}
async fn update_subscription(
&self,
subscription_id: &StripeSubscriptionId,
@ -220,6 +268,10 @@ impl From<Subscription> for StripeSubscription {
fn from(value: Subscription) -> Self {
Self {
id: value.id.into(),
customer: value.customer.id().into(),
status: value.status,
current_period_start: value.current_period_start,
current_period_end: value.current_period_end,
items: value.items.data.into_iter().map(Into::into).collect(),
}
}

View file

@ -1,5 +1,6 @@
use std::sync::Arc;
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
@ -163,8 +164,13 @@ async fn test_subscribe_to_price() {
.lock()
.insert(price.id.clone(), price.clone());
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![],
};
stripe_client
@ -194,8 +200,13 @@ async fn test_subscribe_to_price() {
// Subscribing to a price that is already on the subscription is a no-op.
{
let now = Utc::now();
let subscription = StripeSubscription {
id: StripeSubscriptionId("sub_test".into()),
customer: StripeCustomerId("cus_test".into()),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(price.clone()),
@ -215,6 +226,104 @@ async fn test_subscribe_to_price() {
}
}
#[gpui::test]
async fn test_subscribe_to_zed_free() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let zed_pro_price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(0),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_pro_price.id.clone(), zed_pro_price.clone());
let zed_free_price = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(zed_free_price.id.clone(), zed_free_price.clone());
stripe_billing.initialize().await.unwrap();
// Customer is subscribed to Zed Free when not already subscribed to a plan.
{
let customer_id = StripeCustomerId("cus_no_plan".into());
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription.items[0].price.as_ref(), Some(&zed_free_price));
}
// Customer is not subscribed to Zed Free when they already have an active subscription.
{
let customer_id = StripeCustomerId("cus_active_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_active".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Active,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(30)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
// Customer is not subscribed to Zed Free when they already have a trial subscription.
{
let customer_id = StripeCustomerId("cus_trial_subscription".into());
let now = Utc::now();
let existing_subscription = StripeSubscription {
id: StripeSubscriptionId("sub_existing_trial".into()),
customer: customer_id.clone(),
status: stripe::SubscriptionStatus::Trialing,
current_period_start: now.timestamp(),
current_period_end: (now + Duration::days(14)).timestamp(),
items: vec![StripeSubscriptionItem {
id: StripeSubscriptionItemId("si_test".into()),
price: Some(zed_pro_price.clone()),
}],
};
stripe_client.subscriptions.lock().insert(
existing_subscription.id.clone(),
existing_subscription.clone(),
);
let subscription = stripe_billing
.subscribe_to_zed_free(customer_id)
.await
.unwrap();
assert_eq!(subscription, existing_subscription);
}
}
#[gpui::test]
async fn test_bill_model_request_usage() {
let (stripe_billing, stripe_client) = make_stripe_billing();