Prevent deadlock when create a new meter/price on Stripe (#19196)

This also puts the entire state of `StripeBilling` behind a `RwLock`.
When fetching the existing prices and meters, or when inserting new
ones, we acquire a write lock and hold it until the Stripe request
completes. This prevents two concurrent calls to `get_or_insert_price`
from inserting the same data twice.

Creating a new meter/price is unusual, so in practice we'll acquire a
read lock most of the time.

/cc @rtfeldman @maxdeviant 

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-10-14 17:31:51 +02:00 committed by GitHub
parent 6986f081d0
commit 6e2869a321
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,11 +8,16 @@ use serde::{Deserialize, Serialize};
use tokio::sync::RwLock; use tokio::sync::RwLock;
pub struct StripeBilling { pub struct StripeBilling {
meters_by_event_name: RwLock<HashMap<String, StripeMeter>>, state: RwLock<StripeBillingState>,
price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
client: Arc<stripe::Client>, client: Arc<stripe::Client>,
} }
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
}
pub struct StripeModel { pub struct StripeModel {
input_tokens_price: StripeBillingPrice, input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice, input_cache_creation_tokens_price: StripeBillingPrice,
@ -29,36 +34,36 @@ impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self { pub fn new(client: Arc<stripe::Client>) -> Self {
Self { Self {
client, client,
meters_by_event_name: RwLock::new(HashMap::default()), state: RwLock::default(),
price_ids_by_meter_id: RwLock::new(HashMap::default()),
} }
} }
pub async fn initialize(&self) -> Result<()> { pub async fn initialize(&self) -> Result<()> {
log::info!("initializing StripeBilling"); log::info!("StripeBilling: initializing");
{ let mut state = self.state.write().await;
let meters = StripeMeter::list(&self.client).await?.data;
let mut meters_by_event_name = self.meters_by_event_name.write().await; let (meters, prices) = futures::try_join!(
for meter in meters { StripeMeter::list(&self.client),
meters_by_event_name.insert(meter.event_name.clone(), meter); stripe::Price::list(&self.client, &stripe::ListPrices::default())
} )?;
for meter in meters.data {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
} }
{ for price in prices.data {
let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default()) if let Some(recurring) = price.recurring {
.await? if let Some(meter) = recurring.meter {
.data; state.price_ids_by_meter_id.insert(meter, price.id);
let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
for price in prices {
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
price_ids_by_meter_id.insert(meter, price.id);
}
} }
} }
} }
log::info!("StripeBilling: initialized");
Ok(()) Ok(())
} }
@ -105,79 +110,89 @@ impl StripeBilling {
price_description: &str, price_description: &str,
price_per_million_tokens: Cents, price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> { ) -> Result<StripeBillingPrice> {
let meter = // Fast code path when the meter and the price already exist.
if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) { {
meter.clone() let state = self.state.read().await;
} else { if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
let meter = StripeMeter::create( if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
&self.client, return Ok(StripeBillingPrice {
StripeCreateMeterParams { id: price_id.clone(),
default_aggregation: DefaultAggregation { formula: "sum" }, meter_event_name: meter_event_name.to_string(),
display_name: price_description.to_string(), });
event_name: meter_event_name, }
}, }
) }
.await?;
self.meters_by_event_name
.write()
.await
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = let mut state = self.state.write().await;
if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) { let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
price_id.clone() meter.clone()
} else { } else {
let price = stripe::Price::create( let meter = StripeMeter::create(
&self.client, &self.client,
stripe::CreatePrice { StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
state
.meters_by_event_name
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true), active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None, metadata: None,
nickname: None, name: price_description.to_string(),
product: None, statement_descriptor: None,
product_data: Some(stripe::CreatePriceProductData { tax_code: None,
id: None, unit_label: None,
active: Some(true), }),
metadata: None, recurring: Some(stripe::CreatePriceRecurring {
name: price_description.to_string(), aggregate_usage: None,
statement_descriptor: None, interval: stripe::CreatePriceRecurringInterval::Month,
tax_code: None, interval_count: None,
unit_label: None, trial_period_days: None,
}), usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
recurring: Some(stripe::CreatePriceRecurring { meter: Some(meter.id.clone()),
aggregate_usage: None, }),
interval: stripe::CreatePriceRecurringInterval::Month, tax_behavior: None,
interval_count: None, tiers: None,
trial_period_days: None, tiers_mode: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), transfer_lookup_key: None,
meter: Some(meter.id.clone()), transform_quantity: None,
}), unit_amount: None,
tax_behavior: None, unit_amount_decimal: Some(&format!(
tiers: None, "{:.12}",
tiers_mode: None, price_per_million_tokens.0 as f64 / 1_000_000f64
transfer_lookup_key: None, )),
transform_quantity: None, },
unit_amount: None, )
unit_amount_decimal: Some(&format!( .await?;
"{:.12}", state
price_per_million_tokens.0 as f64 / 1_000_000f64 .price_ids_by_meter_id
)), .insert(meter.id, price.id.clone());
}, price.id
) };
.await?;
self.price_ids_by_meter_id
.write()
.await
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice { Ok(StripeBillingPrice {
id: price_id, id: price_id,