collab: Use StripeClient when creating Stripe Checkout sessions (#31644)

This PR updates the `StripeBilling::checkout_with_zed_pro` and
`StripeBilling::checkout_with_zed_pro_trial` methods to use the
`StripeClient` trait instead of using `stripe::Client` directly.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-28 20:57:04 -04:00 committed by GitHub
parent 97579662e6
commit eb863f8fd6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 471 additions and 59 deletions

View file

@ -338,13 +338,11 @@ async fn create_billing_subscription(
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
} else {
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?
.try_into()?
};
let success_url = format!(
@ -355,7 +353,7 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
ProductCode::ZedPro => {
stripe_billing
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
.await?
}
ProductCode::ZedProTrial => {
@ -372,7 +370,7 @@ async fn create_billing_subscription(
stripe_billing
.checkout_with_zed_pro_trial(
customer_id,
&customer_id,
&user.github_login,
feature_flags,
&success_url,

View file

@ -11,11 +11,14 @@ use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
UpdateSubscriptionParams,
};
pub struct StripeBilling {
@ -190,9 +193,9 @@ impl StripeBilling {
items: Some(vec![UpdateSubscriptionItems {
price: Some(price.id.clone()),
}]),
trial_settings: Some(UpdateSubscriptionTrialSettings {
end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
},
}),
},
@ -228,30 +231,29 @@ impl StripeBilling {
pub async fn checkout_with_zed_pro(
&self,
customer_id: stripe::CustomerId,
customer_id: &StripeCustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
let mut params = StripeCreateCheckoutSessionParams::default();
params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro_trial(
&self,
customer_id: stripe::CustomerId,
customer_id: &StripeCustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
@ -272,34 +274,33 @@ impl StripeBilling {
);
}
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
let mut params = StripeCreateCheckoutSessionParams::default();
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(trial_period_days),
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
}
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: if !subscription_metadata.is_empty() {
Some(subscription_metadata)
} else {
None
},
..Default::default()
});
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.mode = Some(StripeCheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
let session = self.client.create_checkout_session(params).await?;
Ok(session.url.context("no checkout session URL")?)
}

View file

@ -2,6 +2,7 @@
mod fake_stripe_client;
mod real_stripe_client;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Result;
@ -47,7 +48,7 @@ pub struct StripeSubscriptionItem {
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionParams {
pub items: Option<Vec<UpdateSubscriptionItems>>,
pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug, PartialEq, Clone)]
@ -55,18 +56,18 @@ pub struct UpdateSubscriptionItems {
pub price: Option<StripePriceId>,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionTrialSettings {
pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettings {
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
}
#[derive(Debug, Clone)]
pub struct UpdateSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
#[derive(Debug, PartialEq, Clone)]
pub struct StripeSubscriptionTrialSettingsEndBehavior {
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
Cancel,
CreateInvoice,
Pause,
@ -111,6 +112,48 @@ pub struct StripeCreateMeterEventPayload<'a> {
pub stripe_customer_id: &'a StripeCustomerId,
}
#[derive(Debug, Default)]
pub struct StripeCreateCheckoutSessionParams<'a> {
pub customer: Option<&'a StripeCustomerId>,
pub client_reference_id: Option<&'a str>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionMode {
Payment,
Setup,
Subscription,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionLineItems {
pub price: Option<String>,
pub quantity: Option<u64>,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum StripeCheckoutSessionPaymentMethodCollection {
Always,
IfRequired,
}
#[derive(Debug, PartialEq, Clone)]
pub struct StripeCreateCheckoutSessionSubscriptionData {
pub metadata: Option<HashMap<String, String>>,
pub trial_period_days: Option<u32>,
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
}
#[derive(Debug)]
pub struct StripeCheckoutSession {
pub url: Option<String>,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
@ -133,4 +176,9 @@ pub trait StripeClient: Send + Sync {
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession>;
}

View file

@ -7,7 +7,10 @@ use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
StripeSubscriptionId, UpdateSubscriptionParams,
};
@ -21,6 +24,17 @@ pub struct StripeCreateMeterEventCall {
pub timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct StripeCreateCheckoutSessionCall {
pub customer: Option<StripeCustomerId>,
pub client_reference_id: Option<String>,
pub mode: Option<StripeCheckoutSessionMode>,
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
pub success_url: Option<String>,
}
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
@ -29,6 +43,7 @@ pub struct FakeStripeClient {
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
}
impl FakeStripeClient {
@ -40,6 +55,7 @@ impl FakeStripeClient {
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
@ -119,4 +135,25 @@ impl StripeClient for FakeStripeClient {
Ok(())
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
self.create_checkout_session_calls
.lock()
.push(StripeCreateCheckoutSessionCall {
customer: params.customer.cloned(),
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
mode: params.mode,
line_items: params.line_items,
payment_method_collection: params.payment_method_collection,
subscription_data: params.subscription_data,
success_url: params.success_url.map(|url| url.to_string()),
});
Ok(StripeCheckoutSession {
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
})
}
}

View file

@ -5,6 +5,11 @@ use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::Serialize;
use stripe::{
CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
CreateCheckoutSessionSubscriptionDataTrialSettings,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
@ -12,10 +17,14 @@ use stripe::{
};
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
UpdateSubscriptionParams,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
};
pub struct RealStripeClient {
@ -150,6 +159,16 @@ impl StripeClient for RealStripeClient {
Err(error) => Err(anyhow!(error)),
}
}
async fn create_checkout_session(
&self,
params: StripeCreateCheckoutSessionParams<'_>,
) -> Result<StripeCheckoutSession> {
let params = params.try_into()?;
let session = CheckoutSession::create(&self.client, params).await?;
Ok(session.into())
}
}
impl From<CustomerId> for StripeCustomerId {
@ -166,6 +185,14 @@ impl TryFrom<StripeCustomerId> for CustomerId {
}
}
impl TryFrom<&StripeCustomerId> for CustomerId {
type Error = anyhow::Error;
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
}
}
impl From<Customer> for StripeCustomer {
fn from(value: Customer) -> Self {
StripeCustomer {
@ -213,38 +240,34 @@ impl From<SubscriptionItem> for StripeSubscriptionItem {
}
}
impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
for UpdateSubscriptionTrialSettings
{
fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for UpdateSubscriptionTrialSettingsEndBehavior
{
fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(
value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
) -> Self {
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
@ -279,3 +302,103 @@ impl From<Recurring> for StripePriceRecurring {
Self { meter: value.meter }
}
}
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
type Error = anyhow::Error;
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
Ok(Self {
customer: value
.customer
.map(|customer_id| customer_id.try_into())
.transpose()?,
client_reference_id: value.client_reference_id,
mode: value.mode.map(Into::into),
line_items: value
.line_items
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
payment_method_collection: value.payment_method_collection.map(Into::into),
subscription_data: value.subscription_data.map(Into::into),
success_url: value.success_url,
..Default::default()
})
}
}
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
fn from(value: StripeCheckoutSessionMode) -> Self {
match value {
StripeCheckoutSessionMode::Payment => Self::Payment,
StripeCheckoutSessionMode::Setup => Self::Setup,
StripeCheckoutSessionMode::Subscription => Self::Subscription,
}
}
}
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
Self {
price: value.price,
quantity: value.quantity,
..Default::default()
}
}
}
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
match value {
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
}
}
}
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
Self {
trial_period_days: value.trial_period_days,
trial_settings: value.trial_settings.map(Into::into),
metadata: value.metadata,
..Default::default()
}
}
}
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
fn from(value: StripeSubscriptionTrialSettings) -> Self {
Self {
end_behavior: value.end_behavior.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehavior>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
{
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
Self {
missing_payment_method: value.missing_payment_method.into(),
}
}
}
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
{
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
match value {
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
Self::CreateInvoice
}
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
}
}
}
impl From<CheckoutSession> for StripeCheckoutSession {
fn from(value: CheckoutSession) -> Self {
Self { url: value.url }
}
}

View file

@ -2,11 +2,15 @@ use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
StripeSubscriptionItemId, UpdateSubscriptionItems,
FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@ -241,3 +245,204 @@ async fn test_bill_model_request_usage() {
);
assert_eq!(create_meter_event_calls[0].value, 73);
}
#[gpui::test]
async fn test_checkout_with_zed_pro() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
// Successful checkout.
{
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
let checkout_url = stripe_billing
.checkout_with_zed_pro(&customer_id, github_login, success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(call.payment_method_collection, None);
assert_eq!(call.subscription_data, None);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}
#[gpui::test]
async fn test_checkout_with_zed_pro_trial() {
let (stripe_billing, stripe_client) = make_stripe_billing();
let customer_id = StripeCustomerId("cus_test".into());
let github_login = "zeduser1";
let success_url = "https://example.com/success";
// It returns an error when the Zed Pro price doesn't exist.
{
let result = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await;
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
r#"no price ID found for "zed-pro""#
);
}
let price = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(2000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
stripe_client
.prices
.lock()
.insert(price.id.clone(), price.clone());
stripe_billing.initialize().await.unwrap();
// Successful checkout.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer.as_ref(), Some(&customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(14),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: None,
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
// Successful checkout with extended trial.
{
let checkout_url = stripe_billing
.checkout_with_zed_pro_trial(
&customer_id,
github_login,
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
success_url,
)
.await
.unwrap();
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
let create_checkout_session_calls = stripe_client
.create_checkout_session_calls
.lock()
.drain(..)
.collect::<Vec<_>>();
assert_eq!(create_checkout_session_calls.len(), 1);
let call = create_checkout_session_calls.into_iter().next().unwrap();
assert_eq!(call.customer, Some(customer_id));
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
assert_eq!(
call.line_items,
Some(vec![StripeCreateCheckoutSessionLineItems {
price: Some(price.id.to_string()),
quantity: Some(1)
}])
);
assert_eq!(
call.payment_method_collection,
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
);
assert_eq!(
call.subscription_data,
Some(StripeCreateCheckoutSessionSubscriptionData {
trial_period_days: Some(60),
trial_settings: Some(StripeSubscriptionTrialSettings {
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
missing_payment_method:
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
metadata: Some(std::collections::HashMap::from_iter([(
"promo_feature_flag".into(),
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
)])),
})
);
assert_eq!(call.success_url.as_deref(), Some(success_url));
}
}