collab: Use StripeClient
in sync_subscription
(#31761)
This PR updates the `sync_subscription` function to use the `StripeClient` trait instead of using `stripe::Client` directly. Release Notes: - N/A --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
07436b4284
commit
f725b5e248
10 changed files with 177 additions and 75 deletions
|
@ -17,8 +17,8 @@ use stripe::{
|
|||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
|
||||
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, CustomerId, EventObject, EventType, ListEvents,
|
||||
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
|
@ -29,7 +29,10 @@ use crate::db::billing_subscription::{
|
|||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
|
||||
use crate::stripe_client::{
|
||||
StripeCancellationDetailsReason, StripeClient, StripeCustomerId, StripeSubscription,
|
||||
StripeSubscriptionId,
|
||||
};
|
||||
use crate::{AppState, Error, Result};
|
||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||
use crate::{
|
||||
|
@ -426,7 +429,7 @@ async fn manage_billing_subscription(
|
|||
.await?
|
||||
.context("user not found")?;
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
let Some(stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
|
@ -644,7 +647,7 @@ async fn migrate_to_new_billing(
|
|||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<MigrateToNewBillingBody>,
|
||||
) -> Result<Json<MigrateToNewBillingResponse>> {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
let Some(stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
|
@ -723,6 +726,13 @@ async fn sync_billing_subscription(
|
|||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(body): extract::Json<SyncBillingSubscriptionBody>,
|
||||
) -> Result<Json<SyncBillingSubscriptionResponse>> {
|
||||
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
|
@ -748,7 +758,7 @@ async fn sync_billing_subscription(
|
|||
.context("failed to parse Stripe customer ID from database")?;
|
||||
|
||||
let subscriptions = Subscription::list(
|
||||
&stripe_client,
|
||||
&real_stripe_client,
|
||||
&stripe::ListSubscriptions {
|
||||
customer: Some(stripe_customer_id),
|
||||
// Sync all non-canceled subscriptions.
|
||||
|
@ -761,7 +771,7 @@ async fn sync_billing_subscription(
|
|||
for subscription in subscriptions.data {
|
||||
let subscription_id = subscription.id.clone();
|
||||
|
||||
sync_subscription(&app, &stripe_client, subscription)
|
||||
sync_subscription(&app, &stripe_client, subscription.into())
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
|
@ -806,6 +816,10 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
|||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||
let Some(real_stripe_client) = app.real_stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
|
@ -816,7 +830,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
|||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client)
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client, &real_stripe_client)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
|
@ -829,7 +843,8 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Serve
|
|||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
real_stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||
|
@ -861,7 +876,7 @@ async fn poll_stripe_events(
|
|||
params.types = Some(event_types.clone());
|
||||
params.limit = Some(EVENTS_LIMIT_PER_PAGE);
|
||||
|
||||
let mut event_pages = stripe::Event::list(&stripe_client, ¶ms)
|
||||
let mut event_pages = stripe::Event::list(&real_stripe_client, ¶ms)
|
||||
.await?
|
||||
.paginate(params);
|
||||
|
||||
|
@ -905,7 +920,7 @@ async fn poll_stripe_events(
|
|||
break;
|
||||
} else {
|
||||
log::info!("Stripe events: retrieving next page");
|
||||
event_pages = event_pages.next(&stripe_client).await?;
|
||||
event_pages = event_pages.next(&real_stripe_client).await?;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
|
@ -945,7 +960,7 @@ async fn poll_stripe_events(
|
|||
|
||||
let process_result = match event.type_ {
|
||||
EventType::CustomerCreated | EventType::CustomerUpdated => {
|
||||
handle_customer_event(app, stripe_client, event).await
|
||||
handle_customer_event(app, real_stripe_client, event).await
|
||||
}
|
||||
EventType::CustomerSubscriptionCreated
|
||||
| EventType::CustomerSubscriptionUpdated
|
||||
|
@ -1020,8 +1035,8 @@ async fn handle_customer_event(
|
|||
|
||||
async fn sync_subscription(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
subscription: stripe::Subscription,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
subscription: StripeSubscription,
|
||||
) -> anyhow::Result<billing_customer::Model> {
|
||||
let subscription_kind = if let Some(stripe_billing) = &app.stripe_billing {
|
||||
stripe_billing
|
||||
|
@ -1032,7 +1047,7 @@ async fn sync_subscription(
|
|||
};
|
||||
|
||||
let billing_customer =
|
||||
find_or_create_billing_customer(app, stripe_client, subscription.customer)
|
||||
find_or_create_billing_customer(app, stripe_client.as_ref(), &subscription.customer)
|
||||
.await?
|
||||
.context("billing customer not found")?;
|
||||
|
||||
|
@ -1060,7 +1075,7 @@ async fn sync_subscription(
|
|||
.as_ref()
|
||||
.and_then(|details| details.reason)
|
||||
.map_or(false, |reason| {
|
||||
reason == CancellationDetailsReason::PaymentFailed
|
||||
reason == StripeCancellationDetailsReason::PaymentFailed
|
||||
});
|
||||
|
||||
if was_canceled_due_to_payment_failure {
|
||||
|
@ -1077,7 +1092,7 @@ async fn sync_subscription(
|
|||
|
||||
if let Some(existing_subscription) = app
|
||||
.db
|
||||
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
|
||||
.get_billing_subscription_by_stripe_subscription_id(subscription.id.0.as_ref())
|
||||
.await?
|
||||
{
|
||||
app.db
|
||||
|
@ -1118,20 +1133,13 @@ async fn sync_subscription(
|
|||
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
|
||||
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
|
||||
{
|
||||
let stripe_subscription_id = existing_subscription
|
||||
.stripe_subscription_id
|
||||
.parse::<stripe::SubscriptionId>()
|
||||
.context("failed to parse Stripe subscription ID from database")?;
|
||||
let stripe_subscription_id = StripeSubscriptionId(
|
||||
existing_subscription.stripe_subscription_id.clone().into(),
|
||||
);
|
||||
|
||||
Subscription::cancel(
|
||||
&stripe_client,
|
||||
&stripe_subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
stripe_client
|
||||
.cancel_subscription(&stripe_subscription_id)
|
||||
.await?;
|
||||
} else {
|
||||
// If the user already has an active billing subscription, ignore the
|
||||
// event and return an `Ok` to signal that it was processed
|
||||
|
@ -1198,7 +1206,7 @@ async fn sync_subscription(
|
|||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
let EventObject::Subscription(subscription) = event.data.object else {
|
||||
|
@ -1207,7 +1215,7 @@ async fn handle_customer_subscription_event(
|
|||
|
||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||
|
||||
let billing_customer = sync_subscription(app, stripe_client, subscription).await?;
|
||||
let billing_customer = sync_subscription(app, stripe_client, subscription.into()).await?;
|
||||
|
||||
// When the user's subscription changes, push down any changes to their plan.
|
||||
rpc_server
|
||||
|
@ -1403,30 +1411,20 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
|
|||
/// Finds or creates a billing customer using the provided customer.
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
customer_or_id: Expandable<Customer>,
|
||||
stripe_client: &dyn StripeClient,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> anyhow::Result<Option<billing_customer::Model>> {
|
||||
let customer_id = match &customer_or_id {
|
||||
Expandable::Id(id) => id,
|
||||
Expandable::Object(customer) => customer.id.as_ref(),
|
||||
};
|
||||
|
||||
// If we already have a billing customer record associated with the Stripe customer,
|
||||
// there's nothing more we need to do.
|
||||
if let Some(billing_customer) = app
|
||||
.db
|
||||
.get_billing_customer_by_stripe_customer_id(customer_id)
|
||||
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(billing_customer));
|
||||
}
|
||||
|
||||
// If all we have is a customer ID, resolve it to a full customer record by
|
||||
// hitting the Stripe API.
|
||||
let customer = match customer_or_id {
|
||||
Expandable::Id(id) => Customer::retrieve(stripe_client, &id, &[]).await?,
|
||||
Expandable::Object(customer) => *customer,
|
||||
};
|
||||
let customer = stripe_client.get_customer(customer_id).await?;
|
||||
|
||||
let Some(email) = customer.email else {
|
||||
return Ok(None);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
||||
use crate::stripe_client;
|
||||
use chrono::{Datelike as _, NaiveDate, Utc};
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
@ -159,3 +160,17 @@ pub enum StripeCancellationReason {
|
|||
#[sea_orm(string_value = "payment_failed")]
|
||||
PaymentFailed,
|
||||
}
|
||||
|
||||
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
|
||||
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
|
||||
Self::CancellationRequested
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
|
||||
Self::PaymentDisputed
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ use std::{path::PathBuf, sync::Arc};
|
|||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{RealStripeClient, StripeClient};
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
|
@ -270,7 +271,10 @@ pub struct AppState {
|
|||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
/// This is a real instance of the Stripe client; we're working to replace references to this with the
|
||||
/// [`StripeClient`] trait.
|
||||
pub real_stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_client: Option<Arc<dyn StripeClient>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub executor: Executor,
|
||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||
|
@ -323,7 +327,9 @@ impl AppState {
|
|||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
real_stripe_client: stripe_client.clone(),
|
||||
stripe_client: stripe_client
|
||||
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
|
||||
executor,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
|
|
|
@ -4034,23 +4034,19 @@ async fn get_llm_api_token(
|
|||
.as_ref()
|
||||
.context("failed to retrieve Stripe billing object")?;
|
||||
|
||||
let billing_customer =
|
||||
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?
|
||||
.try_into()?;
|
||||
let billing_customer = if let Some(billing_customer) =
|
||||
db.get_billing_customer_by_user_id(user.id).await?
|
||||
{
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?;
|
||||
|
||||
find_or_create_billing_customer(
|
||||
&session.app_state,
|
||||
&stripe_client,
|
||||
stripe::Expandable::Id(customer_id),
|
||||
)
|
||||
find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
|
||||
.await?
|
||||
.context("billing customer not found")?
|
||||
};
|
||||
};
|
||||
|
||||
let billing_subscription =
|
||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||
|
|
|
@ -111,14 +111,12 @@ impl StripeBilling {
|
|||
|
||||
pub async fn determine_subscription_kind(
|
||||
&self,
|
||||
subscription: &stripe::Subscription,
|
||||
subscription: &StripeSubscription,
|
||||
) -> Option<SubscriptionKind> {
|
||||
let zed_pro_price_id: stripe::PriceId =
|
||||
self.zed_pro_price_id().await.ok()?.try_into().ok()?;
|
||||
let zed_free_price_id: stripe::PriceId =
|
||||
self.zed_free_price_id().await.ok()?.try_into().ok()?;
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
|
||||
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
|
||||
|
||||
subscription.items.data.iter().find_map(|item| {
|
||||
subscription.items.iter().find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
||||
if price.id == zed_pro_price_id {
|
||||
|
|
|
@ -39,6 +39,8 @@ pub struct StripeSubscription {
|
|||
pub current_period_end: i64,
|
||||
pub current_period_start: i64,
|
||||
pub items: Vec<StripeSubscriptionItem>,
|
||||
pub cancel_at: Option<i64>,
|
||||
pub cancellation_details: Option<StripeCancellationDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
|
@ -50,6 +52,18 @@ pub struct StripeSubscriptionItem {
|
|||
pub price: Option<StripePrice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct StripeCancellationDetails {
|
||||
pub reason: Option<StripeCancellationDetailsReason>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum StripeCancellationDetailsReason {
|
||||
CancellationRequested,
|
||||
PaymentDisputed,
|
||||
PaymentFailed,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StripeCreateSubscriptionParams {
|
||||
pub customer: StripeCustomerId,
|
||||
|
@ -175,6 +189,8 @@ pub struct StripeCheckoutSession {
|
|||
pub trait StripeClient: Send + Sync {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer>;
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
|
||||
|
||||
async fn list_subscriptions_for_customer(
|
||||
|
@ -198,6 +214,8 @@ pub trait StripeClient: Send + Sync {
|
|||
params: UpdateSubscriptionParams,
|
||||
) -> Result<()>;
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()>;
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||
|
|
|
@ -74,6 +74,14 @@ impl StripeClient for FakeStripeClient {
|
|||
.collect())
|
||||
}
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||
self.customers
|
||||
.lock()
|
||||
.get(customer_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("no customer found for {customer_id:?}"))
|
||||
}
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||
let customer = StripeCustomer {
|
||||
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
|
||||
|
@ -135,6 +143,8 @@ impl StripeClient for FakeStripeClient {
|
|||
.and_then(|price_id| self.prices.lock().get(&price_id).cloned()),
|
||||
})
|
||||
.collect(),
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
|
||||
self.subscriptions
|
||||
|
@ -158,6 +168,13 @@ impl StripeClient for FakeStripeClient {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||
// TODO: Implement fake subscription cancellation.
|
||||
let _ = subscription_id;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let prices = self.prices.lock().values().cloned().collect();
|
||||
|
||||
|
|
|
@ -5,9 +5,9 @@ use anyhow::{Context as _, Result, anyhow};
|
|||
use async_trait::async_trait;
|
||||
use serde::Serialize;
|
||||
use stripe::{
|
||||
CheckoutSession, CheckoutSessionMode, CheckoutSessionPaymentMethodCollection,
|
||||
CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateCheckoutSessionSubscriptionData,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettings,
|
||||
CancellationDetails, CancellationDetailsReason, CheckoutSession, CheckoutSessionMode,
|
||||
CheckoutSessionPaymentMethodCollection, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCheckoutSessionSubscriptionData, CreateCheckoutSessionSubscriptionDataTrialSettings,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior,
|
||||
CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||
CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring, Subscription,
|
||||
|
@ -17,9 +17,9 @@ use stripe::{
|
|||
};
|
||||
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeCheckoutSession, StripeCheckoutSessionMode,
|
||||
StripeCheckoutSessionPaymentMethodCollection, StripeClient,
|
||||
StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
CreateCustomerParams, StripeCancellationDetails, StripeCancellationDetailsReason,
|
||||
StripeCheckoutSession, StripeCheckoutSessionMode, StripeCheckoutSessionPaymentMethodCollection,
|
||||
StripeClient, StripeCreateCheckoutSessionLineItems, StripeCreateCheckoutSessionParams,
|
||||
StripeCreateCheckoutSessionSubscriptionData, StripeCreateMeterEventParams,
|
||||
StripeCreateSubscriptionParams, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
|
||||
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
|
||||
|
@ -57,6 +57,14 @@ impl StripeClient for RealStripeClient {
|
|||
.collect())
|
||||
}
|
||||
|
||||
async fn get_customer(&self, customer_id: &StripeCustomerId) -> Result<StripeCustomer> {
|
||||
let customer_id = customer_id.try_into()?;
|
||||
|
||||
let customer = Customer::retrieve(&self.client, &customer_id, &[]).await?;
|
||||
|
||||
Ok(StripeCustomer::from(customer))
|
||||
}
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||
let customer = Customer::create(
|
||||
&self.client,
|
||||
|
@ -157,6 +165,22 @@ impl StripeClient for RealStripeClient {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_subscription(&self, subscription_id: &StripeSubscriptionId) -> Result<()> {
|
||||
let subscription_id = subscription_id.try_into()?;
|
||||
|
||||
Subscription::cancel(
|
||||
&self.client,
|
||||
&subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let response = stripe::Price::list(
|
||||
&self.client,
|
||||
|
@ -273,6 +297,26 @@ impl From<Subscription> for StripeSubscription {
|
|||
current_period_start: value.current_period_start,
|
||||
current_period_end: value.current_period_end,
|
||||
items: value.items.data.into_iter().map(Into::into).collect(),
|
||||
cancel_at: value.cancel_at,
|
||||
cancellation_details: value.cancellation_details.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CancellationDetails> for StripeCancellationDetails {
|
||||
fn from(value: CancellationDetails) -> Self {
|
||||
Self {
|
||||
reason: value.reason.map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CancellationDetailsReason> for StripeCancellationDetailsReason {
|
||||
fn from(value: CancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
CancellationDetailsReason::CancellationRequested => Self::CancellationRequested,
|
||||
CancellationDetailsReason::PaymentDisputed => Self::PaymentDisputed,
|
||||
CancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,6 +172,8 @@ async fn test_subscribe_to_price() {
|
|||
current_period_start: now.timestamp(),
|
||||
current_period_end: (now + Duration::days(30)).timestamp(),
|
||||
items: vec![],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client
|
||||
.subscriptions
|
||||
|
@ -211,6 +213,8 @@ async fn test_subscribe_to_price() {
|
|||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client
|
||||
.subscriptions
|
||||
|
@ -280,6 +284,8 @@ async fn test_subscribe_to_zed_free() {
|
|||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(zed_pro_price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client.subscriptions.lock().insert(
|
||||
existing_subscription.id.clone(),
|
||||
|
@ -309,6 +315,8 @@ async fn test_subscribe_to_zed_free() {
|
|||
id: StripeSubscriptionItemId("si_test".into()),
|
||||
price: Some(zed_pro_price.clone()),
|
||||
}],
|
||||
cancel_at: None,
|
||||
cancellation_details: None,
|
||||
};
|
||||
stripe_client.subscriptions.lock().insert(
|
||||
existing_subscription.id.clone(),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::stripe_client::FakeStripeClient;
|
||||
use crate::{
|
||||
AppState, Config,
|
||||
db::{NewUserParams, UserId, tests::TestDb},
|
||||
|
@ -522,7 +523,8 @@ impl TestServer {
|
|||
llm_db: None,
|
||||
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
real_stripe_client: None,
|
||||
stripe_client: Some(Arc::new(FakeStripeClient::new())),
|
||||
stripe_billing: None,
|
||||
executor,
|
||||
kinesis_client: None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue