diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index b6a559a538..0edfa2f1bf 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -499,8 +499,10 @@ async fn manage_billing_subscription( let flow = match body.intent { ManageSubscriptionIntent::ManageSubscription => None, ManageSubscriptionIntent::UpgradeToPro => { - let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?; - let zed_free_price_id = stripe_billing.zed_free_price_id().await?; + let zed_pro_price_id: stripe::PriceId = + stripe_billing.zed_pro_price_id().await?.try_into()?; + let zed_free_price_id: stripe::PriceId = + stripe_billing.zed_free_price_id().await?.try_into()?; let stripe_subscription = Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 83eb9ef903..e8a716d206 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -4,14 +4,16 @@ use anyhow::{Context as _, anyhow}; use chrono::Utc; use collections::HashMap; use serde::{Deserialize, Serialize}; -use stripe::{PriceId, SubscriptionStatus}; +use stripe::SubscriptionStatus; use tokio::sync::RwLock; use uuid::Uuid; use crate::Result; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; -use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId}; +use crate::stripe_client::{ + RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, +}; pub struct StripeBilling { state: RwLock, @@ -22,8 +24,8 @@ pub struct StripeBilling { #[derive(Default)] struct StripeBillingState { meters_by_event_name: HashMap, - price_ids_by_meter_id: HashMap, - prices_by_lookup_key: HashMap, + price_ids_by_meter_id: HashMap, + prices_by_lookup_key: HashMap, } impl StripeBilling { @@ -50,24 +52,16 @@ impl StripeBilling { let mut state = self.state.write().await; - let (meters, prices) = futures::try_join!( - StripeMeter::list(&self.real_client), - stripe::Price::list( - &self.real_client, - &stripe::ListPrices { - limit: Some(100), - ..Default::default() - } - ) - )?; + let (meters, prices) = + futures::try_join!(self.client.list_meters(), self.client.list_prices())?; - for meter in meters.data { + for meter in meters { state .meters_by_event_name .insert(meter.event_name.clone(), meter); } - for price in prices.data { + for price in prices { if let Some(lookup_key) = price.lookup_key.clone() { state.prices_by_lookup_key.insert(lookup_key, price.clone()); } @@ -84,15 +78,15 @@ impl StripeBilling { Ok(()) } - pub async fn zed_pro_price_id(&self) -> Result { + pub async fn zed_pro_price_id(&self) -> Result { self.find_price_id_by_lookup_key("zed-pro").await } - pub async fn zed_free_price_id(&self) -> Result { + pub async fn zed_free_price_id(&self) -> Result { self.find_price_id_by_lookup_key("zed-free").await } - pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { + pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { self.state .read() .await @@ -102,7 +96,7 @@ impl StripeBilling { .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) } - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { + pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { self.state .read() .await @@ -116,8 +110,10 @@ impl StripeBilling { &self, subscription: &stripe::Subscription, ) -> Option { - let zed_pro_price_id = self.zed_pro_price_id().await.ok()?; - let zed_free_price_id = self.zed_free_price_id().await.ok()?; + let zed_pro_price_id: stripe::PriceId = + self.zed_pro_price_id().await.ok()?.try_into().ok()?; + let zed_free_price_id: stripe::PriceId = + self.zed_free_price_id().await.ok()?.try_into().ok()?; subscription.items.data.iter().find_map(|item| { let price = item.price.as_ref()?; @@ -171,12 +167,13 @@ impl StripeBilling { pub async fn subscribe_to_price( &self, subscription_id: &stripe::SubscriptionId, - price: &stripe::Price, + price: &StripePrice, ) -> Result<()> { let subscription = stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?; - if subscription_contains_price(&subscription, &price.id) { + let price_id = price.id.clone().try_into()?; + if subscription_contains_price(&subscription, &price_id) { return Ok(()); } @@ -375,24 +372,6 @@ impl StripeBilling { } } -#[derive(Clone, Deserialize)] -struct StripeMeter { - id: String, - event_name: String, -} - -impl StripeMeter { - pub fn list(client: &stripe::Client) -> stripe::Response> { - #[derive(Serialize)] - struct Params { - #[serde(skip_serializing_if = "Option::is_none")] - limit: Option, - } - - client.get_query("/billing/meters", Params { limit: Some(100) }) - } -} - #[derive(Deserialize)] struct StripeMeterEvent { identifier: String, diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 23d8fb34c1..5fcb139a7e 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -10,8 +10,9 @@ use async_trait::async_trait; #[cfg(test)] pub use fake_stripe_client::*; pub use real_stripe_client::*; +use serde::Deserialize; -#[derive(Debug, PartialEq, Eq, Hash, Clone)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] pub struct StripeCustomerId(pub Arc); #[derive(Debug, Clone)] @@ -25,9 +26,38 @@ pub struct CreateCustomerParams<'a> { pub email: Option<&'a str>, } +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] +pub struct StripePriceId(pub Arc); + +#[derive(Debug, Clone)] +pub struct StripePrice { + pub id: StripePriceId, + pub unit_amount: Option, + pub lookup_key: Option, + pub recurring: Option, +} + +#[derive(Debug, Clone)] +pub struct StripePriceRecurring { + pub meter: Option, +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)] +pub struct StripeMeterId(pub Arc); + +#[derive(Debug, Clone, Deserialize)] +pub struct StripeMeter { + pub id: StripeMeterId, + pub event_name: String, +} + #[async_trait] pub trait StripeClient: Send + Sync { async fn list_customers_by_email(&self, email: &str) -> Result>; async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result; + + async fn list_prices(&self) -> Result>; + + async fn list_meters(&self) -> Result>; } diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 5f526082fc..9c1d407215 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -6,16 +6,23 @@ use collections::HashMap; use parking_lot::Mutex; use uuid::Uuid; -use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId}; +use crate::stripe_client::{ + CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, + StripeMeterId, StripePrice, StripePriceId, +}; pub struct FakeStripeClient { pub customers: Arc>>, + pub prices: Arc>>, + pub meters: Arc>>, } impl FakeStripeClient { pub fn new() -> Self { Self { customers: Arc::new(Mutex::new(HashMap::default())), + prices: Arc::new(Mutex::new(HashMap::default())), + meters: Arc::new(Mutex::new(HashMap::default())), } } } @@ -44,4 +51,16 @@ impl StripeClient for FakeStripeClient { Ok(customer) } + + async fn list_prices(&self) -> Result> { + let prices = self.prices.lock().values().cloned().collect(); + + Ok(prices) + } + + async fn list_meters(&self) -> Result> { + let meters = self.meters.lock().values().cloned().collect(); + + Ok(meters) + } } diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index a9480eda59..9ea07a2979 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -3,9 +3,13 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use async_trait::async_trait; -use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers}; +use serde::Serialize; +use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring}; -use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId}; +use crate::stripe_client::{ + CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, + StripePriceId, StripePriceRecurring, +}; pub struct RealStripeClient { client: Arc, @@ -48,6 +52,37 @@ impl StripeClient for RealStripeClient { Ok(StripeCustomer::from(customer)) } + + async fn list_prices(&self) -> Result> { + let response = stripe::Price::list( + &self.client, + &stripe::ListPrices { + limit: Some(100), + ..Default::default() + }, + ) + .await?; + + Ok(response.data.into_iter().map(StripePrice::from).collect()) + } + + async fn list_meters(&self) -> Result> { + #[derive(Serialize)] + struct Params { + #[serde(skip_serializing_if = "Option::is_none")] + limit: Option, + } + + let response = self + .client + .get_query::, _>( + "/billing/meters", + Params { limit: Some(100) }, + ) + .await?; + + Ok(response.data) + } } impl From for StripeCustomerId { @@ -72,3 +107,34 @@ impl From for StripeCustomer { } } } + +impl From for StripePriceId { + fn from(value: PriceId) -> Self { + Self(value.as_str().into()) + } +} + +impl TryFrom for PriceId { + type Error = anyhow::Error; + + fn try_from(value: StripePriceId) -> Result { + Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID") + } +} + +impl From for StripePrice { + fn from(value: Price) -> Self { + Self { + id: value.id.into(), + unit_amount: value.unit_amount, + lookup_key: value.lookup_key, + recurring: value.recurring.map(StripePriceRecurring::from), + } + } +} + +impl From for StripePriceRecurring { + fn from(value: Recurring) -> Self { + Self { meter: value.meter } + } +} diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index db8161b8b5..cce84186ae 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -3,7 +3,9 @@ use std::sync::Arc; use pretty_assertions::assert_eq; use crate::stripe_billing::StripeBilling; -use crate::stripe_client::FakeStripeClient; +use crate::stripe_client::{ + FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring, +}; fn make_stripe_billing() -> (StripeBilling, Arc) { let stripe_client = Arc::new(FakeStripeClient::new()); @@ -12,6 +14,87 @@ fn make_stripe_billing() -> (StripeBilling, Arc) { (stripe_billing, stripe_client) } +#[gpui::test] +async fn test_initialize() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + // Add test meters + let meter1 = StripeMeter { + id: StripeMeterId("meter_1".into()), + event_name: "event_1".to_string(), + }; + let meter2 = StripeMeter { + id: StripeMeterId("meter_2".into()), + event_name: "event_2".to_string(), + }; + stripe_client + .meters + .lock() + .insert(meter1.id.clone(), meter1); + stripe_client + .meters + .lock() + .insert(meter2.id.clone(), meter2); + + // Add test prices + let price1 = StripePrice { + id: StripePriceId("price_1".into()), + unit_amount: Some(1_000), + lookup_key: Some("zed-pro".to_string()), + recurring: None, + }; + let price2 = StripePrice { + id: StripePriceId("price_2".into()), + unit_amount: Some(0), + lookup_key: Some("zed-free".to_string()), + recurring: None, + }; + let price3 = StripePrice { + id: StripePriceId("price_3".into()), + unit_amount: Some(500), + lookup_key: None, + recurring: Some(StripePriceRecurring { + meter: Some("meter_1".to_string()), + }), + }; + stripe_client + .prices + .lock() + .insert(price1.id.clone(), price1); + stripe_client + .prices + .lock() + .insert(price2.id.clone(), price2); + stripe_client + .prices + .lock() + .insert(price3.id.clone(), price3); + + // Initialize the billing system + stripe_billing.initialize().await.unwrap(); + + // Verify that prices can be found by lookup key + let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap(); + assert_eq!(zed_pro_price_id.to_string(), "price_1"); + + let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap(); + assert_eq!(zed_free_price_id.to_string(), "price_2"); + + // Verify that a price can be found by lookup key + let zed_pro_price = stripe_billing + .find_price_by_lookup_key("zed-pro") + .await + .unwrap(); + assert_eq!(zed_pro_price.id.to_string(), "price_1"); + assert_eq!(zed_pro_price.unit_amount, Some(1_000)); + + // Verify that finding a non-existent lookup key returns an error + let result = stripe_billing + .find_price_by_lookup_key("non-existent") + .await; + assert!(result.is_err()); +} + #[gpui::test] async fn test_find_or_create_customer_by_email() { let (stripe_billing, stripe_client) = make_stripe_billing();