collab: Sync model request overages to Stripe (#29583)

This PR adds syncing of model request overages to Stripe.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-28 23:06:30 -04:00 committed by GitHub
parent 3a212e72a4
commit 5092f0f18b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 318 additions and 16 deletions

View file

@ -1,12 +1,13 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
use anyhow::Context as _;
use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::PriceId;
use tokio::sync::RwLock;
use uuid::Uuid;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
@ -17,9 +18,10 @@ pub struct StripeBilling {
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
prices_by_lookup_key: HashMap<String, stripe::Price>,
}
pub struct StripeModel {
pub struct StripeModelTokenPrices {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
@ -62,6 +64,10 @@ impl StripeBilling {
}
for price in prices.data {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
state.price_ids_by_meter_id.insert(meter, price.id);
@ -74,36 +80,49 @@ impl StripeBilling {
Ok(())
}
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
&self,
model: &llm::db::model::Model,
) -> Result<StripeModelTokenPrices> {
let input_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
Ok(StripeModel {
Ok(StripeModelTokenPrices {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
@ -111,7 +130,7 @@ impl StripeBilling {
})
}
async fn get_or_insert_price(
async fn get_or_insert_token_price(
&self,
meter_event_name: &str,
price_description: &str,
@ -207,10 +226,43 @@ impl StripeBilling {
})
}
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price_id: &stripe::PriceId,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, price_id) {
return Ok(());
}
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price_id.to_string()),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
..Default::default()
},
)
.await?;
Ok(())
}
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
model: &StripeModel,
model: &StripeModelTokenPrices,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
@ -271,7 +323,7 @@ impl StripeBilling {
pub async fn bill_model_token_usage(
&self,
customer_id: &stripe::CustomerId,
model: &StripeModel,
model: &StripeModelTokenPrices,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
@ -343,11 +395,37 @@ impl StripeBilling {
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &stripe::CustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
value: requests as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
Ok(())
}
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
model: &StripeModel,
model: &StripeModelTokenPrices,
success_url: &str,
) -> Result<String> {
let first_of_next_month = Utc::now()