diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 6ebeb0ced3..020aedbc57 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -20,6 +20,7 @@ test-support = ["sqlite"] [dependencies] anyhow.workspace = true async-stripe.workspace = true +async-trait.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } aws-sdk-s3 = { version = "1.15.0" } diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 83dcfde4f3..b6a559a538 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -344,6 +344,7 @@ async fn create_billing_subscription( stripe_billing .find_or_create_customer_by_email(user.email_address.as_deref()) .await? + .try_into()? }; let success_url = format!( diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 5f09a1e0aa..5819ad665c 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -9,6 +9,7 @@ pub mod migrations; pub mod rpc; pub mod seed; pub mod stripe_billing; +pub mod stripe_client; pub mod user_backfiller; #[cfg(test)] diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c35cf2e98b..5316304cb0 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4039,7 +4039,8 @@ async fn get_llm_api_token( } else { let customer_id = stripe_billing .find_or_create_customer_by_email(user.email_address.as_deref()) - .await?; + .await? + .try_into()?; find_or_create_billing_customer( &session.app_state, diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 13a1c75877..83eb9ef903 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,19 +1,22 @@ use std::sync::Arc; -use crate::Result; -use crate::db::billing_subscription::SubscriptionKind; -use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use anyhow::{Context as _, anyhow}; use chrono::Utc; use collections::HashMap; use serde::{Deserialize, Serialize}; -use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus}; +use stripe::{PriceId, SubscriptionStatus}; use tokio::sync::RwLock; use uuid::Uuid; +use crate::Result; +use crate::db::billing_subscription::SubscriptionKind; +use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; +use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId}; + pub struct StripeBilling { state: RwLock, - client: Arc, + real_client: Arc, + client: Arc, } #[derive(Default)] @@ -26,6 +29,17 @@ struct StripeBillingState { impl StripeBilling { pub fn new(client: Arc) -> Self { Self { + client: Arc::new(RealStripeClient::new(client.clone())), + real_client: client, + state: RwLock::default(), + } + } + + #[cfg(test)] + pub fn test(client: Arc) -> Self { + Self { + // This is just temporary until we can remove all usages of the real Stripe client. + real_client: Arc::new(stripe::Client::new("sk_test")), client, state: RwLock::default(), } @@ -37,9 +51,9 @@ impl StripeBilling { let mut state = self.state.write().await; let (meters, prices) = futures::try_join!( - StripeMeter::list(&self.client), + StripeMeter::list(&self.real_client), stripe::Price::list( - &self.client, + &self.real_client, &stripe::ListPrices { limit: Some(100), ..Default::default() @@ -129,18 +143,11 @@ impl StripeBilling { pub async fn find_or_create_customer_by_email( &self, email_address: Option<&str>, - ) -> Result { + ) -> Result { let existing_customer = if let Some(email) = email_address { - let customers = Customer::list( - &self.client, - &stripe::ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; + let customers = self.client.list_customers_by_email(email).await?; - customers.data.first().cloned() + customers.first().cloned() } else { None }; @@ -148,14 +155,12 @@ impl StripeBilling { let customer_id = if let Some(existing_customer) = existing_customer { existing_customer.id } else { - let customer = Customer::create( - &self.client, - CreateCustomer { + let customer = self + .client + .create_customer(crate::stripe_client::CreateCustomerParams { email: email_address, - ..Default::default() - }, - ) - .await?; + }) + .await?; customer.id }; @@ -169,7 +174,7 @@ impl StripeBilling { price: &stripe::Price, ) -> Result<()> { let subscription = - stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; + stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?; if subscription_contains_price(&subscription, &price.id) { return Ok(()); @@ -181,7 +186,7 @@ impl StripeBilling { let _units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit; stripe::Subscription::update( - &self.client, + &self.real_client, subscription_id, stripe::UpdateSubscription { items: Some(vec![stripe::UpdateSubscriptionItems { @@ -211,7 +216,7 @@ impl StripeBilling { let idempotency_key = Uuid::new_v4(); StripeMeterEvent::create( - &self.client, + &self.real_client, StripeCreateMeterEventParams { identifier: &format!("model_requests/{}", idempotency_key), event_name, @@ -246,7 +251,7 @@ impl StripeBilling { }]); params.success_url = Some(success_url); - let session = stripe::CheckoutSession::create(&self.client, params).await?; + let session = stripe::CheckoutSession::create(&self.real_client, params).await?; Ok(session.url.context("no checkout session URL")?) } @@ -300,7 +305,7 @@ impl StripeBilling { }]); params.success_url = Some(success_url); - let session = stripe::CheckoutSession::create(&self.client, params).await?; + let session = stripe::CheckoutSession::create(&self.real_client, params).await?; Ok(session.url.context("no checkout session URL")?) } @@ -311,7 +316,7 @@ impl StripeBilling { let zed_free_price_id = self.zed_free_price_id().await?; let existing_subscriptions = stripe::Subscription::list( - &self.client, + &self.real_client, &stripe::ListSubscriptions { customer: Some(customer_id.clone()), status: None, @@ -339,7 +344,7 @@ impl StripeBilling { ..Default::default() }]); - let subscription = stripe::Subscription::create(&self.client, params).await?; + let subscription = stripe::Subscription::create(&self.real_client, params).await?; Ok(subscription) } @@ -365,7 +370,7 @@ impl StripeBilling { }]); params.success_url = Some(success_url); - let session = stripe::CheckoutSession::create(&self.client, params).await?; + let session = stripe::CheckoutSession::create(&self.real_client, params).await?; Ok(session.url.context("no checkout session URL")?) } } diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs new file mode 100644 index 0000000000..23d8fb34c1 --- /dev/null +++ b/crates/collab/src/stripe_client.rs @@ -0,0 +1,33 @@ +#[cfg(test)] +mod fake_stripe_client; +mod real_stripe_client; + +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; + +#[cfg(test)] +pub use fake_stripe_client::*; +pub use real_stripe_client::*; + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct StripeCustomerId(pub Arc); + +#[derive(Debug, Clone)] +pub struct StripeCustomer { + pub id: StripeCustomerId, + pub email: Option, +} + +#[derive(Debug)] +pub struct CreateCustomerParams<'a> { + pub email: Option<&'a str>, +} + +#[async_trait] +pub trait StripeClient: Send + Sync { + async fn list_customers_by_email(&self, email: &str) -> Result>; + + async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; +} diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs new file mode 100644 index 0000000000..5f526082fc --- /dev/null +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use collections::HashMap; +use parking_lot::Mutex; +use uuid::Uuid; + +use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId}; + +pub struct FakeStripeClient { + pub customers: Arc>>, +} + +impl FakeStripeClient { + pub fn new() -> Self { + Self { + customers: Arc::new(Mutex::new(HashMap::default())), + } + } +} + +#[async_trait] +impl StripeClient for FakeStripeClient { + async fn list_customers_by_email(&self, email: &str) -> Result> { + Ok(self + .customers + .lock() + .values() + .filter(|customer| customer.email.as_deref() == Some(email)) + .cloned() + .collect()) + } + + async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { + let customer = StripeCustomer { + id: StripeCustomerId(format!("cus_{}", Uuid::new_v4()).into()), + email: params.email.map(|email| email.to_string()), + }; + + self.customers + .lock() + .insert(customer.id.clone(), customer.clone()); + + Ok(customer) + } +} diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs new file mode 100644 index 0000000000..a9480eda59 --- /dev/null +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -0,0 +1,74 @@ +use std::str::FromStr as _; +use std::sync::Arc; + +use anyhow::{Context as _, Result}; +use async_trait::async_trait; +use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers}; + +use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId}; + +pub struct RealStripeClient { + client: Arc, +} + +impl RealStripeClient { + pub fn new(client: Arc) -> Self { + Self { client } + } +} + +#[async_trait] +impl StripeClient for RealStripeClient { + async fn list_customers_by_email(&self, email: &str) -> Result> { + let response = Customer::list( + &self.client, + &ListCustomers { + email: Some(email), + ..Default::default() + }, + ) + .await?; + + Ok(response + .data + .into_iter() + .map(StripeCustomer::from) + .collect()) + } + + async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result { + let customer = Customer::create( + &self.client, + CreateCustomer { + email: params.email, + ..Default::default() + }, + ) + .await?; + + Ok(StripeCustomer::from(customer)) + } +} + +impl From for StripeCustomerId { + fn from(value: CustomerId) -> Self { + Self(value.as_str().into()) + } +} + +impl TryFrom for CustomerId { + type Error = anyhow::Error; + + fn try_from(value: StripeCustomerId) -> Result { + Self::from_str(value.0.as_ref()).context("failed to parse Stripe customer ID") + } +} + +impl From for StripeCustomer { + fn from(value: Customer) -> Self { + StripeCustomer { + id: value.id.into(), + email: value.email, + } + } +} diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 6ddb349700..19e410de5b 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -18,6 +18,7 @@ mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; mod remote_editing_collaboration_tests; +mod stripe_billing_tests; mod test_server; use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs new file mode 100644 index 0000000000..db8161b8b5 --- /dev/null +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use pretty_assertions::assert_eq; + +use crate::stripe_billing::StripeBilling; +use crate::stripe_client::FakeStripeClient; + +fn make_stripe_billing() -> (StripeBilling, Arc) { + let stripe_client = Arc::new(FakeStripeClient::new()); + let stripe_billing = StripeBilling::test(stripe_client.clone()); + + (stripe_billing, stripe_client) +} + +#[gpui::test] +async fn test_find_or_create_customer_by_email() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + // Create a customer with an email that doesn't yet correspond to a customer. + { + let email = "user@example.com"; + + let customer_id = stripe_billing + .find_or_create_customer_by_email(Some(email)) + .await + .unwrap(); + + let customer = stripe_client + .customers + .lock() + .get(&customer_id) + .unwrap() + .clone(); + assert_eq!(customer.email.as_deref(), Some(email)); + } + + // Create a customer with an email that corresponds to an existing customer. + { + let email = "user2@example.com"; + + let existing_customer_id = stripe_billing + .find_or_create_customer_by_email(Some(email)) + .await + .unwrap(); + + let customer_id = stripe_billing + .find_or_create_customer_by_email(Some(email)) + .await + .unwrap(); + assert_eq!(customer_id, existing_customer_id); + + let customer = stripe_client + .customers + .lock() + .get(&customer_id) + .unwrap() + .clone(); + assert_eq!(customer.email.as_deref(), Some(email)); + } +}