collab: Use StripeClient
when creating Stripe Checkout sessions (#31644)
This PR updates the `StripeBilling::checkout_with_zed_pro` and `StripeBilling::checkout_with_zed_pro_trial` methods to use the `StripeClient` trait instead of using `stripe::Client` directly. Release Notes: - N/A
This commit is contained in:
parent
97579662e6
commit
eb863f8fd6
6 changed files with 471 additions and 59 deletions
|
@ -338,13 +338,11 @@ async fn create_billing_subscription(
|
|||
}
|
||||
|
||||
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
|
||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?
|
||||
StripeCustomerId(existing_customer.stripe_customer_id.clone().into())
|
||||
} else {
|
||||
stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?
|
||||
.try_into()?
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
|
@ -355,7 +353,7 @@ async fn create_billing_subscription(
|
|||
let checkout_session_url = match body.product {
|
||||
ProductCode::ZedPro => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
|
||||
.checkout_with_zed_pro(&customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
ProductCode::ZedProTrial => {
|
||||
|
@ -372,7 +370,7 @@ async fn create_billing_subscription(
|
|||
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
customer_id,
|
||||
&customer_id,
|
||||
&user.github_login,
|
||||
feature_flags,
|
||||
&success_url,
|
||||
|
|
|
@ -11,11 +11,14 @@ use crate::Result;
|
|||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_client::{
|
||||
RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
|
||||
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
|
||||
StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
|
||||
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
|
||||
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
RealStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||
StripeCreateMeterEventPayload, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
|
||||
StripeSubscription, StripeSubscriptionId, StripeSubscriptionTrialSettings,
|
||||
StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||
UpdateSubscriptionParams,
|
||||
};
|
||||
|
||||
pub struct StripeBilling {
|
||||
|
@ -190,9 +193,9 @@ impl StripeBilling {
|
|||
items: Some(vec![UpdateSubscriptionItems {
|
||||
price: Some(price.id.clone()),
|
||||
}]),
|
||||
trial_settings: Some(UpdateSubscriptionTrialSettings {
|
||||
end_behavior: UpdateSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel
|
||||
},
|
||||
}),
|
||||
},
|
||||
|
@ -228,30 +231,29 @@ impl StripeBilling {
|
|||
|
||||
pub async fn checkout_with_zed_pro(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
customer_id: &StripeCustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await?;
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||
let session = self.client.create_checkout_session(params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_pro_trial(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
customer_id: &StripeCustomerId,
|
||||
github_login: &str,
|
||||
feature_flags: Vec<String>,
|
||||
success_url: &str,
|
||||
|
@ -272,34 +274,33 @@ impl StripeBilling {
|
|||
);
|
||||
}
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
|
||||
let mut params = StripeCreateCheckoutSessionParams::default();
|
||||
params.subscription_data = Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(trial_period_days),
|
||||
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
|
||||
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
|
||||
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
}
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: if !subscription_metadata.is_empty() {
|
||||
Some(subscription_metadata)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.mode = Some(StripeCheckoutSessionMode::Subscription);
|
||||
params.payment_method_collection =
|
||||
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
params.line_items = Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||
let session = self.client.create_checkout_session(params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
mod fake_stripe_client;
|
||||
mod real_stripe_client;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
|
@ -47,7 +48,7 @@ pub struct StripeSubscriptionItem {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateSubscriptionParams {
|
||||
pub items: Option<Vec<UpdateSubscriptionItems>>,
|
||||
pub trial_settings: Option<UpdateSubscriptionTrialSettings>,
|
||||
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
|
@ -55,18 +56,18 @@ pub struct UpdateSubscriptionItems {
|
|||
pub price: Option<StripePriceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateSubscriptionTrialSettings {
|
||||
pub end_behavior: UpdateSubscriptionTrialSettingsEndBehavior,
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscriptionTrialSettings {
|
||||
pub end_behavior: StripeSubscriptionTrialSettingsEndBehavior,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateSubscriptionTrialSettingsEndBehavior {
|
||||
pub missing_payment_method: UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeSubscriptionTrialSettingsEndBehavior {
|
||||
pub missing_payment_method: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
|
||||
pub enum StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod {
|
||||
Cancel,
|
||||
CreateInvoice,
|
||||
Pause,
|
||||
|
@ -111,6 +112,48 @@ pub struct StripeCreateMeterEventPayload<'a> {
|
|||
pub stripe_customer_id: &'a StripeCustomerId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct StripeCreateCheckoutSessionParams<'a> {
|
||||
pub customer: Option<&'a StripeCustomerId>,
|
||||
pub client_reference_id: Option<&'a str>,
|
||||
pub mode: Option<StripeCheckoutSessionMode>,
|
||||
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||
pub success_url: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCheckoutSessionMode {
|
||||
Payment,
|
||||
Setup,
|
||||
Subscription,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionLineItems {
|
||||
pub price: Option<String>,
|
||||
pub quantity: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCheckoutSessionPaymentMethodCollection {
|
||||
Always,
|
||||
IfRequired,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionSubscriptionData {
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
pub trial_period_days: Option<u32>,
|
||||
pub trial_settings: Option<StripeSubscriptionTrialSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StripeCheckoutSession {
|
||||
pub url: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait StripeClient: Send + Sync {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||
|
@ -133,4 +176,9 @@ pub trait StripeClient: Send + Sync {
|
|||
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||
|
||||
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession>;
|
||||
}
|
||||
|
|
|
@ -7,7 +7,10 @@ use parking_lot::Mutex;
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
|
||||
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
|
||||
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
|
||||
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
|
||||
StripeSubscriptionId, UpdateSubscriptionParams,
|
||||
};
|
||||
|
@ -21,6 +24,17 @@ pub struct StripeCreateMeterEventCall {
|
|||
pub timestamp: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripeCreateCheckoutSessionCall {
|
||||
pub customer: Option<StripeCustomerId>,
|
||||
pub client_reference_id: Option<String>,
|
||||
pub mode: Option<StripeCheckoutSessionMode>,
|
||||
pub line_items: Option<Vec<StripeCreateCheckoutSessionLineItems>>,
|
||||
pub payment_method_collection: Option<StripeCheckoutSessionPaymentMethodCollection>,
|
||||
pub subscription_data: Option<StripeCreateCheckoutSessionSubscriptionData>,
|
||||
pub success_url: Option<String>,
|
||||
}
|
||||
|
||||
pub struct FakeStripeClient {
|
||||
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
|
||||
|
@ -29,6 +43,7 @@ pub struct FakeStripeClient {
|
|||
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
||||
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
||||
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
|
||||
pub create_checkout_session_calls: Arc<Mutex<Vec<StripeCreateCheckoutSessionCall>>>,
|
||||
}
|
||||
|
||||
impl FakeStripeClient {
|
||||
|
@ -40,6 +55,7 @@ impl FakeStripeClient {
|
|||
prices: Arc::new(Mutex::new(HashMap::default())),
|
||||
meters: Arc::new(Mutex::new(HashMap::default())),
|
||||
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
|
||||
create_checkout_session_calls: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -119,4 +135,25 @@ impl StripeClient for FakeStripeClient {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession> {
|
||||
self.create_checkout_session_calls
|
||||
.lock()
|
||||
.push(StripeCreateCheckoutSessionCall {
|
||||
customer: params.customer.cloned(),
|
||||
client_reference_id: params.client_reference_id.map(|id| id.to_string()),
|
||||
mode: params.mode,
|
||||
line_items: params.line_items,
|
||||
payment_method_collection: params.payment_method_collection,
|
||||
subscription_data: params.subscription_data,
|
||||
success_url: params.success_url.map(|url| url.to_string()),
|
||||
});
|
||||
|
||||
Ok(StripeCheckoutSession {
|
||||
url: Some("https://checkout.stripe.com/c/pay/cs_test_1".to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,11 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use async_trait::async_trait;
|
||||
use serde::Serialize;
|
||||
use stripe::{
|
||||
CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
|
||||
CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettings,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
|
||||
SubscriptionId, SubscriptionItem, SubscriptionItemId, UpdateSubscriptionItems,
|
||||
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
|
||||
|
@ -12,10 +17,14 @@ use stripe::{
|
|||
};
|
||||
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
|
||||
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
|
||||
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams, StripeCustomer,
|
||||
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
|
||||
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
||||
UpdateSubscriptionParams,
|
||||
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionParams,
|
||||
};
|
||||
|
||||
pub struct RealStripeClient {
|
||||
|
@ -150,6 +159,16 @@ impl StripeClient for RealStripeClient {
|
|||
Err(error) => Err(anyhow!(error)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
params: StripeCreateCheckoutSessionParams<'_>,
|
||||
) -> Result<StripeCheckoutSession> {
|
||||
let params = params.try_into()?;
|
||||
let session = CheckoutSession::create(&self.client, params).await?;
|
||||
|
||||
Ok(session.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CustomerId> for StripeCustomerId {
|
||||
|
@ -166,6 +185,14 @@ impl TryFrom<StripeCustomerId> for CustomerId {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&StripeCustomerId> for CustomerId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: &StripeCustomerId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Customer> for StripeCustomer {
|
||||
fn from(value: Customer) -> Self {
|
||||
StripeCustomer {
|
||||
|
@ -213,38 +240,34 @@ impl From<SubscriptionItem> for StripeSubscriptionItem {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<crate::stripe_client::UpdateSubscriptionTrialSettings>
|
||||
for UpdateSubscriptionTrialSettings
|
||||
{
|
||||
fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettings) -> Self {
|
||||
impl From<StripeSubscriptionTrialSettings> for UpdateSubscriptionTrialSettings {
|
||||
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||
Self {
|
||||
end_behavior: value.end_behavior.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior>
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||
for UpdateSubscriptionTrialSettingsEndBehavior
|
||||
{
|
||||
fn from(value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||
Self {
|
||||
missing_payment_method: value.missing_payment_method.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||
for UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod
|
||||
{
|
||||
fn from(
|
||||
value: crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
) -> Self {
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||
match value {
|
||||
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||
Self::CreateInvoice
|
||||
}
|
||||
crate::stripe_client::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -279,3 +302,103 @@ impl From<Recurring> for StripePriceRecurring {
|
|||
Self { meter: value.meter }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<StripeCreateCheckoutSessionParams<'a>> for CreateCheckoutSession<'a> {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: StripeCreateCheckoutSessionParams<'a>) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
customer: value
|
||||
.customer
|
||||
.map(|customer_id| customer_id.try_into())
|
||||
.transpose()?,
|
||||
client_reference_id: value.client_reference_id,
|
||||
mode: value.mode.map(Into::into),
|
||||
line_items: value
|
||||
.line_items
|
||||
.map(|line_items| line_items.into_iter().map(Into::into).collect()),
|
||||
payment_method_collection: value.payment_method_collection.map(Into::into),
|
||||
subscription_data: value.subscription_data.map(Into::into),
|
||||
success_url: value.success_url,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCheckoutSessionMode> for CheckoutSessionMode {
|
||||
fn from(value: StripeCheckoutSessionMode) -> Self {
|
||||
match value {
|
||||
StripeCheckoutSessionMode::Payment => Self::Payment,
|
||||
StripeCheckoutSessionMode::Setup => Self::Setup,
|
||||
StripeCheckoutSessionMode::Subscription => Self::Subscription,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCreateCheckoutSessionLineItems> for CreateCheckoutSessionLineItems {
|
||||
fn from(value: StripeCreateCheckoutSessionLineItems) -> Self {
|
||||
Self {
|
||||
price: value.price,
|
||||
quantity: value.quantity,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCheckoutSessionPaymentMethodCollection> for CheckoutSessionPaymentMethodCollection {
|
||||
fn from(value: StripeCheckoutSessionPaymentMethodCollection) -> Self {
|
||||
match value {
|
||||
StripeCheckoutSessionPaymentMethodCollection::Always => Self::Always,
|
||||
StripeCheckoutSessionPaymentMethodCollection::IfRequired => Self::IfRequired,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeCreateCheckoutSessionSubscriptionData> for CreateCheckoutSessionSubscriptionData {
|
||||
fn from(value: StripeCreateCheckoutSessionSubscriptionData) -> Self {
|
||||
Self {
|
||||
trial_period_days: value.trial_period_days,
|
||||
trial_settings: value.trial_settings.map(Into::into),
|
||||
metadata: value.metadata,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettings> for CreateCheckoutSessionSubscriptionDataTrialSettings {
|
||||
fn from(value: StripeSubscriptionTrialSettings) -> Self {
|
||||
Self {
|
||||
end_behavior: value.end_behavior.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehavior>
|
||||
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehavior) -> Self {
|
||||
Self {
|
||||
missing_payment_method: value.missing_payment_method.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod>
|
||||
for CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod
|
||||
{
|
||||
fn from(value: StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod) -> Self {
|
||||
match value {
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel => Self::Cancel,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::CreateInvoice => {
|
||||
Self::CreateInvoice
|
||||
}
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Pause => Self::Pause,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CheckoutSession> for StripeCheckoutSession {
|
||||
fn from(value: CheckoutSession) -> Self {
|
||||
Self { url: value.url }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,15 @@ use std::sync::Arc;
|
|||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{
|
||||
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
|
||||
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
|
||||
StripeSubscriptionItemId, UpdateSubscriptionItems,
|
||||
FakeStripeClient, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionSubscriptionData,
|
||||
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
|
||||
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
||||
StripeSubscriptionTrialSettings, StripeSubscriptionTrialSettingsEndBehavior,
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, UpdateSubscriptionItems,
|
||||
};
|
||||
|
||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
|
@ -241,3 +245,204 @@ async fn test_bill_model_request_usage() {
|
|||
);
|
||||
assert_eq!(create_meter_event_calls[0].value, 73);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkout_with_zed_pro() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let customer_id = StripeCustomerId("cus_test".into());
|
||||
let github_login = "zeduser1";
|
||||
let success_url = "https://example.com/success";
|
||||
|
||||
// It returns an error when the Zed Pro price doesn't exist.
|
||||
{
|
||||
let result = stripe_billing
|
||||
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
r#"no price ID found for "zed-pro""#
|
||||
);
|
||||
}
|
||||
|
||||
// Successful checkout.
|
||||
{
|
||||
let price = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(2000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price.id.clone(), price.clone());
|
||||
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro(&customer_id, github_login, success_url)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer, Some(customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(call.payment_method_collection, None);
|
||||
assert_eq!(call.subscription_data, None);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkout_with_zed_pro_trial() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
let customer_id = StripeCustomerId("cus_test".into());
|
||||
let github_login = "zeduser1";
|
||||
let success_url = "https://example.com/success";
|
||||
|
||||
// It returns an error when the Zed Pro price doesn't exist.
|
||||
{
|
||||
let result = stripe_billing
|
||||
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(
|
||||
result.err().unwrap().to_string(),
|
||||
r#"no price ID found for "zed-pro""#
|
||||
);
|
||||
}
|
||||
|
||||
let price = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(2000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price.id.clone(), price.clone());
|
||||
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
// Successful checkout.
|
||||
{
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro_trial(&customer_id, github_login, Vec::new(), success_url)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer.as_ref(), Some(&customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
call.payment_method_collection,
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||
);
|
||||
assert_eq!(
|
||||
call.subscription_data,
|
||||
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(14),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: None,
|
||||
})
|
||||
);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
}
|
||||
|
||||
// Successful checkout with extended trial.
|
||||
{
|
||||
let checkout_url = stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
&customer_id,
|
||||
github_login,
|
||||
vec![AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string()],
|
||||
success_url,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(checkout_url.starts_with("https://checkout.stripe.com/c/pay"));
|
||||
|
||||
let create_checkout_session_calls = stripe_client
|
||||
.create_checkout_session_calls
|
||||
.lock()
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(create_checkout_session_calls.len(), 1);
|
||||
let call = create_checkout_session_calls.into_iter().next().unwrap();
|
||||
assert_eq!(call.customer, Some(customer_id));
|
||||
assert_eq!(call.client_reference_id.as_deref(), Some(github_login));
|
||||
assert_eq!(call.mode, Some(StripeCheckoutSessionMode::Subscription));
|
||||
assert_eq!(
|
||||
call.line_items,
|
||||
Some(vec![StripeCreateCheckoutSessionLineItems {
|
||||
price: Some(price.id.to_string()),
|
||||
quantity: Some(1)
|
||||
}])
|
||||
);
|
||||
assert_eq!(
|
||||
call.payment_method_collection,
|
||||
Some(StripeCheckoutSessionPaymentMethodCollection::IfRequired)
|
||||
);
|
||||
assert_eq!(
|
||||
call.subscription_data,
|
||||
Some(StripeCreateCheckoutSessionSubscriptionData {
|
||||
trial_period_days: Some(60),
|
||||
trial_settings: Some(StripeSubscriptionTrialSettings {
|
||||
end_behavior: StripeSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method:
|
||||
StripeSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
metadata: Some(std::collections::HashMap::from_iter([(
|
||||
"promo_feature_flag".into(),
|
||||
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.into()
|
||||
)])),
|
||||
})
|
||||
);
|
||||
assert_eq!(call.success_url.as_deref(), Some(success_url));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue