Prevent deadlock when create a new meter/price on Stripe (#19196)
This also puts the entire state of `StripeBilling` behind a `RwLock`. When fetching the existing prices and meters, or when inserting new ones, we acquire a write lock and hold it until the Stripe request completes. This prevents two concurrent calls to `get_or_insert_price` from inserting the same data twice. Creating a new meter/price is unusual, so in practice we'll acquire a read lock most of the time. /cc @rtfeldman @maxdeviant Release Notes: - N/A
This commit is contained in:
parent
6986f081d0
commit
6e2869a321
1 changed files with 106 additions and 91 deletions
|
@ -8,11 +8,16 @@ use serde::{Deserialize, Serialize};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
pub struct StripeBilling {
|
pub struct StripeBilling {
|
||||||
meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
|
state: RwLock<StripeBillingState>,
|
||||||
price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
|
|
||||||
client: Arc<stripe::Client>,
|
client: Arc<stripe::Client>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct StripeBillingState {
|
||||||
|
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||||
|
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct StripeModel {
|
pub struct StripeModel {
|
||||||
input_tokens_price: StripeBillingPrice,
|
input_tokens_price: StripeBillingPrice,
|
||||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||||
|
@ -29,36 +34,36 @@ impl StripeBilling {
|
||||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
meters_by_event_name: RwLock::new(HashMap::default()),
|
state: RwLock::default(),
|
||||||
price_ids_by_meter_id: RwLock::new(HashMap::default()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn initialize(&self) -> Result<()> {
|
pub async fn initialize(&self) -> Result<()> {
|
||||||
log::info!("initializing StripeBilling");
|
log::info!("StripeBilling: initializing");
|
||||||
|
|
||||||
{
|
let mut state = self.state.write().await;
|
||||||
let meters = StripeMeter::list(&self.client).await?.data;
|
|
||||||
let mut meters_by_event_name = self.meters_by_event_name.write().await;
|
let (meters, prices) = futures::try_join!(
|
||||||
for meter in meters {
|
StripeMeter::list(&self.client),
|
||||||
meters_by_event_name.insert(meter.event_name.clone(), meter);
|
stripe::Price::list(&self.client, &stripe::ListPrices::default())
|
||||||
}
|
)?;
|
||||||
|
|
||||||
|
for meter in meters.data {
|
||||||
|
state
|
||||||
|
.meters_by_event_name
|
||||||
|
.insert(meter.event_name.clone(), meter);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
for price in prices.data {
|
||||||
let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
|
if let Some(recurring) = price.recurring {
|
||||||
.await?
|
if let Some(meter) = recurring.meter {
|
||||||
.data;
|
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||||
let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
|
|
||||||
for price in prices {
|
|
||||||
if let Some(recurring) = price.recurring {
|
|
||||||
if let Some(meter) = recurring.meter {
|
|
||||||
price_ids_by_meter_id.insert(meter, price.id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log::info!("StripeBilling: initialized");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,79 +110,89 @@ impl StripeBilling {
|
||||||
price_description: &str,
|
price_description: &str,
|
||||||
price_per_million_tokens: Cents,
|
price_per_million_tokens: Cents,
|
||||||
) -> Result<StripeBillingPrice> {
|
) -> Result<StripeBillingPrice> {
|
||||||
let meter =
|
// Fast code path when the meter and the price already exist.
|
||||||
if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
|
{
|
||||||
meter.clone()
|
let state = self.state.read().await;
|
||||||
} else {
|
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||||
let meter = StripeMeter::create(
|
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||||
&self.client,
|
return Ok(StripeBillingPrice {
|
||||||
StripeCreateMeterParams {
|
id: price_id.clone(),
|
||||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
meter_event_name: meter_event_name.to_string(),
|
||||||
display_name: price_description.to_string(),
|
});
|
||||||
event_name: meter_event_name,
|
}
|
||||||
},
|
}
|
||||||
)
|
}
|
||||||
.await?;
|
|
||||||
self.meters_by_event_name
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert(meter_event_name.to_string(), meter.clone());
|
|
||||||
meter
|
|
||||||
};
|
|
||||||
|
|
||||||
let price_id =
|
let mut state = self.state.write().await;
|
||||||
if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
|
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||||
price_id.clone()
|
meter.clone()
|
||||||
} else {
|
} else {
|
||||||
let price = stripe::Price::create(
|
let meter = StripeMeter::create(
|
||||||
&self.client,
|
&self.client,
|
||||||
stripe::CreatePrice {
|
StripeCreateMeterParams {
|
||||||
|
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||||
|
display_name: price_description.to_string(),
|
||||||
|
event_name: meter_event_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
state
|
||||||
|
.meters_by_event_name
|
||||||
|
.insert(meter_event_name.to_string(), meter.clone());
|
||||||
|
meter
|
||||||
|
};
|
||||||
|
|
||||||
|
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||||
|
price_id.clone()
|
||||||
|
} else {
|
||||||
|
let price = stripe::Price::create(
|
||||||
|
&self.client,
|
||||||
|
stripe::CreatePrice {
|
||||||
|
active: Some(true),
|
||||||
|
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||||
|
currency: stripe::Currency::USD,
|
||||||
|
currency_options: None,
|
||||||
|
custom_unit_amount: None,
|
||||||
|
expand: &[],
|
||||||
|
lookup_key: None,
|
||||||
|
metadata: None,
|
||||||
|
nickname: None,
|
||||||
|
product: None,
|
||||||
|
product_data: Some(stripe::CreatePriceProductData {
|
||||||
|
id: None,
|
||||||
active: Some(true),
|
active: Some(true),
|
||||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
|
||||||
currency: stripe::Currency::USD,
|
|
||||||
currency_options: None,
|
|
||||||
custom_unit_amount: None,
|
|
||||||
expand: &[],
|
|
||||||
lookup_key: None,
|
|
||||||
metadata: None,
|
metadata: None,
|
||||||
nickname: None,
|
name: price_description.to_string(),
|
||||||
product: None,
|
statement_descriptor: None,
|
||||||
product_data: Some(stripe::CreatePriceProductData {
|
tax_code: None,
|
||||||
id: None,
|
unit_label: None,
|
||||||
active: Some(true),
|
}),
|
||||||
metadata: None,
|
recurring: Some(stripe::CreatePriceRecurring {
|
||||||
name: price_description.to_string(),
|
aggregate_usage: None,
|
||||||
statement_descriptor: None,
|
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||||
tax_code: None,
|
interval_count: None,
|
||||||
unit_label: None,
|
trial_period_days: None,
|
||||||
}),
|
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||||
recurring: Some(stripe::CreatePriceRecurring {
|
meter: Some(meter.id.clone()),
|
||||||
aggregate_usage: None,
|
}),
|
||||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
tax_behavior: None,
|
||||||
interval_count: None,
|
tiers: None,
|
||||||
trial_period_days: None,
|
tiers_mode: None,
|
||||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
transfer_lookup_key: None,
|
||||||
meter: Some(meter.id.clone()),
|
transform_quantity: None,
|
||||||
}),
|
unit_amount: None,
|
||||||
tax_behavior: None,
|
unit_amount_decimal: Some(&format!(
|
||||||
tiers: None,
|
"{:.12}",
|
||||||
tiers_mode: None,
|
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||||
transfer_lookup_key: None,
|
)),
|
||||||
transform_quantity: None,
|
},
|
||||||
unit_amount: None,
|
)
|
||||||
unit_amount_decimal: Some(&format!(
|
.await?;
|
||||||
"{:.12}",
|
state
|
||||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
.price_ids_by_meter_id
|
||||||
)),
|
.insert(meter.id, price.id.clone());
|
||||||
},
|
price.id
|
||||||
)
|
};
|
||||||
.await?;
|
|
||||||
self.price_ids_by_meter_id
|
|
||||||
.write()
|
|
||||||
.await
|
|
||||||
.insert(meter.id, price.id.clone());
|
|
||||||
price.id
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(StripeBillingPrice {
|
Ok(StripeBillingPrice {
|
||||||
id: price_id,
|
id: price_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue