collab: Add billing thresholds to request overage subscription items (#29738)

This PR adds billing thresholds of the unit equivalent of $20 for model
request overages.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-01 12:10:06 -04:00 committed by GitHub
parent 5bf1b4f0a8
commit 57610c9935
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 8 deletions

View file

@ -1403,13 +1403,13 @@ async fn sync_model_request_usage_with_stripe(
.await?; .await?;
let claude_3_5_sonnet = stripe_billing let claude_3_5_sonnet = stripe_billing
.find_price_id_by_lookup_key("claude-3-5-sonnet-requests") .find_price_by_lookup_key("claude-3-5-sonnet-requests")
.await?; .await?;
let claude_3_7_sonnet = stripe_billing let claude_3_7_sonnet = stripe_billing
.find_price_id_by_lookup_key("claude-3-7-sonnet-requests") .find_price_by_lookup_key("claude-3-7-sonnet-requests")
.await?; .await?;
let claude_3_7_sonnet_max = stripe_billing let claude_3_7_sonnet_max = stripe_billing
.find_price_id_by_lookup_key("claude-3-7-sonnet-requests-max") .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?; .await?;
for (usage_meter, usage) in usage_meters { for (usage_meter, usage) in usage_meters {
@ -1434,7 +1434,7 @@ async fn sync_model_request_usage_with_stripe(
let model = llm_db.model_by_id(usage_meter.model_id)?; let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price_id, meter_event_name) = match model.name.as_str() { let (price, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match usage_meter.mode { "claude-3-7-sonnet" => match usage_meter.mode {
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
@ -1448,7 +1448,7 @@ async fn sync_model_request_usage_with_stripe(
}; };
stripe_billing stripe_billing
.subscribe_to_price(&stripe_subscription_id, price_id) .subscribe_to_price(&stripe_subscription_id, price)
.await?; .await?;
stripe_billing stripe_billing
.bill_model_request_usage( .bill_model_request_usage(

View file

@ -99,6 +99,16 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
} }
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage( pub async fn register_model_for_token_based_usage(
&self, &self,
model: &llm::db::model::Model, model: &llm::db::model::Model,
@ -238,21 +248,29 @@ impl StripeBilling {
pub async fn subscribe_to_price( pub async fn subscribe_to_price(
&self, &self,
subscription_id: &stripe::SubscriptionId, subscription_id: &stripe::SubscriptionId,
price_id: &stripe::PriceId, price: &stripe::Price,
) -> Result<()> { ) -> Result<()> {
let subscription = let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, price_id) { if subscription_contains_price(&subscription, &price.id) {
return Ok(()); return Ok(());
} }
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update( stripe::Subscription::update(
&self.client, &self.client,
subscription_id, subscription_id,
stripe::UpdateSubscription { stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems { items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price_id.to_string()), price: Some(price.id.to_string()),
billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
usage_gte: Some(units_for_billing_threshold),
}),
..Default::default() ..Default::default()
}]), }]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings { trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {