collab: Remove code for syncing token-based billing events (#30130)

This PR removes the code related to syncing token-based billing events
to Stripe.

We don't need this anymore with the new billing.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-07 12:11:51 -04:00 committed by GitHub
parent a34fb6f6b1
commit d50562ed81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 8 additions and 521 deletions

View file

@ -301,13 +301,6 @@ async fn create_billing_subscription(
"not supported".into(),
))?
};
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
if app.db.has_active_billing_subscription(user.id).await? {
return Err(Error::http(
@ -399,16 +392,10 @@ async fn create_billing_subscription(
.await?
}
None => {
let default_model = llm_db.model(
zed_llm_client::LanguageModelProvider::Anthropic,
"claude-3-7-sonnet",
)?;
let stripe_model = stripe_billing
.register_model_for_token_based_usage(default_model)
.await?;
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
return Err(Error::http(
StatusCode::BAD_REQUEST,
"No product selected".into(),
));
}
};
@ -1381,81 +1368,6 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::warn!("failed to retrieve Stripe billing object");
return;
};
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
return;
};
let executor = app.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
loop {
sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
.await
.context("failed to sync LLM usage to Stripe")
.trace_err();
executor
.sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
.await;
}
}
});
}
async fn sync_token_usage_with_stripe(
app: &Arc<AppState>,
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
let events = llm_db.get_billing_events().await?;
let user_ids = events
.iter()
.map(|(event, _)| event.user_id)
.collect::<HashSet<UserId>>();
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
for (event, model) in events {
let Some((stripe_db_customer, stripe_db_subscription)) =
stripe_subscriptions.get(&event.user_id)
else {
tracing::warn!(
user_id = event.user_id.0,
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
);
continue;
};
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
.stripe_subscription_id
.parse()
.context("failed to parse stripe subscription id from db")?;
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
.stripe_customer_id
.parse()
.context("failed to parse stripe customer id from db")?;
let stripe_model = stripe_billing
.register_model_for_token_based_usage(&model)
.await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
stripe_billing
.bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
.await?;
llm_db.consume_billing_event(event.id).await?;
}
Ok(())
}
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {

View file

@ -1,6 +1,5 @@
use super::*;
pub mod billing_events;
pub mod providers;
pub mod subscription_usage_meters;
pub mod subscription_usages;

View file

@ -1,31 +0,0 @@
use super::*;
use crate::Result;
use anyhow::Context as _;
impl LlmDatabase {
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
self.transaction(|tx| async move {
let events_with_models = billing_event::Entity::find()
.find_also_related(model::Entity)
.all(&*tx)
.await?;
events_with_models
.into_iter()
.map(|(event, model)| {
let model =
model.context("could not find model associated with billing event")?;
Ok((event, model))
})
.collect()
})
.await
}
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
self.transaction(|tx| async move {
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
Ok(())
})
.await
}
}

View file

@ -1,4 +1,3 @@
pub mod billing_event;
pub mod model;
pub mod monthly_usage;
pub mod provider;

View file

@ -1,37 +0,0 @@
use crate::{
db::UserId,
llm::db::{BillingEventId, ModelId},
};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_events")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingEventId,
pub idempotency_key: Uuid,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub input_cache_creation_tokens: i64,
pub input_cache_read_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -31,8 +31,6 @@ pub enum Relation {
Provider,
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
#[sea_orm(has_many = "super::billing_event::Entity")]
BillingEvents,
}
impl Related<super::provider::Entity> for Entity {
@ -47,10 +45,4 @@ impl Related<super::usage::Entity> for Entity {
}
}
impl Related<super::billing_event::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingEvents.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -8,9 +8,7 @@ use axum::{
};
use collab::api::CloudflareIpCountryHeader;
use collab::api::billing::{
sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
};
use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
@ -155,7 +153,6 @@ async fn main() -> Result<()> {
if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?;
sync_llm_request_usage_with_stripe_periodically(state.clone());
sync_llm_token_usage_with_stripe_periodically(state.clone());
}
app = app

View file

@ -1,9 +1,9 @@
use std::sync::Arc;
use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
use crate::{Cents, Result};
use crate::Result;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::PriceId;
@ -22,18 +22,6 @@ struct StripeBillingState {
prices_by_lookup_key: HashMap<String, stripe::Price>,
}
pub struct StripeModelTokenPrices {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
output_tokens_price: StripeBillingPrice,
}
struct StripeBillingPrice {
id: stripe::PriceId,
meter_event_name: String,
}
impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self {
Self {
@ -109,142 +97,6 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
&self,
model: &llm::db::model::Model,
) -> Result<StripeModelTokenPrices> {
let input_tokens_price = self
.get_or_insert_token_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
.get_or_insert_token_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
.get_or_insert_token_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
.get_or_insert_token_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
Ok(StripeModelTokenPrices {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
output_tokens_price,
})
}
async fn get_or_insert_token_price(
&self,
meter_event_name: &str,
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
// Fast code path when the meter and the price already exist.
{
let state = self.state.read().await;
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
return Ok(StripeBillingPrice {
id: price_id.clone(),
meter_event_name: meter_event_name.to_string(),
});
}
}
}
let mut state = self.state.write().await;
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
meter.clone()
} else {
let meter = StripeMeter::create(
&self.client,
StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
state
.meters_by_event_name
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true),
metadata: None,
name: price_description.to_string(),
statement_descriptor: None,
tax_code: None,
unit_label: None,
}),
recurring: Some(stripe::CreatePriceRecurring {
aggregate_usage: None,
interval: stripe::CreatePriceRecurringInterval::Month,
interval_count: None,
trial_period_days: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
meter: Some(meter.id.clone()),
}),
tax_behavior: None,
tiers: None,
tiers_mode: None,
transfer_lookup_key: None,
transform_quantity: None,
unit_amount: None,
unit_amount_decimal: Some(&format!(
"{:.12}",
price_per_million_tokens.0 as f64 / 1_000_000f64
)),
},
)
.await?;
state
.price_ids_by_meter_id
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice {
id: price_id,
meter_event_name: meter_event_name.to_string(),
})
}
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
@ -283,142 +135,6 @@ impl StripeBilling {
Ok(())
}
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
model: &StripeModelTokenPrices,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
let mut items = Vec::new();
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
{
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_read_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.output_tokens_price.id.to_string()),
..Default::default()
});
}
if !items.is_empty() {
items.extend(subscription.items.data.iter().map(|item| {
stripe::UpdateSubscriptionItems {
id: Some(item.id.to_string()),
..Default::default()
}
}));
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(items),
..Default::default()
},
)
.await?;
}
Ok(())
}
pub async fn bill_model_token_usage(
&self,
customer_id: &stripe::CustomerId,
model: &StripeModelTokenPrices,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
if event.input_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_tokens/{}", event.idempotency_key),
event_name: &model.input_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_creation_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_creation_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_read_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_read_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_read_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.output_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("output_tokens/{}", event.idempotency_key),
event_name: &model.output_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.output_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
Ok(())
}
pub async fn bill_model_request_usage(
&self,
customer_id: &stripe::CustomerId,
@ -445,47 +161,6 @@ impl StripeBilling {
Ok(())
}
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
model: &StripeModelTokenPrices,
success_url: &str,
) -> Result<String> {
let first_of_next_month = Utc::now()
.checked_add_months(chrono::Months::new(1))
.unwrap()
.with_day(1)
.unwrap();
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
billing_cycle_anchor: Some(first_of_next_month.timestamp()),
..Default::default()
});
params.line_items = Some(
[
&model.input_tokens_price.id,
&model.input_cache_creation_tokens_price.id,
&model.input_cache_read_tokens_price.id,
&model.output_tokens_price.id,
]
.into_iter()
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
price: Some(price_id.to_string()),
..Default::default()
})
.collect(),
);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro(
&self,
customer_id: stripe::CustomerId,
@ -587,18 +262,6 @@ impl StripeBilling {
}
}
#[derive(Serialize)]
struct DefaultAggregation {
formula: &'static str,
}
#[derive(Serialize)]
struct StripeCreateMeterParams<'a> {
default_aggregation: DefaultAggregation,
display_name: String,
event_name: &'a str,
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
@ -606,13 +269,6 @@ struct StripeMeter {
}
impl StripeMeter {
pub fn create(
client: &stripe::Client,
params: StripeCreateMeterParams,
) -> stripe::Response<Self> {
client.post_form("/billing/meters", params)
}
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {