collab: Make the StripeBilling object long-lived (#19090)

This PR makes the `StripeBilling` object long-lived so that we can make
better use of the cached data on it.

We now hold it on the `AppState` and spawn a background task to
initialize the cache on startup.

Release Notes:

- N/A

Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
Marshall Bowers 2024-10-11 15:15:08 -04:00 committed by GitHub
parent 550064f80f
commit f280b29859
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 129 additions and 92 deletions

View file

@ -5,10 +5,11 @@ use anyhow::Context;
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
pub struct StripeBilling {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
client: Arc<stripe::Client>,
}
@ -25,32 +26,43 @@ struct StripeBillingPrice {
}
impl StripeBilling {
pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
let mut meters_by_event_name = HashMap::default();
for meter in StripeMeter::list(&client).await?.data {
meters_by_event_name.insert(meter.event_name.clone(), meter);
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
client,
meters_by_event_name: RwLock::new(HashMap::default()),
price_ids_by_meter_id: RwLock::new(HashMap::default()),
}
}
pub async fn initialize(&self) -> Result<()> {
log::info!("initializing StripeBilling");
{
let meters = StripeMeter::list(&self.client).await?.data;
let mut meters_by_event_name = self.meters_by_event_name.write().await;
for meter in meters {
meters_by_event_name.insert(meter.event_name.clone(), meter);
}
}
let mut price_ids_by_meter_id = HashMap::default();
for price in stripe::Price::list(&client, &stripe::ListPrices::default())
.await?
.data
{
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
price_ids_by_meter_id.insert(meter, price.id);
let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
.await?
.data;
let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
for price in prices {
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
price_ids_by_meter_id.insert(meter, price.id);
}
}
}
}
Ok(Self {
meters_by_event_name,
price_ids_by_meter_id,
client,
})
Ok(())
}
pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
let input_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_tokens", model.id),
@ -88,78 +100,84 @@ impl StripeBilling {
}
async fn get_or_insert_price(
&mut self,
&self,
meter_event_name: &str,
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
meter.clone()
} else {
let meter = StripeMeter::create(
&self.client,
StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
self.meters_by_event_name
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let meter =
if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
meter.clone()
} else {
let meter = StripeMeter::create(
&self.client,
StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
self.meters_by_event_name
.write()
.await
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
let price_id =
if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
name: price_description.to_string(),
statement_descriptor: None,
tax_code: None,
unit_label: None,
}),
recurring: Some(stripe::CreatePriceRecurring {
aggregate_usage: None,
interval: stripe::CreatePriceRecurringInterval::Month,
interval_count: None,
trial_period_days: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
meter: Some(meter.id.clone()),
}),
tax_behavior: None,
tiers: None,
tiers_mode: None,
transfer_lookup_key: None,
transform_quantity: None,
unit_amount: None,
unit_amount_decimal: Some(&format!(
"{:.12}",
price_per_million_tokens.0 as f64 / 1_000_000f64
)),
},
)
.await?;
self.price_ids_by_meter_id
.insert(meter.id, price.id.clone());
price.id
};
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true),
metadata: None,
name: price_description.to_string(),
statement_descriptor: None,
tax_code: None,
unit_label: None,
}),
recurring: Some(stripe::CreatePriceRecurring {
aggregate_usage: None,
interval: stripe::CreatePriceRecurringInterval::Month,
interval_count: None,
trial_period_days: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
meter: Some(meter.id.clone()),
}),
tax_behavior: None,
tiers: None,
tiers_mode: None,
transfer_lookup_key: None,
transform_quantity: None,
unit_amount: None,
unit_amount_decimal: Some(&format!(
"{:.12}",
price_per_million_tokens.0 as f64 / 1_000_000f64
)),
},
)
.await?;
self.price_ids_by_meter_id
.write()
.await
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice {
id: price_id,