diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 68f8fa5042..28eaf4de08 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -11,8 +11,9 @@ use crate::Result; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ - RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, - StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, + RealStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, + StripeCheckoutSessionPaymentMethodCollection, StripeClient, + StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateMeterEventPayload, StripeCreateSubscriptionItems, StripeCreateSubscriptionParams, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription, @@ -245,6 +246,7 @@ impl StripeBilling { quantity: Some(1), }]); params.success_url = Some(success_url); + params.billing_address_collection = Some(StripeBillingAddressCollection::Required); let session = self.client.create_checkout_session(params).await?; Ok(session.url.context("no checkout session URL")?) @@ -298,6 +300,7 @@ impl StripeBilling { quantity: Some(1), }]); params.success_url = Some(success_url); + params.billing_address_collection = Some(StripeBillingAddressCollection::Required); let session = self.client.create_checkout_session(params).await?; Ok(session.url.context("no checkout session URL")?) diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 3511fb447e..48158e7cd9 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -148,6 +148,12 @@ pub struct StripeCreateMeterEventPayload<'a> { pub stripe_customer_id: &'a StripeCustomerId, } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum StripeBillingAddressCollection { + Auto, + Required, +} + #[derive(Debug, Default)] pub struct StripeCreateCheckoutSessionParams<'a> { pub customer: Option<&'a StripeCustomerId>, @@ -157,6 +163,7 @@ pub struct StripeCreateCheckoutSessionParams<'a> { pub payment_method_collection: Option, pub subscription_data: Option, pub success_url: Option<&'a str>, + pub billing_address_collection: Option, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index f679987f8b..96596aa414 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -8,8 +8,8 @@ use parking_lot::Mutex; use uuid::Uuid; use crate::stripe_client::{ - CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode, - StripeCheckoutSessionPaymentMethodCollection, StripeClient, + CreateCustomerParams, StripeBillingAddressCollection, StripeCheckoutSession, + StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripeMeterId, @@ -35,6 +35,7 @@ pub struct StripeCreateCheckoutSessionCall { pub payment_method_collection: Option, pub subscription_data: Option, pub success_url: Option, + pub billing_address_collection: Option, } pub struct FakeStripeClient { @@ -231,6 +232,7 @@ impl StripeClient for FakeStripeClient { payment_method_collection: params.payment_method_collection, subscription_data: params.subscription_data, success_url: params.success_url.map(|url| url.to_string()), + billing_address_collection: params.billing_address_collection, }); Ok(StripeCheckoutSession { diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 56ddc8d7ac..917e23cac3 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -17,9 +17,10 @@ use stripe::{ }; use crate::stripe_client::{ - CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason, - StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, - StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, + CreateCustomerParams, StripeBillingAddressCollection, StripeCancellationDetails, + StripeCancellationDetailsReason, StripeCheckoutSession, StripeCheckoutSessionMode, + StripeCheckoutSessionPaymentMethodCollection, StripeClient, + StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams, StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, @@ -444,6 +445,7 @@ impl<'a> TryFrom> for CreateCheckoutSessio payment_method_collection: value.payment_method_collection.map(Into::into), subscription_data: value.subscription_data.map(Into::into), success_url: value.success_url, + billing_address_collection: value.billing_address_collection.map(Into::into), ..Default::default() }) } @@ -526,3 +528,16 @@ impl From for StripeCheckoutSession { Self { url: value.url } } } + +impl From for stripe::CheckoutSessionBillingAddressCollection { + fn from(value: StripeBillingAddressCollection) -> Self { + match value { + StripeBillingAddressCollection::Auto => { + stripe::CheckoutSessionBillingAddressCollection::Auto + } + StripeBillingAddressCollection::Required => { + stripe::CheckoutSessionBillingAddressCollection::Required + } + } + } +} diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index 9c0dbad543..941669362d 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -6,11 +6,12 @@ use pretty_assertions::assert_eq; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_billing::StripeBilling; use crate::stripe_client::{ - FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection, - StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData, - StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring, - StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior, + FakeStripeClient, StripeBillingAddressCollection, StripeCheckoutSessionMode, + StripeCheckoutSessionPaymentMethodCollection, StripeCreateCheckoutSessionLineItems, + StripeCreateCheckoutSessionSubscriptionData, StripeCustomerId, StripeMeter, StripeMeterId, + StripePrice, StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, + StripeSubscriptionItem, StripeSubscriptionItemId, StripeSubscriptionTrialSettings, + StripeSubscriptionTrialSettingsEndBehavior, StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems, }; @@ -426,6 +427,10 @@ async fn test_checkout_with_zed_pro() { assert_eq!(call.payment_method_collection, None); assert_eq!(call.subscription_data, None); assert_eq!(call.success_url.as_deref(), Some(success_url)); + assert_eq!( + call.billing_address_collection, + Some(StripeBillingAddressCollection::Required) + ); } } @@ -507,6 +512,10 @@ async fn test_checkout_with_zed_pro_trial() { }) ); assert_eq!(call.success_url.as_deref(), Some(success_url)); + assert_eq!( + call.billing_address_collection, + Some(StripeBillingAddressCollection::Required) + ); } // Successful checkout with extended trial. @@ -561,5 +570,9 @@ async fn test_checkout_with_zed_pro_trial() { }) ); assert_eq!(call.success_url.as_deref(), Some(success_url)); + assert_eq!( + call.billing_address_collection, + Some(StripeBillingAddressCollection::Required) + ); } }