use std::sync::Arc; use crate::{Cents, Result, llm}; use anyhow::Context as _; use chrono::{Datelike, Utc}; use collections::HashMap; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; pub struct StripeBilling { state: RwLock, client: Arc, } #[derive(Default)] struct StripeBillingState { meters_by_event_name: HashMap, price_ids_by_meter_id: HashMap, } pub struct StripeModel { input_tokens_price: StripeBillingPrice, input_cache_creation_tokens_price: StripeBillingPrice, input_cache_read_tokens_price: StripeBillingPrice, output_tokens_price: StripeBillingPrice, } struct StripeBillingPrice { id: stripe::PriceId, meter_event_name: String, } impl StripeBilling { pub fn new(client: Arc) -> Self { Self { client, state: RwLock::default(), } } pub async fn initialize(&self) -> Result<()> { log::info!("StripeBilling: initializing"); let mut state = self.state.write().await; let (meters, prices) = futures::try_join!( StripeMeter::list(&self.client), stripe::Price::list( &self.client, &stripe::ListPrices { limit: Some(100), ..Default::default() } ) )?; for meter in meters.data { state .meters_by_event_name .insert(meter.event_name.clone(), meter); } for price in prices.data { if let Some(recurring) = price.recurring { if let Some(meter) = recurring.meter { state.price_ids_by_meter_id.insert(meter, price.id); } } } log::info!("StripeBilling: initialized"); Ok(()) } pub async fn register_model(&self, model: &llm::db::model::Model) -> Result { let input_tokens_price = self .get_or_insert_price( &format!("model_{}/input_tokens", model.id), &format!("{} (Input Tokens)", model.name), Cents::new(model.price_per_million_input_tokens as u32), ) .await?; let input_cache_creation_tokens_price = self .get_or_insert_price( &format!("model_{}/input_cache_creation_tokens", model.id), &format!("{} (Input Cache Creation Tokens)", model.name), Cents::new(model.price_per_million_cache_creation_input_tokens as u32), ) .await?; let input_cache_read_tokens_price = self .get_or_insert_price( &format!("model_{}/input_cache_read_tokens", model.id), &format!("{} (Input Cache Read Tokens)", model.name), Cents::new(model.price_per_million_cache_read_input_tokens as u32), ) .await?; let output_tokens_price = self .get_or_insert_price( &format!("model_{}/output_tokens", model.id), &format!("{} (Output Tokens)", model.name), Cents::new(model.price_per_million_output_tokens as u32), ) .await?; Ok(StripeModel { input_tokens_price, input_cache_creation_tokens_price, input_cache_read_tokens_price, output_tokens_price, }) } async fn get_or_insert_price( &self, meter_event_name: &str, price_description: &str, price_per_million_tokens: Cents, ) -> Result { // Fast code path when the meter and the price already exist. { let state = self.state.read().await; if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { return Ok(StripeBillingPrice { id: price_id.clone(), meter_event_name: meter_event_name.to_string(), }); } } } let mut state = self.state.write().await; let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { meter.clone() } else { let meter = StripeMeter::create( &self.client, StripeCreateMeterParams { default_aggregation: DefaultAggregation { formula: "sum" }, display_name: price_description.to_string(), event_name: meter_event_name, }, ) .await?; state .meters_by_event_name .insert(meter_event_name.to_string(), meter.clone()); meter }; let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { price_id.clone() } else { let price = stripe::Price::create( &self.client, stripe::CreatePrice { active: Some(true), billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), currency: stripe::Currency::USD, currency_options: None, custom_unit_amount: None, expand: &[], lookup_key: None, metadata: None, nickname: None, product: None, product_data: Some(stripe::CreatePriceProductData { id: None, active: Some(true), metadata: None, name: price_description.to_string(), statement_descriptor: None, tax_code: None, unit_label: None, }), recurring: Some(stripe::CreatePriceRecurring { aggregate_usage: None, interval: stripe::CreatePriceRecurringInterval::Month, interval_count: None, trial_period_days: None, usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), meter: Some(meter.id.clone()), }), tax_behavior: None, tiers: None, tiers_mode: None, transfer_lookup_key: None, transform_quantity: None, unit_amount: None, unit_amount_decimal: Some(&format!( "{:.12}", price_per_million_tokens.0 as f64 / 1_000_000f64 )), }, ) .await?; state .price_ids_by_meter_id .insert(meter.id, price.id.clone()); price.id }; Ok(StripeBillingPrice { id: price_id, meter_event_name: meter_event_name.to_string(), }) } pub async fn subscribe_to_model( &self, subscription_id: &stripe::SubscriptionId, model: &StripeModel, ) -> Result<()> { let subscription = stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; let mut items = Vec::new(); if !subscription_contains_price(&subscription, &model.input_tokens_price.id) { items.push(stripe::UpdateSubscriptionItems { price: Some(model.input_tokens_price.id.to_string()), ..Default::default() }); } if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id) { items.push(stripe::UpdateSubscriptionItems { price: Some(model.input_cache_creation_tokens_price.id.to_string()), ..Default::default() }); } if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) { items.push(stripe::UpdateSubscriptionItems { price: Some(model.input_cache_read_tokens_price.id.to_string()), ..Default::default() }); } if !subscription_contains_price(&subscription, &model.output_tokens_price.id) { items.push(stripe::UpdateSubscriptionItems { price: Some(model.output_tokens_price.id.to_string()), ..Default::default() }); } if !items.is_empty() { items.extend(subscription.items.data.iter().map(|item| { stripe::UpdateSubscriptionItems { id: Some(item.id.to_string()), ..Default::default() } })); stripe::Subscription::update( &self.client, subscription_id, stripe::UpdateSubscription { items: Some(items), ..Default::default() }, ) .await?; } Ok(()) } pub async fn bill_model_usage( &self, customer_id: &stripe::CustomerId, model: &StripeModel, event: &llm::db::billing_event::Model, ) -> Result<()> { let timestamp = Utc::now().timestamp(); if event.input_tokens > 0 { StripeMeterEvent::create( &self.client, StripeCreateMeterEventParams { identifier: &format!("input_tokens/{}", event.idempotency_key), event_name: &model.input_tokens_price.meter_event_name, payload: StripeCreateMeterEventPayload { value: event.input_tokens as u64, stripe_customer_id: customer_id, }, timestamp: Some(timestamp), }, ) .await?; } if event.input_cache_creation_tokens > 0 { StripeMeterEvent::create( &self.client, StripeCreateMeterEventParams { identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key), event_name: &model.input_cache_creation_tokens_price.meter_event_name, payload: StripeCreateMeterEventPayload { value: event.input_cache_creation_tokens as u64, stripe_customer_id: customer_id, }, timestamp: Some(timestamp), }, ) .await?; } if event.input_cache_read_tokens > 0 { StripeMeterEvent::create( &self.client, StripeCreateMeterEventParams { identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key), event_name: &model.input_cache_read_tokens_price.meter_event_name, payload: StripeCreateMeterEventPayload { value: event.input_cache_read_tokens as u64, stripe_customer_id: customer_id, }, timestamp: Some(timestamp), }, ) .await?; } if event.output_tokens > 0 { StripeMeterEvent::create( &self.client, StripeCreateMeterEventParams { identifier: &format!("output_tokens/{}", event.idempotency_key), event_name: &model.output_tokens_price.meter_event_name, payload: StripeCreateMeterEventPayload { value: event.output_tokens as u64, stripe_customer_id: customer_id, }, timestamp: Some(timestamp), }, ) .await?; } Ok(()) } pub async fn checkout( &self, customer_id: stripe::CustomerId, github_login: &str, model: &StripeModel, success_url: &str, ) -> Result { let first_of_next_month = Utc::now() .checked_add_months(chrono::Months::new(1)) .unwrap() .with_day(1) .unwrap(); let mut params = stripe::CreateCheckoutSession::new(); params.mode = Some(stripe::CheckoutSessionMode::Subscription); params.customer = Some(customer_id); params.client_reference_id = Some(github_login); params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData { billing_cycle_anchor: Some(first_of_next_month.timestamp()), ..Default::default() }); params.line_items = Some( [ &model.input_tokens_price.id, &model.input_cache_creation_tokens_price.id, &model.input_cache_read_tokens_price.id, &model.output_tokens_price.id, ] .into_iter() .map(|price_id| stripe::CreateCheckoutSessionLineItems { price: Some(price_id.to_string()), ..Default::default() }) .collect(), ); params.success_url = Some(success_url); let session = stripe::CheckoutSession::create(&self.client, params).await?; Ok(session.url.context("no checkout session URL")?) } } #[derive(Serialize)] struct DefaultAggregation { formula: &'static str, } #[derive(Serialize)] struct StripeCreateMeterParams<'a> { default_aggregation: DefaultAggregation, display_name: String, event_name: &'a str, } #[derive(Clone, Deserialize)] struct StripeMeter { id: String, event_name: String, } impl StripeMeter { pub fn create( client: &stripe::Client, params: StripeCreateMeterParams, ) -> stripe::Response { client.post_form("/billing/meters", params) } pub fn list(client: &stripe::Client) -> stripe::Response> { #[derive(Serialize)] struct Params { #[serde(skip_serializing_if = "Option::is_none")] limit: Option, } client.get_query("/billing/meters", Params { limit: Some(100) }) } } #[derive(Deserialize)] struct StripeMeterEvent { identifier: String, } impl StripeMeterEvent { pub async fn create( client: &stripe::Client, params: StripeCreateMeterEventParams<'_>, ) -> Result { let identifier = params.identifier; match client.post_form("/billing/meter_events", params).await { Ok(event) => Ok(event), Err(stripe::StripeError::Stripe(error)) => { if error.http_status == 400 && error .message .as_ref() .map_or(false, |message| message.contains(identifier)) { Ok(Self { identifier: identifier.to_string(), }) } else { Err(stripe::StripeError::Stripe(error)) } } Err(error) => Err(error), } } } #[derive(Serialize)] struct StripeCreateMeterEventParams<'a> { identifier: &'a str, event_name: &'a str, payload: StripeCreateMeterEventPayload<'a>, timestamp: Option, } #[derive(Serialize)] struct StripeCreateMeterEventPayload<'a> { value: u64, stripe_customer_id: &'a stripe::CustomerId, } fn subscription_contains_price( subscription: &stripe::Subscription, price_id: &stripe::PriceId, ) -> bool { subscription.items.data.iter().any(|item| { item.price .as_ref() .map_or(false, |price| price.id == *price_id) }) }