collab: Sync model request overages to Stripe (#29583)

This PR adds syncing of model request overages to Stripe.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-28 23:06:30 -04:00 committed by GitHub
parent 3a212e72a4
commit 5092f0f18b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 318 additions and 16 deletions

View file

@ -393,7 +393,9 @@ async fn create_billing_subscription(
zed_llm_client::LanguageModelProvider::Anthropic,
"claude-3-7-sonnet",
)?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let stripe_model = stripe_billing
.register_model_for_token_based_usage(default_model)
.await?;
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
@ -1303,7 +1305,9 @@ async fn sync_token_usage_with_stripe(
.parse()
.context("failed to parse stripe customer id from db")?;
let stripe_model = stripe_billing.register_model(&model).await?;
let stripe_model = stripe_billing
.register_model_for_token_based_usage(&model)
.await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
@ -1315,3 +1319,106 @@ async fn sync_token_usage_with_stripe(
Ok(())
}
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::warn!("failed to retrieve Stripe billing object");
return;
};
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
return;
};
let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
.await
.context("failed to sync LLM request usage to Stripe")
.trace_err();
executor
.sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
.await;
}
}
});
}
async fn sync_model_request_usage_with_stripe(
app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
let usage_meters = llm_db
.get_current_subscription_usage_meters(Utc::now())
.await?;
let user_ids = usage_meters
.iter()
.map(|(_, usage)| usage.user_id)
.collect::<HashSet<UserId>>();
let billing_subscriptions = app
.db
.get_active_zed_pro_billing_subscriptions(user_ids)
.await?;
let claude_3_5_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
.await?;
let claude_3_7_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
.await?;
for (usage_meter, usage) in usage_meters {
maybe!(async {
let Some((billing_customer, billing_subscription)) =
billing_subscriptions.get(&usage.user_id)
else {
bail!(
"Attempted to sync usage meter for user who is not a Stripe customer: {}",
usage.user_id
);
};
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription_id = billing_subscription
.stripe_subscription_id
.parse::<stripe::SubscriptionId>()
.context("failed to parse Stripe subscription ID from database")?;
let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price_id, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
model_name => {
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
}
};
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price_id)
.await?;
stripe_billing
.bill_model_request_usage(
&stripe_customer_id,
meter_event_name,
usage_meter.requests,
)
.await?;
Ok(())
})
.await
.log_err();
}
Ok(())
}

View file

@ -191,6 +191,38 @@ impl Database {
.await
}
pub async fn get_active_zed_pro_billing_subscriptions(
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
}
})
.await
}
/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)

View file

@ -2,5 +2,6 @@ use super::*;
pub mod billing_events;
pub mod providers;
pub mod subscription_usage_meters;
pub mod subscription_usages;
pub mod usages;

View file

@ -0,0 +1,37 @@
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
use super::*;
impl LlmDatabase {
/// Returns all current subscription usage meters as of the given timestamp.
pub async fn get_current_subscription_usage_meters(
&self,
now: DateTimeUtc,
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
let now = convert_chrono_to_time(now)?;
self.transaction(|tx| async move {
let result = subscription_usage_meter::Entity::find()
.inner_join(subscription_usage::Entity)
.filter(
subscription_usage::Column::PeriodStartAt
.lte(now)
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
)
.select_also(subscription_usage::Entity)
.all(&*tx)
.await?;
let result = result
.into_iter()
.filter_map(|(meter, usage)| {
let usage = usage?;
Some((meter, usage))
})
.collect();
Ok(result)
})
.await
}
}

View file

@ -6,7 +6,7 @@ use crate::db::{UserId, billing_subscription};
use super::*;
fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
use chrono::{Datelike as _, Timelike as _};
let date = time::Date::from_calendar_date(

View file

@ -3,5 +3,6 @@ pub mod model;
pub mod monthly_usage;
pub mod provider;
pub mod subscription_usage;
pub mod subscription_usage_meter;
pub mod usage;
pub mod usage_measure;

View file

@ -0,0 +1,43 @@
use sea_orm::entity::prelude::*;
use crate::llm::db::ModelId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usage_meters")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub subscription_usage_id: i32,
pub model_id: ModelId,
pub requests: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::subscription_usage::Entity",
from = "Column::SubscriptionUsageId",
to = "super::subscription_usage::Column::Id"
)]
SubscriptionUsage,
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::subscription_usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::SubscriptionUsage.def()
}
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -8,7 +8,9 @@ use axum::{
};
use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::sync_llm_token_usage_with_stripe_periodically;
use collab::api::billing::{
sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
};
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
@ -152,6 +154,7 @@ async fn main() -> Result<()> {
if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?;
sync_llm_request_usage_with_stripe_periodically(state.clone());
sync_llm_token_usage_with_stripe_periodically(state.clone());
}

View file

@ -1,12 +1,13 @@
use std::sync::Arc;
use crate::{Cents, Result, llm};
use anyhow::Context as _;
use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::PriceId;
use tokio::sync::RwLock;
use uuid::Uuid;
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
@ -17,9 +18,10 @@ pub struct StripeBilling {
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
prices_by_lookup_key: HashMap<String, stripe::Price>,
}
pub struct StripeModel {
pub struct StripeModelTokenPrices {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
@ -62,6 +64,10 @@ impl StripeBilling {
}
for price in prices.data {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
state.price_ids_by_meter_id.insert(meter, price.id);
@ -74,36 +80,49 @@ impl StripeBilling {
Ok(())
}
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
&self,
model: &llm::db::model::Model,
) -> Result<StripeModelTokenPrices> {
let input_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
.get_or_insert_price(
.get_or_insert_token_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
Ok(StripeModel {
Ok(StripeModelTokenPrices {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
@ -111,7 +130,7 @@ impl StripeBilling {
})
}
async fn get_or_insert_price(
async fn get_or_insert_token_price(
&self,
meter_event_name: &str,
price_description: &str,
@ -207,10 +226,43 @@ impl StripeBilling {
})
}
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price_id: &stripe::PriceId,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, price_id) {
return Ok(());
}
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price_id.to_string()),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
},
}),
..Default::default()
},
)
.await?;
Ok(())
}
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
model: &StripeModel,
model: &StripeModelTokenPrices,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
@ -271,7 +323,7 @@ impl StripeBilling {
pub async fn bill_model_token_usage(
&self,
customer_id: &stripe::CustomerId,
model: &StripeModel,
model: &StripeModelTokenPrices,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
@ -343,11 +395,37 @@ impl StripeBilling {
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &stripe::CustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
value: requests as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
Ok(())
}
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
model: &StripeModel,
model: &StripeModelTokenPrices,
success_url: &str,
) -> Result<String> {
let first_of_next_month = Utc::now()