collab: Introduce StripeClient
trait to abstract over Stripe interactions (#31615)
This PR introduces a new `StripeClient` trait to abstract over interacting with the Stripe API. This will allow us to more easily test our billing code. This initial cut is small and focuses just on making `StripeBilling::find_or_create_customer_by_email` testable. I'll follow up with using the `StripeClient` in more places. Release Notes: - N/A
This commit is contained in:
parent
68724ea99e
commit
361ceee72b
10 changed files with 257 additions and 33 deletions
|
@ -1,19 +1,22 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
||||
use stripe::{PriceId, SubscriptionStatus};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<stripe::Client>,
|
||||
real_client: Arc<stripe::Client>,
|
||||
client: Arc<dyn StripeClient>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -26,6 +29,17 @@ struct StripeBillingState {
|
|||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client: Arc::new(RealStripeClient::new(client.clone())),
|
||||
real_client: client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
|
||||
Self {
|
||||
// This is just temporary until we can remove all usages of the real Stripe client.
|
||||
real_client: Arc::new(stripe::Client::new("sk_test")),
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
}
|
||||
|
@ -37,9 +51,9 @@ impl StripeBilling {
|
|||
let mut state = self.state.write().await;
|
||||
|
||||
let (meters, prices) = futures::try_join!(
|
||||
StripeMeter::list(&self.client),
|
||||
StripeMeter::list(&self.real_client),
|
||||
stripe::Price::list(
|
||||
&self.client,
|
||||
&self.real_client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
|
@ -129,18 +143,11 @@ impl StripeBilling {
|
|||
pub async fn find_or_create_customer_by_email(
|
||||
&self,
|
||||
email_address: Option<&str>,
|
||||
) -> Result<CustomerId> {
|
||||
) -> Result<StripeCustomerId> {
|
||||
let existing_customer = if let Some(email) = email_address {
|
||||
let customers = Customer::list(
|
||||
&self.client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let customers = self.client.list_customers_by_email(email).await?;
|
||||
|
||||
customers.data.first().cloned()
|
||||
customers.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
@ -148,14 +155,12 @@ impl StripeBilling {
|
|||
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&self.client,
|
||||
CreateCustomer {
|
||||
let customer = self
|
||||
.client
|
||||
.create_customer(crate::stripe_client::CreateCustomerParams {
|
||||
email: email_address,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
})
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
};
|
||||
|
@ -169,7 +174,7 @@ impl StripeBilling {
|
|||
price: &stripe::Price,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
|
||||
|
||||
if subscription_contains_price(&subscription, &price.id) {
|
||||
return Ok(());
|
||||
|
@ -181,7 +186,7 @@ impl StripeBilling {
|
|||
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
&self.real_client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||
|
@ -211,7 +216,7 @@ impl StripeBilling {
|
|||
let idempotency_key = Uuid::new_v4();
|
||||
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
&self.real_client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("model_requests/{}", idempotency_key),
|
||||
event_name,
|
||||
|
@ -246,7 +251,7 @@ impl StripeBilling {
|
|||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
|
@ -300,7 +305,7 @@ impl StripeBilling {
|
|||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
|
@ -311,7 +316,7 @@ impl StripeBilling {
|
|||
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||
|
||||
let existing_subscriptions = stripe::Subscription::list(
|
||||
&self.client,
|
||||
&self.real_client,
|
||||
&stripe::ListSubscriptions {
|
||||
customer: Some(customer_id.clone()),
|
||||
status: None,
|
||||
|
@ -339,7 +344,7 @@ impl StripeBilling {
|
|||
..Default::default()
|
||||
}]);
|
||||
|
||||
let subscription = stripe::Subscription::create(&self.client, params).await?;
|
||||
let subscription = stripe::Subscription::create(&self.real_client, params).await?;
|
||||
|
||||
Ok(subscription)
|
||||
}
|
||||
|
@ -365,7 +370,7 @@ impl StripeBilling {
|
|||
}]);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue