From 00bdebc89dbc3db717d7ad4e2a740775b2fe96ed Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 28 May 2025 17:17:11 -0400 Subject: [PATCH] collab: Use `StripeClient` in `StripeBilling::subscribe_to_price` (#31631) This PR updates the `StripeBilling::subscribe_to_price` method to use the `StripeClient` trait. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 2 +- crates/collab/src/stripe_billing.rs | 49 ++++--- crates/collab/src/stripe_client.rs | 57 ++++++++ .../src/stripe_client/fake_stripe_client.rs | 35 ++++- .../src/stripe_client/real_stripe_client.rs | 124 +++++++++++++++++- .../collab/src/tests/stripe_billing_tests.rs | 69 ++++++++++ 6 files changed, 306 insertions(+), 30 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 0edfa2f1bf..607576bb04 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1578,7 +1578,7 @@ async fn sync_model_request_usage_with_stripe( }; stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) + .subscribe_to_price(&stripe_subscription_id.into(), price) .await?; stripe_billing .bill_model_request_usage( diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index e8a716d206..d3f062042b 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -13,6 +13,9 @@ use crate::db::billing_subscription::SubscriptionKind; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, + StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams, + UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, + UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, }; pub struct StripeBilling { @@ -166,14 +169,12 @@ impl StripeBilling { pub async fn subscribe_to_price( &self, - subscription_id: &stripe::SubscriptionId, + subscription_id: &StripeSubscriptionId, price: &StripePrice, ) -> Result<()> { - let subscription = - stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?; + let subscription = self.client.get_subscription(subscription_id).await?; - let price_id = price.id.clone().try_into()?; - if subscription_contains_price(&subscription, &price_id) { + if subscription_contains_price(&subscription, &price.id) { return Ok(()); } @@ -182,23 +183,21 @@ impl StripeBilling { let price_per_unit = price.unit_amount.unwrap_or_default(); let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; - stripe::Subscription::update( - &self.real_client, - subscription_id, - stripe::UpdateSubscription { - items: Some(vec![stripe::UpdateSubscriptionItems { - price: Some(price.id.to_string()), - ..Default::default() - }]), - trial_settings: Some(stripe::UpdateSubscriptionTrialSettings { - end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior { - missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, - }, - }), - ..Default::default() - }, - ) - .await?; + self.client + .update_subscription( + subscription_id, + UpdateSubscriptionParams { + items: Some(vec![UpdateSubscriptionItems { + price: Some(price.id.clone()), + }]), + trial_settings: Some(UpdateSubscriptionTrialSettings { + end_behavior: UpdateSubscriptionTrialSettingsEndBehavior { + missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel + }, + }), + }, + ) + .await?; Ok(()) } @@ -419,10 +418,10 @@ struct StripeCreateMeterEventPayload<'a> { } fn subscription_contains_price( - subscription: &stripe::Subscription, - price_id: &stripe::PriceId, + subscription: &StripeSubscription, + price_id: &StripePriceId, ) -> bool { - subscription.items.data.iter().any(|item| { + subscription.items.iter().any(|item| { item.price .as_ref() .map_or(false, |price| price.id == *price_id) diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 5fcb139a7e..8ecf0b2fe5 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -26,6 +26,52 @@ pub struct CreateCustomerParams<'a> { pub email: Option<&'a str>, } +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] +pub struct StripeSubscriptionId(pub Arc); + +#[derive(Debug, Clone)] +pub struct StripeSubscription { + pub id: StripeSubscriptionId, + pub items: Vec, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] +pub struct StripeSubscriptionItemId(pub Arc); + +#[derive(Debug, Clone)] +pub struct StripeSubscriptionItem { + pub id: StripeSubscriptionItemId, + pub price: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateSubscriptionParams { + pub items: Option>, + pub trial_settings: Option, +} + +#[derive(Debug, PartialEq, Clone)] +pub struct UpdateSubscriptionItems { + pub price: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateSubscriptionTrialSettings { + pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior, +} + +#[derive(Debug, Clone)] +pub struct UpdateSubscriptionTrialSettingsEndBehavior { + pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod { + Cancel, + CreateInvoice, + Pause, +} + #[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] pub struct StripePriceId(pub Arc); @@ -57,6 +103,17 @@ pub trait StripeClient: Send + Sync { async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; + async fn get_subscription( + &self, + subscription_id: &StripeSubscriptionId, + ) -> Result; + + async fn update_subscription( + &self, + subscription_id: &StripeSubscriptionId, + params: UpdateSubscriptionParams, + ) -> Result<()>; + async fn list_prices(&self) -> Result>; async fn list_meters(&self) -> Result>; diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 9c1d407215..3c3be84da1 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use anyhow::Result; +use anyhow::{Result, anyhow}; use async_trait::async_trait; use collections::HashMap; use parking_lot::Mutex; @@ -8,11 +8,15 @@ use uuid::Uuid; use crate::stripe_client::{ CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, - StripeMeterId, StripePrice, StripePriceId, + StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, + UpdateSubscriptionParams, }; pub struct FakeStripeClient { pub customers: Arc>>, + pub subscriptions: Arc>>, + pub update_subscription_calls: + Arc>>, pub prices: Arc>>, pub meters: Arc>>, } @@ -21,6 +25,8 @@ impl FakeStripeClient { pub fn new() -> Self { Self { customers: Arc::new(Mutex::new(HashMap::default())), + subscriptions: Arc::new(Mutex::new(HashMap::default())), + update_subscription_calls: Arc::new(Mutex::new(Vec::new())), prices: Arc::new(Mutex::new(HashMap::default())), meters: Arc::new(Mutex::new(HashMap::default())), } @@ -52,6 +58,31 @@ impl StripeClient for FakeStripeClient { Ok(customer) } + async fn get_subscription( + &self, + subscription_id: &StripeSubscriptionId, + ) -> Result { + self.subscriptions + .lock() + .get(subscription_id) + .cloned() + .ok_or_else(|| anyhow!("no subscription found for {subscription_id:?}")) + } + + async fn update_subscription( + &self, + subscription_id: &StripeSubscriptionId, + params: UpdateSubscriptionParams, + ) -> Result<()> { + let subscription = self.get_subscription(subscription_id).await?; + + self.update_subscription_calls + .lock() + .push((subscription.id, params)); + + Ok(()) + } + async fn list_prices(&self) -> Result> { let prices = self.prices.lock().values().cloned().collect(); diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 9ea07a2979..62f436d617 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -4,11 +4,17 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use async_trait::async_trait; use serde::Serialize; -use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring}; +use stripe::{ + CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription, + SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems, + UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, + UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, +}; use crate::stripe_client::{ CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, - StripePriceId, StripePriceRecurring, + StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, + StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams, }; pub struct RealStripeClient { @@ -53,6 +59,46 @@ impl StripeClient for RealStripeClient { Ok(StripeCustomer::from(customer)) } + async fn get_subscription( + &self, + subscription_id: &StripeSubscriptionId, + ) -> Result { + let subscription_id = subscription_id.try_into()?; + + let subscription = Subscription::retrieve(&self.client, &subscription_id, &[]).await?; + + Ok(StripeSubscription::from(subscription)) + } + + async fn update_subscription( + &self, + subscription_id: &StripeSubscriptionId, + params: UpdateSubscriptionParams, + ) -> Result<()> { + let subscription_id = subscription_id.try_into()?; + + stripe::Subscription::update( + &self.client, + &subscription_id, + stripe::UpdateSubscription { + items: params.items.map(|items| { + items + .into_iter() + .map(|item| UpdateSubscriptionItems { + price: item.price.map(|price| price.to_string()), + ..Default::default() + }) + .collect() + }), + trial_settings: params.trial_settings.map(Into::into), + ..Default::default() + }, + ) + .await?; + + Ok(()) + } + async fn list_prices(&self) -> Result> { let response = stripe::Price::list( &self.client, @@ -108,6 +154,80 @@ impl From for StripeCustomer { } } +impl From for StripeSubscriptionId { + fn from(value: SubscriptionId) -> Self { + Self(value.as_str().into()) + } +} + +impl TryFrom<&StripeSubscriptionId> for SubscriptionId { + type Error = anyhow::Error; + + fn try_from(value: &StripeSubscriptionId) -> Result { + Self::from_str(value.0.as_ref()).context("failed to parse Stripe subscription ID") + } +} + +impl From for StripeSubscription { + fn from(value: Subscription) -> Self { + Self { + id: value.id.into(), + items: value.items.data.into_iter().map(Into::into).collect(), + } + } +} + +impl From for StripeSubscriptionItemId { + fn from(value: SubscriptionItemId) -> Self { + Self(value.as_str().into()) + } +} + +impl From for StripeSubscriptionItem { + fn from(value: SubscriptionItem) -> Self { + Self { + id: value.id.into(), + price: value.price.map(Into::into), + } + } +} + +impl From + for UpdateSubscriptionTrialSettings +{ + fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self { + Self { + end_behavior: value.end_behavior.into(), + } + } +} + +impl From + for UpdateSubscriptionTrialSettingsEndBehavior +{ + fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self { + Self { + missing_payment_method: value.missing_payment_method.into(), + } + } +} + +impl From + for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod +{ + fn from( + value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, + ) -> Self { + match value { + crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel, + crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => { + Self::CreateInvoice + } + crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause, + } + } +} + impl From for StripePriceId { fn from(value: PriceId) -> Self { Self(value.as_str().into()) diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index cce84186ae..b12fa722f3 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -5,6 +5,8 @@ use pretty_assertions::assert_eq; use crate::stripe_billing::StripeBilling; use crate::stripe_client::{ FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring, + StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, + UpdateSubscriptionItems, }; fn make_stripe_billing() -> (StripeBilling, Arc) { @@ -141,3 +143,70 @@ async fn test_find_or_create_customer_by_email() { assert_eq!(customer.email.as_deref(), Some(email)); } } + +#[gpui::test] +async fn test_subscribe_to_price() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + let price = StripePrice { + id: StripePriceId("price_test".into()), + unit_amount: Some(2000), + lookup_key: Some("test-price".to_string()), + recurring: None, + }; + stripe_client + .prices + .lock() + .insert(price.id.clone(), price.clone()); + + let subscription = StripeSubscription { + id: StripeSubscriptionId("sub_test".into()), + items: vec![], + }; + stripe_client + .subscriptions + .lock() + .insert(subscription.id.clone(), subscription.clone()); + + stripe_billing + .subscribe_to_price(&subscription.id, &price) + .await + .unwrap(); + + let update_subscription_calls = stripe_client + .update_subscription_calls + .lock() + .iter() + .map(|(id, params)| (id.clone(), params.clone())) + .collect::>(); + assert_eq!(update_subscription_calls.len(), 1); + assert_eq!(update_subscription_calls[0].0, subscription.id); + assert_eq!( + update_subscription_calls[0].1.items, + Some(vec![UpdateSubscriptionItems { + price: Some(price.id.clone()) + }]) + ); + + // Subscribing to a price that is already on the subscription is a no-op. + { + let subscription = StripeSubscription { + id: StripeSubscriptionId("sub_test".into()), + items: vec![StripeSubscriptionItem { + id: StripeSubscriptionItemId("si_test".into()), + price: Some(price.clone()), + }], + }; + stripe_client + .subscriptions + .lock() + .insert(subscription.id.clone(), subscription.clone()); + + stripe_billing + .subscribe_to_price(&subscription.id, &price) + .await + .unwrap(); + + assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1); + } +}