collab: Make the StripeBilling object long-lived (#19090)

This PR makes the `StripeBilling` object long-lived so that we can make
better use of the cached data on it.

We now hold it on the `AppState` and spawn a background task to
initialize the cache on startup.

Release Notes:

- N/A

Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-11 15:15:08 -04:00 committed by GitHub
parent 550064f80f
commit f280b29859
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 129 additions and 92 deletions

View file

@ -210,6 +210,13 @@ async fn create_billing_subscription(
"not supported".into(), "not supported".into(),
))? ))?
}; };
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(llm_db) = app.llm_db.clone() else { let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database"); log::error!("failed to retrieve LLM database");
Err(Error::http( Err(Error::http(
@ -236,7 +243,6 @@ async fn create_billing_subscription(
}; };
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?; let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
let stripe_model = stripe_billing.register_model(default_model).await?; let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!("{}/account", app.config.zed_dot_dev_url()); let success_url = format!("{}/account", app.config.zed_dot_dev_url());
let checkout_session_url = stripe_billing let checkout_session_url = stripe_billing
@ -716,8 +722,8 @@ async fn find_or_create_billing_customer(
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) { pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else { let Some(stripe_billing) = app.stripe_billing.clone() else {
log::warn!("failed to retrieve Stripe client"); log::warn!("failed to retrieve Stripe billing object");
return; return;
}; };
let Some(llm_db) = app.llm_db.clone() else { let Some(llm_db) = app.llm_db.clone() else {
@ -730,7 +736,7 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let executor = executor.clone(); let executor = executor.clone();
async move { async move {
loop { loop {
sync_with_stripe(&app, &llm_db, &stripe_client) sync_with_stripe(&app, &llm_db, &stripe_billing)
.await .await
.trace_err(); .trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await; executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
@ -742,10 +748,8 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
async fn sync_with_stripe( async fn sync_with_stripe(
app: &Arc<AppState>, app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>, llm_db: &Arc<LlmDatabase>,
stripe_client: &Arc<stripe::Client>, stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
let events = llm_db.get_billing_events().await?; let events = llm_db.get_billing_events().await?;
let user_ids = events let user_ids = events
.iter() .iter()

View file

@ -31,6 +31,8 @@ use serde::Deserialize;
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
use util::ResultExt; use util::ResultExt;
use crate::stripe_billing::StripeBilling;
pub type Result<T, E = Error> = std::result::Result<T, E>; pub type Result<T, E = Error> = std::result::Result<T, E>;
pub enum Error { pub enum Error {
@ -274,6 +276,7 @@ pub struct AppState {
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>, pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>, pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>, pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub rate_limiter: Arc<RateLimiter>, pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor, pub executor: Executor,
pub clickhouse_client: Option<::clickhouse::Client>, pub clickhouse_client: Option<::clickhouse::Client>,
@ -317,12 +320,16 @@ impl AppState {
}; };
let db = Arc::new(db); let db = Arc::new(db);
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
let this = Self { let this = Self {
db: db.clone(), db: db.clone(),
llm_db, llm_db,
live_kit_client, live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(), blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_client: build_stripe_client(&config).map(Arc::new).log_err(), stripe_billing: stripe_client
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)), rate_limiter: Arc::new(RateLimiter::new(db)),
executor, executor,
clickhouse_client: config clickhouse_client: config

View file

@ -111,6 +111,13 @@ async fn main() -> Result<()> {
let state = AppState::new(config, Executor::Production).await?; let state = AppState::new(config, Executor::Production).await?;
if let Some(stripe_billing) = state.stripe_billing.clone() {
let executor = state.executor.clone();
executor.spawn_detached(async move {
stripe_billing.initialize().await.trace_err();
});
}
if mode.is_collab() { if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err(); state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically( RateLimiter::save_periodically(

View file

@ -5,10 +5,11 @@ use anyhow::Context;
use chrono::Utc; use chrono::Utc;
use collections::HashMap; use collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub struct StripeBilling { pub struct StripeBilling {
meters_by_event_name: HashMap<String, StripeMeter>, meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>, price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
client: Arc<stripe::Client>, client: Arc<stripe::Client>,
} }
@ -25,32 +26,43 @@ struct StripeBillingPrice {
} }
impl StripeBilling { impl StripeBilling {
pub async fn new(client: Arc<stripe::Client>) -> Result<Self> { pub fn new(client: Arc<stripe::Client>) -> Self {
let mut meters_by_event_name = HashMap::default(); Self {
for meter in StripeMeter::list(&client).await?.data { client,
meters_by_event_name.insert(meter.event_name.clone(), meter); meters_by_event_name: RwLock::new(HashMap::default()),
price_ids_by_meter_id: RwLock::new(HashMap::default()),
}
}
pub async fn initialize(&self) -> Result<()> {
log::info!("initializing StripeBilling");
{
let meters = StripeMeter::list(&self.client).await?.data;
let mut meters_by_event_name = self.meters_by_event_name.write().await;
for meter in meters {
meters_by_event_name.insert(meter.event_name.clone(), meter);
}
} }
let mut price_ids_by_meter_id = HashMap::default();
for price in stripe::Price::list(&client, &stripe::ListPrices::default())
.await?
.data
{ {
if let Some(recurring) = price.recurring { let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
if let Some(meter) = recurring.meter { .await?
price_ids_by_meter_id.insert(meter, price.id); .data;
let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
for price in prices {
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
price_ids_by_meter_id.insert(meter, price.id);
}
} }
} }
} }
Ok(Self { Ok(())
meters_by_event_name,
price_ids_by_meter_id,
client,
})
} }
pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> { pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
let input_tokens_price = self let input_tokens_price = self
.get_or_insert_price( .get_or_insert_price(
&format!("model_{}/input_tokens", model.id), &format!("model_{}/input_tokens", model.id),
@ -88,78 +100,84 @@ impl StripeBilling {
} }
async fn get_or_insert_price( async fn get_or_insert_price(
&mut self, &self,
meter_event_name: &str, meter_event_name: &str,
price_description: &str, price_description: &str,
price_per_million_tokens: Cents, price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> { ) -> Result<StripeBillingPrice> {
let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) { let meter =
meter.clone() if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
} else { meter.clone()
let meter = StripeMeter::create( } else {
&self.client, let meter = StripeMeter::create(
StripeCreateMeterParams { &self.client,
default_aggregation: DefaultAggregation { formula: "sum" }, StripeCreateMeterParams {
display_name: price_description.to_string(), default_aggregation: DefaultAggregation { formula: "sum" },
event_name: meter_event_name, display_name: price_description.to_string(),
}, event_name: meter_event_name,
) },
.await?; )
self.meters_by_event_name .await?;
.insert(meter_event_name.to_string(), meter.clone()); self.meters_by_event_name
meter .write()
}; .await
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) { let price_id =
price_id.clone() if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
} else { price_id.clone()
let price = stripe::Price::create( } else {
&self.client, let price = stripe::Price::create(
stripe::CreatePrice { &self.client,
active: Some(true), stripe::CreatePrice {
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true), active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None, metadata: None,
name: price_description.to_string(), nickname: None,
statement_descriptor: None, product: None,
tax_code: None, product_data: Some(stripe::CreatePriceProductData {
unit_label: None, id: None,
}), active: Some(true),
recurring: Some(stripe::CreatePriceRecurring { metadata: None,
aggregate_usage: None, name: price_description.to_string(),
interval: stripe::CreatePriceRecurringInterval::Month, statement_descriptor: None,
interval_count: None, tax_code: None,
trial_period_days: None, unit_label: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), }),
meter: Some(meter.id.clone()), recurring: Some(stripe::CreatePriceRecurring {
}), aggregate_usage: None,
tax_behavior: None, interval: stripe::CreatePriceRecurringInterval::Month,
tiers: None, interval_count: None,
tiers_mode: None, trial_period_days: None,
transfer_lookup_key: None, usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
transform_quantity: None, meter: Some(meter.id.clone()),
unit_amount: None, }),
unit_amount_decimal: Some(&format!( tax_behavior: None,
"{:.12}", tiers: None,
price_per_million_tokens.0 as f64 / 1_000_000f64 tiers_mode: None,
)), transfer_lookup_key: None,
}, transform_quantity: None,
) unit_amount: None,
.await?; unit_amount_decimal: Some(&format!(
self.price_ids_by_meter_id "{:.12}",
.insert(meter.id, price.id.clone()); price_per_million_tokens.0 as f64 / 1_000_000f64
price.id )),
}; },
)
.await?;
self.price_ids_by_meter_id
.write()
.await
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice { Ok(StripeBillingPrice {
id: price_id, id: price_id,

View file

@ -639,6 +639,7 @@ impl TestServer {
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None, blob_store_client: None,
stripe_client: None, stripe_client: None,
stripe_billing: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor, executor,
clickhouse_client: None, clickhouse_client: None,