diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 39d124dd96..ff8849aff5 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -301,13 +301,6 @@ async fn create_billing_subscription( "not supported".into(), ))? }; - let Some(llm_db) = app.llm_db.clone() else { - log::error!("failed to retrieve LLM database"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; if app.db.has_active_billing_subscription(user.id).await? { return Err(Error::http( @@ -399,16 +392,10 @@ async fn create_billing_subscription( .await? } None => { - let default_model = llm_db.model( - zed_llm_client::LanguageModelProvider::Anthropic, - "claude-3-7-sonnet", - )?; - let stripe_model = stripe_billing - .register_model_for_token_based_usage(default_model) - .await?; - stripe_billing - .checkout(customer_id, &user.github_login, &stripe_model, &success_url) - .await? + return Err(Error::http( + StatusCode::BAD_REQUEST, + "No product selected".into(), + )); } }; @@ -1381,81 +1368,6 @@ async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } -const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); - -pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc) { - let Some(stripe_billing) = app.stripe_billing.clone() else { - log::warn!("failed to retrieve Stripe billing object"); - return; - }; - let Some(llm_db) = app.llm_db.clone() else { - log::warn!("failed to retrieve LLM database"); - return; - }; - - let executor = app.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - loop { - sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing) - .await - .context("failed to sync LLM usage to Stripe") - .trace_err(); - executor - .sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL) - .await; - } - } - }); -} - -async fn sync_token_usage_with_stripe( - app: &Arc, - llm_db: &Arc, - stripe_billing: &Arc, -) -> anyhow::Result<()> { - let events = llm_db.get_billing_events().await?; - let user_ids = events - .iter() - .map(|(event, _)| event.user_id) - .collect::>(); - let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?; - - for (event, model) in events { - let Some((stripe_db_customer, stripe_db_subscription)) = - stripe_subscriptions.get(&event.user_id) - else { - tracing::warn!( - user_id = event.user_id.0, - "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side." - ); - continue; - }; - let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription - .stripe_subscription_id - .parse() - .context("failed to parse stripe subscription id from db")?; - let stripe_customer_id: stripe::CustomerId = stripe_db_customer - .stripe_customer_id - .parse() - .context("failed to parse stripe customer id from db")?; - - let stripe_model = stripe_billing - .register_model_for_token_based_usage(&model) - .await?; - stripe_billing - .subscribe_to_model(&stripe_subscription_id, &stripe_model) - .await?; - stripe_billing - .bill_model_token_usage(&stripe_customer_id, &stripe_model, &event) - .await?; - llm_db.consume_billing_event(event.id).await?; - } - - Ok(()) -} - const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc) { diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 9820502f40..3565366fdd 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,6 +1,5 @@ use super::*; -pub mod billing_events; pub mod providers; pub mod subscription_usage_meters; pub mod subscription_usages; diff --git a/crates/collab/src/llm/db/queries/billing_events.rs b/crates/collab/src/llm/db/queries/billing_events.rs deleted file mode 100644 index 400477f234..0000000000 --- a/crates/collab/src/llm/db/queries/billing_events.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::*; -use crate::Result; -use anyhow::Context as _; - -impl LlmDatabase { - pub async fn get_billing_events(&self) -> Result> { - self.transaction(|tx| async move { - let events_with_models = billing_event::Entity::find() - .find_also_related(model::Entity) - .all(&*tx) - .await?; - events_with_models - .into_iter() - .map(|(event, model)| { - let model = - model.context("could not find model associated with billing event")?; - Ok((event, model)) - }) - .collect() - }) - .await - } - - pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> { - self.transaction(|tx| async move { - billing_event::Entity::delete_by_id(id).exec(&*tx).await?; - Ok(()) - }) - .await - } -} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index a21616da98..d178fb10d3 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,4 +1,3 @@ -pub mod billing_event; pub mod model; pub mod monthly_usage; pub mod provider; diff --git a/crates/collab/src/llm/db/tables/billing_event.rs b/crates/collab/src/llm/db/tables/billing_event.rs deleted file mode 100644 index 93987bc71e..0000000000 --- a/crates/collab/src/llm/db/tables/billing_event.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::{ - db::UserId, - llm::db::{BillingEventId, ModelId}, -}; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] -#[sea_orm(table_name = "billing_events")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: BillingEventId, - pub idempotency_key: Uuid, - pub user_id: UserId, - pub model_id: ModelId, - pub input_tokens: i64, - pub input_cache_creation_tokens: i64, - pub input_cache_read_tokens: i64, - pub output_tokens: i64, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::model::Entity", - from = "Column::ModelId", - to = "super::model::Column::Id" - )] - Model, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Model.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs index 3453c34726..f0a858b4a6 100644 --- a/crates/collab/src/llm/db/tables/model.rs +++ b/crates/collab/src/llm/db/tables/model.rs @@ -31,8 +31,6 @@ pub enum Relation { Provider, #[sea_orm(has_many = "super::usage::Entity")] Usages, - #[sea_orm(has_many = "super::billing_event::Entity")] - BillingEvents, } impl Related for Entity { @@ -47,10 +45,4 @@ impl Related for Entity { } } -impl Related for Entity { - fn to() -> RelationDef { - Relation::BillingEvents.def() - } -} - impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index e30e3f587a..e5240666c4 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -8,9 +8,7 @@ use axum::{ }; use collab::api::CloudflareIpCountryHeader; -use collab::api::billing::{ - sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically, -}; +use collab::api::billing::sync_llm_request_usage_with_stripe_periodically; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; @@ -155,7 +153,6 @@ async fn main() -> Result<()> { if let Some(mut llm_db) = llm_db { llm_db.initialize().await?; sync_llm_request_usage_with_stripe_periodically(state.clone()); - sync_llm_token_usage_with_stripe_periodically(state.clone()); } app = app diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 2e83bc116e..6c89cc94be 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG}; -use crate::{Cents, Result}; +use crate::Result; +use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use anyhow::{Context as _, anyhow}; -use chrono::{Datelike, Utc}; +use chrono::Utc; use collections::HashMap; use serde::{Deserialize, Serialize}; use stripe::PriceId; @@ -22,18 +22,6 @@ struct StripeBillingState { prices_by_lookup_key: HashMap, } -pub struct StripeModelTokenPrices { - input_tokens_price: StripeBillingPrice, - input_cache_creation_tokens_price: StripeBillingPrice, - input_cache_read_tokens_price: StripeBillingPrice, - output_tokens_price: StripeBillingPrice, -} - -struct StripeBillingPrice { - id: stripe::PriceId, - meter_event_name: String, -} - impl StripeBilling { pub fn new(client: Arc) -> Self { Self { @@ -109,142 +97,6 @@ impl StripeBilling { .ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}"))) } - pub async fn register_model_for_token_based_usage( - &self, - model: &llm::db::model::Model, - ) -> Result { - let input_tokens_price = self - .get_or_insert_token_price( - &format!("model_{}/input_tokens", model.id), - &format!("{} (Input Tokens)", model.name), - Cents::new(model.price_per_million_input_tokens as u32), - ) - .await?; - let input_cache_creation_tokens_price = self - .get_or_insert_token_price( - &format!("model_{}/input_cache_creation_tokens", model.id), - &format!("{} (Input Cache Creation Tokens)", model.name), - Cents::new(model.price_per_million_cache_creation_input_tokens as u32), - ) - .await?; - let input_cache_read_tokens_price = self - .get_or_insert_token_price( - &format!("model_{}/input_cache_read_tokens", model.id), - &format!("{} (Input Cache Read Tokens)", model.name), - Cents::new(model.price_per_million_cache_read_input_tokens as u32), - ) - .await?; - let output_tokens_price = self - .get_or_insert_token_price( - &format!("model_{}/output_tokens", model.id), - &format!("{} (Output Tokens)", model.name), - Cents::new(model.price_per_million_output_tokens as u32), - ) - .await?; - Ok(StripeModelTokenPrices { - input_tokens_price, - input_cache_creation_tokens_price, - input_cache_read_tokens_price, - output_tokens_price, - }) - } - - async fn get_or_insert_token_price( - &self, - meter_event_name: &str, - price_description: &str, - price_per_million_tokens: Cents, - ) -> Result { - // Fast code path when the meter and the price already exist. - { - let state = self.state.read().await; - if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { - if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { - return Ok(StripeBillingPrice { - id: price_id.clone(), - meter_event_name: meter_event_name.to_string(), - }); - } - } - } - - let mut state = self.state.write().await; - let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) { - meter.clone() - } else { - let meter = StripeMeter::create( - &self.client, - StripeCreateMeterParams { - default_aggregation: DefaultAggregation { formula: "sum" }, - display_name: price_description.to_string(), - event_name: meter_event_name, - }, - ) - .await?; - state - .meters_by_event_name - .insert(meter_event_name.to_string(), meter.clone()); - meter - }; - - let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) { - price_id.clone() - } else { - let price = stripe::Price::create( - &self.client, - stripe::CreatePrice { - active: Some(true), - billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), - currency: stripe::Currency::USD, - currency_options: None, - custom_unit_amount: None, - expand: &[], - lookup_key: None, - metadata: None, - nickname: None, - product: None, - product_data: Some(stripe::CreatePriceProductData { - id: None, - active: Some(true), - metadata: None, - name: price_description.to_string(), - statement_descriptor: None, - tax_code: None, - unit_label: None, - }), - recurring: Some(stripe::CreatePriceRecurring { - aggregate_usage: None, - interval: stripe::CreatePriceRecurringInterval::Month, - interval_count: None, - trial_period_days: None, - usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), - meter: Some(meter.id.clone()), - }), - tax_behavior: None, - tiers: None, - tiers_mode: None, - transfer_lookup_key: None, - transform_quantity: None, - unit_amount: None, - unit_amount_decimal: Some(&format!( - "{:.12}", - price_per_million_tokens.0 as f64 / 1_000_000f64 - )), - }, - ) - .await?; - state - .price_ids_by_meter_id - .insert(meter.id, price.id.clone()); - price.id - }; - - Ok(StripeBillingPrice { - id: price_id, - meter_event_name: meter_event_name.to_string(), - }) - } - pub async fn subscribe_to_price( &self, subscription_id: &stripe::SubscriptionId, @@ -283,142 +135,6 @@ impl StripeBilling { Ok(()) } - pub async fn subscribe_to_model( - &self, - subscription_id: &stripe::SubscriptionId, - model: &StripeModelTokenPrices, - ) -> Result<()> { - let subscription = - stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; - - let mut items = Vec::new(); - - if !subscription_contains_price(&subscription, &model.input_tokens_price.id) { - items.push(stripe::UpdateSubscriptionItems { - price: Some(model.input_tokens_price.id.to_string()), - ..Default::default() - }); - } - - if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id) - { - items.push(stripe::UpdateSubscriptionItems { - price: Some(model.input_cache_creation_tokens_price.id.to_string()), - ..Default::default() - }); - } - - if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) { - items.push(stripe::UpdateSubscriptionItems { - price: Some(model.input_cache_read_tokens_price.id.to_string()), - ..Default::default() - }); - } - - if !subscription_contains_price(&subscription, &model.output_tokens_price.id) { - items.push(stripe::UpdateSubscriptionItems { - price: Some(model.output_tokens_price.id.to_string()), - ..Default::default() - }); - } - - if !items.is_empty() { - items.extend(subscription.items.data.iter().map(|item| { - stripe::UpdateSubscriptionItems { - id: Some(item.id.to_string()), - ..Default::default() - } - })); - - stripe::Subscription::update( - &self.client, - subscription_id, - stripe::UpdateSubscription { - items: Some(items), - ..Default::default() - }, - ) - .await?; - } - - Ok(()) - } - - pub async fn bill_model_token_usage( - &self, - customer_id: &stripe::CustomerId, - model: &StripeModelTokenPrices, - event: &llm::db::billing_event::Model, - ) -> Result<()> { - let timestamp = Utc::now().timestamp(); - - if event.input_tokens > 0 { - StripeMeterEvent::create( - &self.client, - StripeCreateMeterEventParams { - identifier: &format!("input_tokens/{}", event.idempotency_key), - event_name: &model.input_tokens_price.meter_event_name, - payload: StripeCreateMeterEventPayload { - value: event.input_tokens as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }, - ) - .await?; - } - - if event.input_cache_creation_tokens > 0 { - StripeMeterEvent::create( - &self.client, - StripeCreateMeterEventParams { - identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key), - event_name: &model.input_cache_creation_tokens_price.meter_event_name, - payload: StripeCreateMeterEventPayload { - value: event.input_cache_creation_tokens as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }, - ) - .await?; - } - - if event.input_cache_read_tokens > 0 { - StripeMeterEvent::create( - &self.client, - StripeCreateMeterEventParams { - identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key), - event_name: &model.input_cache_read_tokens_price.meter_event_name, - payload: StripeCreateMeterEventPayload { - value: event.input_cache_read_tokens as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }, - ) - .await?; - } - - if event.output_tokens > 0 { - StripeMeterEvent::create( - &self.client, - StripeCreateMeterEventParams { - identifier: &format!("output_tokens/{}", event.idempotency_key), - event_name: &model.output_tokens_price.meter_event_name, - payload: StripeCreateMeterEventPayload { - value: event.output_tokens as u64, - stripe_customer_id: customer_id, - }, - timestamp: Some(timestamp), - }, - ) - .await?; - } - - Ok(()) - } - pub async fn bill_model_request_usage( &self, customer_id: &stripe::CustomerId, @@ -445,47 +161,6 @@ impl StripeBilling { Ok(()) } - pub async fn checkout( - &self, - customer_id: stripe::CustomerId, - github_login: &str, - model: &StripeModelTokenPrices, - success_url: &str, - ) -> Result { - let first_of_next_month = Utc::now() - .checked_add_months(chrono::Months::new(1)) - .unwrap() - .with_day(1) - .unwrap(); - - let mut params = stripe::CreateCheckoutSession::new(); - params.mode = Some(stripe::CheckoutSessionMode::Subscription); - params.customer = Some(customer_id); - params.client_reference_id = Some(github_login); - params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData { - billing_cycle_anchor: Some(first_of_next_month.timestamp()), - ..Default::default() - }); - params.line_items = Some( - [ - &model.input_tokens_price.id, - &model.input_cache_creation_tokens_price.id, - &model.input_cache_read_tokens_price.id, - &model.output_tokens_price.id, - ] - .into_iter() - .map(|price_id| stripe::CreateCheckoutSessionLineItems { - price: Some(price_id.to_string()), - ..Default::default() - }) - .collect(), - ); - params.success_url = Some(success_url); - - let session = stripe::CheckoutSession::create(&self.client, params).await?; - Ok(session.url.context("no checkout session URL")?) - } - pub async fn checkout_with_zed_pro( &self, customer_id: stripe::CustomerId, @@ -587,18 +262,6 @@ impl StripeBilling { } } -#[derive(Serialize)] -struct DefaultAggregation { - formula: &'static str, -} - -#[derive(Serialize)] -struct StripeCreateMeterParams<'a> { - default_aggregation: DefaultAggregation, - display_name: String, - event_name: &'a str, -} - #[derive(Clone, Deserialize)] struct StripeMeter { id: String, @@ -606,13 +269,6 @@ struct StripeMeter { } impl StripeMeter { - pub fn create( - client: &stripe::Client, - params: StripeCreateMeterParams, - ) -> stripe::Response { - client.post_form("/billing/meters", params) - } - pub fn list(client: &stripe::Client) -> stripe::Response> { #[derive(Serialize)] struct Params {