collab: Introduce StripeClient
trait to abstract over Stripe interactions (#31615)
This PR introduces a new `StripeClient` trait to abstract over interacting with the Stripe API. This will allow us to more easily test our billing code. This initial cut is small and focuses just on making `StripeBilling::find_or_create_customer_by_email` testable. I'll follow up with using the `StripeClient` in more places. Release Notes: - N/A
This commit is contained in:
parent
68724ea99e
commit
361ceee72b
10 changed files with 257 additions and 33 deletions
|
@ -20,6 +20,7 @@ test-support = ["sqlite"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
async-stripe.workspace = true
|
async-stripe.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
async-tungstenite.workspace = true
|
async-tungstenite.workspace = true
|
||||||
aws-config = { version = "1.1.5" }
|
aws-config = { version = "1.1.5" }
|
||||||
aws-sdk-s3 = { version = "1.15.0" }
|
aws-sdk-s3 = { version = "1.15.0" }
|
||||||
|
|
|
@ -344,6 +344,7 @@ async fn create_billing_subscription(
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
.await?
|
.await?
|
||||||
|
.try_into()?
|
||||||
};
|
};
|
||||||
|
|
||||||
let success_url = format!(
|
let success_url = format!(
|
||||||
|
|
|
@ -9,6 +9,7 @@ pub mod migrations;
|
||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
pub mod seed;
|
pub mod seed;
|
||||||
pub mod stripe_billing;
|
pub mod stripe_billing;
|
||||||
|
pub mod stripe_client;
|
||||||
pub mod user_backfiller;
|
pub mod user_backfiller;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -4039,7 +4039,8 @@ async fn get_llm_api_token(
|
||||||
} else {
|
} else {
|
||||||
let customer_id = stripe_billing
|
let customer_id = stripe_billing
|
||||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
.await?;
|
.await?
|
||||||
|
.try_into()?;
|
||||||
|
|
||||||
find_or_create_billing_customer(
|
find_or_create_billing_customer(
|
||||||
&session.app_state,
|
&session.app_state,
|
||||||
|
|
|
@ -1,19 +1,22 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::Result;
|
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
|
||||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::{Context as _, anyhow};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
use stripe::{PriceId, SubscriptionStatus};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
|
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||||
|
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
|
||||||
|
|
||||||
pub struct StripeBilling {
|
pub struct StripeBilling {
|
||||||
state: RwLock<StripeBillingState>,
|
state: RwLock<StripeBillingState>,
|
||||||
client: Arc<stripe::Client>,
|
real_client: Arc<stripe::Client>,
|
||||||
|
client: Arc<dyn StripeClient>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
@ -26,6 +29,17 @@ struct StripeBillingState {
|
||||||
impl StripeBilling {
|
impl StripeBilling {
|
||||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
client: Arc::new(RealStripeClient::new(client.clone())),
|
||||||
|
real_client: client,
|
||||||
|
state: RwLock::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn test(client: Arc<crate::stripe_client::FakeStripeClient>) -> Self {
|
||||||
|
Self {
|
||||||
|
// This is just temporary until we can remove all usages of the real Stripe client.
|
||||||
|
real_client: Arc::new(stripe::Client::new("sk_test")),
|
||||||
client,
|
client,
|
||||||
state: RwLock::default(),
|
state: RwLock::default(),
|
||||||
}
|
}
|
||||||
|
@ -37,9 +51,9 @@ impl StripeBilling {
|
||||||
let mut state = self.state.write().await;
|
let mut state = self.state.write().await;
|
||||||
|
|
||||||
let (meters, prices) = futures::try_join!(
|
let (meters, prices) = futures::try_join!(
|
||||||
StripeMeter::list(&self.client),
|
StripeMeter::list(&self.real_client),
|
||||||
stripe::Price::list(
|
stripe::Price::list(
|
||||||
&self.client,
|
&self.real_client,
|
||||||
&stripe::ListPrices {
|
&stripe::ListPrices {
|
||||||
limit: Some(100),
|
limit: Some(100),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -129,18 +143,11 @@ impl StripeBilling {
|
||||||
pub async fn find_or_create_customer_by_email(
|
pub async fn find_or_create_customer_by_email(
|
||||||
&self,
|
&self,
|
||||||
email_address: Option<&str>,
|
email_address: Option<&str>,
|
||||||
) -> Result<CustomerId> {
|
) -> Result<StripeCustomerId> {
|
||||||
let existing_customer = if let Some(email) = email_address {
|
let existing_customer = if let Some(email) = email_address {
|
||||||
let customers = Customer::list(
|
let customers = self.client.list_customers_by_email(email).await?;
|
||||||
&self.client,
|
|
||||||
&stripe::ListCustomers {
|
|
||||||
email: Some(email),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customers.data.first().cloned()
|
customers.first().cloned()
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
@ -148,14 +155,12 @@ impl StripeBilling {
|
||||||
let customer_id = if let Some(existing_customer) = existing_customer {
|
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||||
existing_customer.id
|
existing_customer.id
|
||||||
} else {
|
} else {
|
||||||
let customer = Customer::create(
|
let customer = self
|
||||||
&self.client,
|
.client
|
||||||
CreateCustomer {
|
.create_customer(crate::stripe_client::CreateCustomerParams {
|
||||||
email: email_address,
|
email: email_address,
|
||||||
..Default::default()
|
})
|
||||||
},
|
.await?;
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customer.id
|
customer.id
|
||||||
};
|
};
|
||||||
|
@ -169,7 +174,7 @@ impl StripeBilling {
|
||||||
price: &stripe::Price,
|
price: &stripe::Price,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let subscription =
|
let subscription =
|
||||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
|
||||||
|
|
||||||
if subscription_contains_price(&subscription, &price.id) {
|
if subscription_contains_price(&subscription, &price.id) {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
@ -181,7 +186,7 @@ impl StripeBilling {
|
||||||
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
|
||||||
|
|
||||||
stripe::Subscription::update(
|
stripe::Subscription::update(
|
||||||
&self.client,
|
&self.real_client,
|
||||||
subscription_id,
|
subscription_id,
|
||||||
stripe::UpdateSubscription {
|
stripe::UpdateSubscription {
|
||||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||||
|
@ -211,7 +216,7 @@ impl StripeBilling {
|
||||||
let idempotency_key = Uuid::new_v4();
|
let idempotency_key = Uuid::new_v4();
|
||||||
|
|
||||||
StripeMeterEvent::create(
|
StripeMeterEvent::create(
|
||||||
&self.client,
|
&self.real_client,
|
||||||
StripeCreateMeterEventParams {
|
StripeCreateMeterEventParams {
|
||||||
identifier: &format!("model_requests/{}", idempotency_key),
|
identifier: &format!("model_requests/{}", idempotency_key),
|
||||||
event_name,
|
event_name,
|
||||||
|
@ -246,7 +251,7 @@ impl StripeBilling {
|
||||||
}]);
|
}]);
|
||||||
params.success_url = Some(success_url);
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,7 +305,7 @@ impl StripeBilling {
|
||||||
}]);
|
}]);
|
||||||
params.success_url = Some(success_url);
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,7 +316,7 @@ impl StripeBilling {
|
||||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||||
|
|
||||||
let existing_subscriptions = stripe::Subscription::list(
|
let existing_subscriptions = stripe::Subscription::list(
|
||||||
&self.client,
|
&self.real_client,
|
||||||
&stripe::ListSubscriptions {
|
&stripe::ListSubscriptions {
|
||||||
customer: Some(customer_id.clone()),
|
customer: Some(customer_id.clone()),
|
||||||
status: None,
|
status: None,
|
||||||
|
@ -339,7 +344,7 @@ impl StripeBilling {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}]);
|
}]);
|
||||||
|
|
||||||
let subscription = stripe::Subscription::create(&self.client, params).await?;
|
let subscription = stripe::Subscription::create(&self.real_client, params).await?;
|
||||||
|
|
||||||
Ok(subscription)
|
Ok(subscription)
|
||||||
}
|
}
|
||||||
|
@ -365,7 +370,7 @@ impl StripeBilling {
|
||||||
}]);
|
}]);
|
||||||
params.success_url = Some(success_url);
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
let session = stripe::CheckoutSession::create(&self.real_client, params).await?;
|
||||||
Ok(session.url.context("no checkout session URL")?)
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
33
crates/collab/src/stripe_client.rs
Normal file
33
crates/collab/src/stripe_client.rs
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
#[cfg(test)]
|
||||||
|
mod fake_stripe_client;
|
||||||
|
mod real_stripe_client;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub use fake_stripe_client::*;
|
||||||
|
pub use real_stripe_client::*;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||||
|
pub struct StripeCustomerId(pub Arc<str>);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StripeCustomer {
|
||||||
|
pub id: StripeCustomerId,
|
||||||
|
pub email: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct CreateCustomerParams<'a> {
|
||||||
|
pub email: Option<&'a str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait StripeClient: Send + Sync {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
|
||||||
|
}
|
47
crates/collab/src/stripe_client/fake_stripe_client.rs
Normal file
47
crates/collab/src/stripe_client/fake_stripe_client.rs
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use collections::HashMap;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
|
||||||
|
|
||||||
|
pub struct FakeStripeClient {
|
||||||
|
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FakeStripeClient {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
customers: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StripeClient for FakeStripeClient {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||||
|
Ok(self
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.values()
|
||||||
|
.filter(|customer| customer.email.as_deref() == Some(email))
|
||||||
|
.cloned()
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||||
|
let customer = StripeCustomer {
|
||||||
|
id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()),
|
||||||
|
email: params.email.map(|email| email.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.customers
|
||||||
|
.lock()
|
||||||
|
.insert(customer.id.clone(), customer.clone());
|
||||||
|
|
||||||
|
Ok(customer)
|
||||||
|
}
|
||||||
|
}
|
74
crates/collab/src/stripe_client/real_stripe_client.rs
Normal file
74
crates/collab/src/stripe_client/real_stripe_client.rs
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
use std::str::FromStr as _;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use anyhow::{Context as _, Result};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
|
||||||
|
|
||||||
|
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
|
||||||
|
|
||||||
|
pub struct RealStripeClient {
|
||||||
|
client: Arc<stripe::Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RealStripeClient {
|
||||||
|
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StripeClient for RealStripeClient {
|
||||||
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>> {
|
||||||
|
let response = Customer::list(
|
||||||
|
&self.client,
|
||||||
|
&ListCustomers {
|
||||||
|
email: Some(email),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(StripeCustomer::from)
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer> {
|
||||||
|
let customer = Customer::create(
|
||||||
|
&self.client,
|
||||||
|
CreateCustomer {
|
||||||
|
email: params.email,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(StripeCustomer::from(customer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<CustomerId> for StripeCustomerId {
|
||||||
|
fn from(value: CustomerId) -> Self {
|
||||||
|
Self(value.as_str().into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<StripeCustomerId> for CustomerId {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: StripeCustomerId) -> Result<Self, Self::Error> {
|
||||||
|
Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Customer> for StripeCustomer {
|
||||||
|
fn from(value: Customer) -> Self {
|
||||||
|
StripeCustomer {
|
||||||
|
id: value.id.into(),
|
||||||
|
email: value.email,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,6 +18,7 @@ mod random_channel_buffer_tests;
|
||||||
mod random_project_collaboration_tests;
|
mod random_project_collaboration_tests;
|
||||||
mod randomized_test_helpers;
|
mod randomized_test_helpers;
|
||||||
mod remote_editing_collaboration_tests;
|
mod remote_editing_collaboration_tests;
|
||||||
|
mod stripe_billing_tests;
|
||||||
mod test_server;
|
mod test_server;
|
||||||
|
|
||||||
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||||
|
|
60
crates/collab/src/tests/stripe_billing_tests.rs
Normal file
60
crates/collab/src/tests/stripe_billing_tests.rs
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
use crate::stripe_billing::StripeBilling;
|
||||||
|
use crate::stripe_client::FakeStripeClient;
|
||||||
|
|
||||||
|
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||||
|
let stripe_client = Arc::new(FakeStripeClient::new());
|
||||||
|
let stripe_billing = StripeBilling::test(stripe_client.clone());
|
||||||
|
|
||||||
|
(stripe_billing, stripe_client)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_find_or_create_customer_by_email() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
// Create a customer with an email that doesn't yet correspond to a customer.
|
||||||
|
{
|
||||||
|
let email = "user@example.com";
|
||||||
|
|
||||||
|
let customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let customer = stripe_client
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.get(&customer_id)
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
assert_eq!(customer.email.as_deref(), Some(email));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a customer with an email that corresponds to an existing customer.
|
||||||
|
{
|
||||||
|
let email = "user2@example.com";
|
||||||
|
|
||||||
|
let existing_customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(Some(email))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(customer_id, existing_customer_id);
|
||||||
|
|
||||||
|
let customer = stripe_client
|
||||||
|
.customers
|
||||||
|
.lock()
|
||||||
|
.get(&customer_id)
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
assert_eq!(customer.email.as_deref(), Some(email));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue