From f280b298594a3c20cfa0a3d473d6d499cf59a60c Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 11 Oct 2024 15:15:08 -0400 Subject: [PATCH] collab: Make the `StripeBilling` object long-lived (#19090) This PR makes the `StripeBilling` object long-lived so that we can make better use of the cached data on it. We now hold it on the `AppState` and spawn a background task to initialize the cache on startup. Release Notes: - N/A Co-authored-by: Richard --- crates/collab/src/api/billing.rs | 18 ++- crates/collab/src/lib.rs | 9 +- crates/collab/src/main.rs | 7 + crates/collab/src/stripe_billing.rs | 186 ++++++++++++++----------- crates/collab/src/tests/test_server.rs | 1 + 5 files changed, 129 insertions(+), 92 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index c0fc33a643..0e46e8453d 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -210,6 +210,13 @@ async fn create_billing_subscription( "not supported".into(), ))? }; + let Some(stripe_billing) = app.stripe_billing.clone() else { + log::error!("failed to retrieve Stripe billing object"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; let Some(llm_db) = app.llm_db.clone() else { log::error!("failed to retrieve LLM database"); Err(Error::http( @@ -236,7 +243,6 @@ async fn create_billing_subscription( }; let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?; - let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?; let stripe_model = stripe_billing.register_model(default_model).await?; let success_url = format!("{}/account", app.config.zed_dot_dev_url()); let checkout_session_url = stripe_billing @@ -716,8 +722,8 @@ async fn find_or_create_billing_customer( const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); pub fn sync_llm_usage_with_stripe_periodically(app: Arc) { - let Some(stripe_client) = app.stripe_client.clone() else { - log::warn!("failed to retrieve Stripe client"); + let Some(stripe_billing) = app.stripe_billing.clone() else { + log::warn!("failed to retrieve Stripe billing object"); return; }; let Some(llm_db) = app.llm_db.clone() else { @@ -730,7 +736,7 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc) { let executor = executor.clone(); async move { loop { - sync_with_stripe(&app, &llm_db, &stripe_client) + sync_with_stripe(&app, &llm_db, &stripe_billing) .await .trace_err(); executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await; @@ -742,10 +748,8 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc) { async fn sync_with_stripe( app: &Arc, llm_db: &Arc, - stripe_client: &Arc, + stripe_billing: &Arc, ) -> anyhow::Result<()> { - let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?; - let events = llm_db.get_billing_events().await?; let user_ids = events .iter() diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 0cc50f68f3..c82968efec 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -31,6 +31,8 @@ use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; +use crate::stripe_billing::StripeBilling; + pub type Result = std::result::Result; pub enum Error { @@ -274,6 +276,7 @@ pub struct AppState { pub live_kit_client: Option>, pub blob_store_client: Option, pub stripe_client: Option>, + pub stripe_billing: Option>, pub rate_limiter: Arc, pub executor: Executor, pub clickhouse_client: Option<::clickhouse::Client>, @@ -317,12 +320,16 @@ impl AppState { }; let db = Arc::new(db); + let stripe_client = build_stripe_client(&config).map(Arc::new).log_err(); let this = Self { db: db.clone(), llm_db, live_kit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_client: build_stripe_client(&config).map(Arc::new).log_err(), + stripe_billing: stripe_client + .clone() + .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), + stripe_client, rate_limiter: Arc::new(RateLimiter::new(db)), executor, clickhouse_client: config diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 02c0baf9de..de77be21eb 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -111,6 +111,13 @@ async fn main() -> Result<()> { let state = AppState::new(config, Executor::Production).await?; + if let Some(stripe_billing) = state.stripe_billing.clone() { + let executor = state.executor.clone(); + executor.spawn_detached(async move { + stripe_billing.initialize().await.trace_err(); + }); + } + if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); RateLimiter::save_periodically( diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 878d160666..ee7d93979d 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -5,10 +5,11 @@ use anyhow::Context; use chrono::Utc; use collections::HashMap; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; pub struct StripeBilling { - meters_by_event_name: HashMap, - price_ids_by_meter_id: HashMap, + meters_by_event_name: RwLock>, + price_ids_by_meter_id: RwLock>, client: Arc, } @@ -25,32 +26,43 @@ struct StripeBillingPrice { } impl StripeBilling { - pub async fn new(client: Arc) -> Result { - let mut meters_by_event_name = HashMap::default(); - for meter in StripeMeter::list(&client).await?.data { - meters_by_event_name.insert(meter.event_name.clone(), meter); + pub fn new(client: Arc) -> Self { + Self { + client, + meters_by_event_name: RwLock::new(HashMap::default()), + price_ids_by_meter_id: RwLock::new(HashMap::default()), + } + } + + pub async fn initialize(&self) -> Result<()> { + log::info!("initializing StripeBilling"); + + { + let meters = StripeMeter::list(&self.client).await?.data; + let mut meters_by_event_name = self.meters_by_event_name.write().await; + for meter in meters { + meters_by_event_name.insert(meter.event_name.clone(), meter); + } } - let mut price_ids_by_meter_id = HashMap::default(); - for price in stripe::Price::list(&client, &stripe::ListPrices::default()) - .await? - .data { - if let Some(recurring) = price.recurring { - if let Some(meter) = recurring.meter { - price_ids_by_meter_id.insert(meter, price.id); + let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default()) + .await? + .data; + let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await; + for price in prices { + if let Some(recurring) = price.recurring { + if let Some(meter) = recurring.meter { + price_ids_by_meter_id.insert(meter, price.id); + } } } } - Ok(Self { - meters_by_event_name, - price_ids_by_meter_id, - client, - }) + Ok(()) } - pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result { + pub async fn register_model(&self, model: &llm::db::model::Model) -> Result { let input_tokens_price = self .get_or_insert_price( &format!("model_{}/input_tokens", model.id), @@ -88,78 +100,84 @@ impl StripeBilling { } async fn get_or_insert_price( - &mut self, + &self, meter_event_name: &str, price_description: &str, price_per_million_tokens: Cents, ) -> Result { - let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) { - meter.clone() - } else { - let meter = StripeMeter::create( - &self.client, - StripeCreateMeterParams { - default_aggregation: DefaultAggregation { formula: "sum" }, - display_name: price_description.to_string(), - event_name: meter_event_name, - }, - ) - .await?; - self.meters_by_event_name - .insert(meter_event_name.to_string(), meter.clone()); - meter - }; + let meter = + if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) { + meter.clone() + } else { + let meter = StripeMeter::create( + &self.client, + StripeCreateMeterParams { + default_aggregation: DefaultAggregation { formula: "sum" }, + display_name: price_description.to_string(), + event_name: meter_event_name, + }, + ) + .await?; + self.meters_by_event_name + .write() + .await + .insert(meter_event_name.to_string(), meter.clone()); + meter + }; - let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) { - price_id.clone() - } else { - let price = stripe::Price::create( - &self.client, - stripe::CreatePrice { - active: Some(true), - billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), - currency: stripe::Currency::USD, - currency_options: None, - custom_unit_amount: None, - expand: &[], - lookup_key: None, - metadata: None, - nickname: None, - product: None, - product_data: Some(stripe::CreatePriceProductData { - id: None, + let price_id = + if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) { + price_id.clone() + } else { + let price = stripe::Price::create( + &self.client, + stripe::CreatePrice { active: Some(true), + billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), + currency: stripe::Currency::USD, + currency_options: None, + custom_unit_amount: None, + expand: &[], + lookup_key: None, metadata: None, - name: price_description.to_string(), - statement_descriptor: None, - tax_code: None, - unit_label: None, - }), - recurring: Some(stripe::CreatePriceRecurring { - aggregate_usage: None, - interval: stripe::CreatePriceRecurringInterval::Month, - interval_count: None, - trial_period_days: None, - usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), - meter: Some(meter.id.clone()), - }), - tax_behavior: None, - tiers: None, - tiers_mode: None, - transfer_lookup_key: None, - transform_quantity: None, - unit_amount: None, - unit_amount_decimal: Some(&format!( - "{:.12}", - price_per_million_tokens.0 as f64 / 1_000_000f64 - )), - }, - ) - .await?; - self.price_ids_by_meter_id - .insert(meter.id, price.id.clone()); - price.id - }; + nickname: None, + product: None, + product_data: Some(stripe::CreatePriceProductData { + id: None, + active: Some(true), + metadata: None, + name: price_description.to_string(), + statement_descriptor: None, + tax_code: None, + unit_label: None, + }), + recurring: Some(stripe::CreatePriceRecurring { + aggregate_usage: None, + interval: stripe::CreatePriceRecurringInterval::Month, + interval_count: None, + trial_period_days: None, + usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), + meter: Some(meter.id.clone()), + }), + tax_behavior: None, + tiers: None, + tiers_mode: None, + transfer_lookup_key: None, + transform_quantity: None, + unit_amount: None, + unit_amount_decimal: Some(&format!( + "{:.12}", + price_per_million_tokens.0 as f64 / 1_000_000f64 + )), + }, + ) + .await?; + self.price_ids_by_meter_id + .write() + .await + .insert(meter.id, price.id.clone()); + price.id + }; Ok(StripeBillingPrice { id: price_id, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 484940c527..210a049e0b 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -639,6 +639,7 @@ impl TestServer { live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), blob_store_client: None, stripe_client: None, + stripe_billing: None, rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), executor, clickhouse_client: None,