diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index ee7d93979d..c121d8e8fd 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -8,11 +8,16 @@ use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; pub struct StripeBilling { - meters_by_event_name: RwLock>, - price_ids_by_meter_id: RwLock>, + state: RwLock, client: Arc, } +#[derive(Default)] +struct StripeBillingState { + meters_by_event_name: HashMap, + price_ids_by_meter_id: HashMap, +} + pub struct StripeModel { input_tokens_price: StripeBillingPrice, input_cache_creation_tokens_price: StripeBillingPrice, @@ -29,36 +34,36 @@ impl StripeBilling { pub fn new(client: Arc) -> Self { Self { client, - meters_by_event_name: RwLock::new(HashMap::default()), - price_ids_by_meter_id: RwLock::new(HashMap::default()), + state: RwLock::default(), } } pub async fn initialize(&self) -> Result<()> { - log::info!("initializing StripeBilling"); + log::info!("StripeBilling: initializing"); - { - let meters = StripeMeter::list(&self.client).await?.data; - let mut meters_by_event_name = self.meters_by_event_name.write().await; - for meter in meters { - meters_by_event_name.insert(meter.event_name.clone(), meter); - } + let mut state = self.state.write().await; + + let (meters, prices) = futures::try_join!( + StripeMeter::list(&self.client), + stripe::Price::list(&self.client, &stripe::ListPrices::default()) + )?; + + for meter in meters.data { + state + .meters_by_event_name + .insert(meter.event_name.clone(), meter); } - { - let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default()) - .await? - .data; - let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await; - for price in prices { - if let Some(recurring) = price.recurring { - if let Some(meter) = recurring.meter { - price_ids_by_meter_id.insert(meter, price.id); - } + for price in prices.data { + if let Some(recurring) = price.recurring { + if let Some(meter) = recurring.meter { + state.price_ids_by_meter_id.insert(meter, price.id); } } } + log::info!("StripeBilling: initialized"); + Ok(()) } @@ -105,79 +110,89 @@ impl StripeBilling { price_description: &str, price_per_million_tokens: Cents, ) -> Result { - let meter = - if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) { - meter.clone() - } else { - let meter = StripeMeter::create( - &self.client, - StripeCreateMeterParams { - default_aggregation: DefaultAggregation { formula: "sum" }, - display_name: price_description.to_string(), - event_name: meter_event_name, - }, - ) - .await?; - self.meters_by_event_name - .write() - .await - .insert(meter_event_name.to_string(), meter.clone()); - meter - }; + // Fast code path when the meter and the price already exist. + { + let state = self.state.read().await; + if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { + if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { + return Ok(StripeBillingPrice { + id: price_id.clone(), + meter_event_name: meter_event_name.to_string(), + }); + } + } + } - let price_id = - if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) { - price_id.clone() - } else { - let price = stripe::Price::create( - &self.client, - stripe::CreatePrice { + let mut state = self.state.write().await; + let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { + meter.clone() + } else { + let meter = StripeMeter::create( + &self.client, + StripeCreateMeterParams { + default_aggregation: DefaultAggregation { formula: "sum" }, + display_name: price_description.to_string(), + event_name: meter_event_name, + }, + ) + .await?; + state + .meters_by_event_name + .insert(meter_event_name.to_string(), meter.clone()); + meter + }; + + let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { + price_id.clone() + } else { + let price = stripe::Price::create( + &self.client, + stripe::CreatePrice { + active: Some(true), + billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), + currency: stripe::Currency::USD, + currency_options: None, + custom_unit_amount: None, + expand: &[], + lookup_key: None, + metadata: None, + nickname: None, + product: None, + product_data: Some(stripe::CreatePriceProductData { + id: None, active: Some(true), - billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), - currency: stripe::Currency::USD, - currency_options: None, - custom_unit_amount: None, - expand: &[], - lookup_key: None, metadata: None, - nickname: None, - product: None, - product_data: Some(stripe::CreatePriceProductData { - id: None, - active: Some(true), - metadata: None, - name: price_description.to_string(), - statement_descriptor: None, - tax_code: None, - unit_label: None, - }), - recurring: Some(stripe::CreatePriceRecurring { - aggregate_usage: None, - interval: stripe::CreatePriceRecurringInterval::Month, - interval_count: None, - trial_period_days: None, - usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), - meter: Some(meter.id.clone()), - }), - tax_behavior: None, - tiers: None, - tiers_mode: None, - transfer_lookup_key: None, - transform_quantity: None, - unit_amount: None, - unit_amount_decimal: Some(&format!( - "{:.12}", - price_per_million_tokens.0 as f64 / 1_000_000f64 - )), - }, - ) - .await?; - self.price_ids_by_meter_id - .write() - .await - .insert(meter.id, price.id.clone()); - price.id - }; + name: price_description.to_string(), + statement_descriptor: None, + tax_code: None, + unit_label: None, + }), + recurring: Some(stripe::CreatePriceRecurring { + aggregate_usage: None, + interval: stripe::CreatePriceRecurringInterval::Month, + interval_count: None, + trial_period_days: None, + usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), + meter: Some(meter.id.clone()), + }), + tax_behavior: None, + tiers: None, + tiers_mode: None, + transfer_lookup_key: None, + transform_quantity: None, + unit_amount: None, + unit_amount_decimal: Some(&format!( + "{:.12}", + price_per_million_tokens.0 as f64 / 1_000_000f64 + )), + }, + ) + .await?; + state + .price_ids_by_meter_id + .insert(meter.id, price.id.clone()); + price.id + }; Ok(StripeBillingPrice { id: price_id,