diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 507c3435d0..64cbe8422f 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -393,7 +393,9 @@ async fn create_billing_subscription( zed_llm_client::LanguageModelProvider::Anthropic, "claude-3-7-sonnet", )?; - let stripe_model = stripe_billing.register_model(default_model).await?; + let stripe_model = stripe_billing + .register_model_for_token_based_usage(default_model) + .await?; stripe_billing .checkout(customer_id, &user.github_login, &stripe_model, &success_url) .await? @@ -1303,7 +1305,9 @@ async fn sync_token_usage_with_stripe( .parse() .context("failed to parse stripe customer id from db")?; - let stripe_model = stripe_billing.register_model(&model).await?; + let stripe_model = stripe_billing + .register_model_for_token_based_usage(&model) + .await?; stripe_billing .subscribe_to_model(&stripe_subscription_id, &stripe_model) .await?; @@ -1315,3 +1319,106 @@ async fn sync_token_usage_with_stripe( Ok(()) } + +const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); + +pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc) { + let Some(stripe_billing) = app.stripe_billing.clone() else { + log::warn!("failed to retrieve Stripe billing object"); + return; + }; + let Some(llm_db) = app.llm_db.clone() else { + log::warn!("failed to retrieve LLM database"); + return; + }; + + let executor = app.executor.clone(); + executor.spawn_detached({ + let executor = executor.clone(); + async move { + loop { + sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing) + .await + .context("failed to sync LLM request usage to Stripe") + .trace_err(); + executor + .sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL) + .await; + } + } + }); +} + +async fn sync_model_request_usage_with_stripe( + app: &Arc, + llm_db: &Arc, + stripe_billing: &Arc, +) -> anyhow::Result<()> { + let usage_meters = llm_db + .get_current_subscription_usage_meters(Utc::now()) + .await?; + let user_ids = usage_meters + .iter() + .map(|(_, usage)| usage.user_id) + .collect::>(); + let billing_subscriptions = app + .db + .get_active_zed_pro_billing_subscriptions(user_ids) + .await?; + + let claude_3_5_sonnet = stripe_billing + .find_price_by_lookup_key("claude-3-5-sonnet-requests") + .await?; + let claude_3_7_sonnet = stripe_billing + .find_price_by_lookup_key("claude-3-7-sonnet-requests") + .await?; + + for (usage_meter, usage) in usage_meters { + maybe!(async { + let Some((billing_customer, billing_subscription)) = + billing_subscriptions.get(&usage.user_id) + else { + bail!( + "Attempted to sync usage meter for user who is not a Stripe customer: {}", + usage.user_id + ); + }; + + let stripe_customer_id = billing_customer + .stripe_customer_id + .parse::() + .context("failed to parse Stripe customer ID from database")?; + let stripe_subscription_id = billing_subscription + .stripe_subscription_id + .parse::() + .context("failed to parse Stripe subscription ID from database")?; + + let model = llm_db.model_by_id(usage_meter.model_id)?; + + let (price_id, meter_event_name) = match model.name.as_str() { + "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"), + "claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"), + model_name => { + bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") + } + }; + + stripe_billing + .subscribe_to_price(&stripe_subscription_id, price_id) + .await?; + stripe_billing + .bill_model_request_usage( + &stripe_customer_id, + meter_event_name, + usage_meter.requests, + ) + .await?; + + Ok(()) + }) + .await + .log_err(); + } + + Ok(()) +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 1107d0753e..e7fc8d208d 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -191,6 +191,38 @@ impl Database { .await } + pub async fn get_active_zed_pro_billing_subscriptions( + &self, + user_ids: HashSet, + ) -> Result> { + self.transaction(|tx| { + let user_ids = user_ids.clone(); + async move { + let mut rows = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .select_also(billing_customer::Entity) + .filter(billing_customer::Column::UserId.is_in(user_ids)) + .filter( + billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Active), + ) + .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) + .order_by_asc(billing_subscription::Column::Id) + .stream(&*tx) + .await?; + + let mut subscriptions = HashMap::default(); + while let Some(row) = rows.next().await { + if let (subscription, Some(customer)) = row? { + subscriptions.insert(customer.user_id, (customer, subscription)); + } + } + Ok(subscriptions) + } + }) + .await + } + /// Returns whether the user has an active billing subscription. pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result { Ok(self.count_active_billing_subscriptions(user_id).await? > 0) diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 6f9aab4a68..9820502f40 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -2,5 +2,6 @@ use super::*; pub mod billing_events; pub mod providers; +pub mod subscription_usage_meters; pub mod subscription_usages; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/subscription_usage_meters.rs b/crates/collab/src/llm/db/queries/subscription_usage_meters.rs new file mode 100644 index 0000000000..5666b1cda9 --- /dev/null +++ b/crates/collab/src/llm/db/queries/subscription_usage_meters.rs @@ -0,0 +1,37 @@ +use crate::llm::db::queries::subscription_usages::convert_chrono_to_time; + +use super::*; + +impl LlmDatabase { + /// Returns all current subscription usage meters as of the given timestamp. + pub async fn get_current_subscription_usage_meters( + &self, + now: DateTimeUtc, + ) -> Result> { + let now = convert_chrono_to_time(now)?; + + self.transaction(|tx| async move { + let result = subscription_usage_meter::Entity::find() + .inner_join(subscription_usage::Entity) + .filter( + subscription_usage::Column::PeriodStartAt + .lte(now) + .and(subscription_usage::Column::PeriodEndAt.gte(now)), + ) + .select_also(subscription_usage::Entity) + .all(&*tx) + .await?; + + let result = result + .into_iter() + .filter_map(|(meter, usage)| { + let usage = usage?; + Some((meter, usage)) + }) + .collect(); + + Ok(result) + }) + .await + } +} diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs index e26e8d3799..162bbd274d 100644 --- a/crates/collab/src/llm/db/queries/subscription_usages.rs +++ b/crates/collab/src/llm/db/queries/subscription_usages.rs @@ -6,7 +6,7 @@ use crate::db::{UserId, billing_subscription}; use super::*; -fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result { +pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result { use chrono::{Datelike as _, Timelike as _}; let date = time::Date::from_calendar_date( diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 0e99e01144..a21616da98 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -3,5 +3,6 @@ pub mod model; pub mod monthly_usage; pub mod provider; pub mod subscription_usage; +pub mod subscription_usage_meter; pub mod usage; pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/subscription_usage_meter.rs b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs new file mode 100644 index 0000000000..a7241e8f95 --- /dev/null +++ b/crates/collab/src/llm/db/tables/subscription_usage_meter.rs @@ -0,0 +1,43 @@ +use sea_orm::entity::prelude::*; + +use crate::llm::db::ModelId; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "subscription_usage_meters")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub subscription_usage_id: i32, + pub model_id: ModelId, + pub requests: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::subscription_usage::Entity", + from = "Column::SubscriptionUsageId", + to = "super::subscription_usage::Column::Id" + )] + SubscriptionUsage, + #[sea_orm( + belongs_to = "super::model::Entity", + from = "Column::ModelId", + to = "super::model::Column::Id" + )] + Model, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::SubscriptionUsage.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Model.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 25befc2845..e30e3f587a 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -8,7 +8,9 @@ use axum::{ }; use collab::api::CloudflareIpCountryHeader; -use collab::api::billing::sync_llm_token_usage_with_stripe_periodically; +use collab::api::billing::{ + sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically, +}; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; @@ -152,6 +154,7 @@ async fn main() -> Result<()> { if let Some(mut llm_db) = llm_db { llm_db.initialize().await?; + sync_llm_request_usage_with_stripe_periodically(state.clone()); sync_llm_token_usage_with_stripe_periodically(state.clone()); } diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index e992187a9c..880b887dfe 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -1,12 +1,13 @@ use std::sync::Arc; use crate::{Cents, Result, llm}; -use anyhow::Context as _; +use anyhow::{Context as _, anyhow}; use chrono::{Datelike, Utc}; use collections::HashMap; use serde::{Deserialize, Serialize}; use stripe::PriceId; use tokio::sync::RwLock; +use uuid::Uuid; pub struct StripeBilling { state: RwLock, @@ -17,9 +18,10 @@ pub struct StripeBilling { struct StripeBillingState { meters_by_event_name: HashMap, price_ids_by_meter_id: HashMap, + prices_by_lookup_key: HashMap, } -pub struct StripeModel { +pub struct StripeModelTokenPrices { input_tokens_price: StripeBillingPrice, input_cache_creation_tokens_price: StripeBillingPrice, input_cache_read_tokens_price: StripeBillingPrice, @@ -62,6 +64,10 @@ impl StripeBilling { } for price in prices.data { + if let Some(lookup_key) = price.lookup_key.clone() { + state.prices_by_lookup_key.insert(lookup_key, price.clone()); + } + if let Some(recurring) = price.recurring { if let Some(meter) = recurring.meter { state.price_ids_by_meter_id.insert(meter, price.id); @@ -74,36 +80,49 @@ impl StripeBilling { Ok(()) } - pub async fn register_model(&self, model: &llm::db::model::Model) -> Result { + pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { + self.state + .read() + .await + .prices_by_lookup_key + .get(lookup_key) + .cloned() + .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) + } + + pub async fn register_model_for_token_based_usage( + &self, + model: &llm::db::model::Model, + ) -> Result { let input_tokens_price = self - .get_or_insert_price( + .get_or_insert_token_price( &format!("model_{}/input_tokens", model.id), &format!("{} (Input Tokens)", model.name), Cents::new(model.price_per_million_input_tokens as u32), ) .await?; let input_cache_creation_tokens_price = self - .get_or_insert_price( + .get_or_insert_token_price( &format!("model_{}/input_cache_creation_tokens", model.id), &format!("{} (Input Cache Creation Tokens)", model.name), Cents::new(model.price_per_million_cache_creation_input_tokens as u32), ) .await?; let input_cache_read_tokens_price = self - .get_or_insert_price( + .get_or_insert_token_price( &format!("model_{}/input_cache_read_tokens", model.id), &format!("{} (Input Cache Read Tokens)", model.name), Cents::new(model.price_per_million_cache_read_input_tokens as u32), ) .await?; let output_tokens_price = self - .get_or_insert_price( + .get_or_insert_token_price( &format!("model_{}/output_tokens", model.id), &format!("{} (Output Tokens)", model.name), Cents::new(model.price_per_million_output_tokens as u32), ) .await?; - Ok(StripeModel { + Ok(StripeModelTokenPrices { input_tokens_price, input_cache_creation_tokens_price, input_cache_read_tokens_price, @@ -111,7 +130,7 @@ impl StripeBilling { }) } - async fn get_or_insert_price( + async fn get_or_insert_token_price( &self, meter_event_name: &str, price_description: &str, @@ -207,10 +226,43 @@ impl StripeBilling { }) } + pub async fn subscribe_to_price( + &self, + subscription_id: &stripe::SubscriptionId, + price_id: &stripe::PriceId, + ) -> Result<()> { + let subscription = + stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; + + if subscription_contains_price(&subscription, price_id) { + return Ok(()); + } + + stripe::Subscription::update( + &self.client, + subscription_id, + stripe::UpdateSubscription { + items: Some(vec![stripe::UpdateSubscriptionItems { + price: Some(price_id.to_string()), + ..Default::default() + }]), + trial_settings: Some(stripe::UpdateSubscriptionTrialSettings { + end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior { + missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel, + }, + }), + ..Default::default() + }, + ) + .await?; + + Ok(()) + } + pub async fn subscribe_to_model( &self, subscription_id: &stripe::SubscriptionId, - model: &StripeModel, + model: &StripeModelTokenPrices, ) -> Result<()> { let subscription = stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; @@ -271,7 +323,7 @@ impl StripeBilling { pub async fn bill_model_token_usage( &self, customer_id: &stripe::CustomerId, - model: &StripeModel, + model: &StripeModelTokenPrices, event: &llm::db::billing_event::Model, ) -> Result<()> { let timestamp = Utc::now().timestamp(); @@ -343,11 +395,37 @@ impl StripeBilling { Ok(()) } + pub async fn bill_model_request_usage( + &self, + customer_id: &stripe::CustomerId, + event_name: &str, + requests: i32, + ) -> Result<()> { + let timestamp = Utc::now().timestamp(); + let idempotency_key = Uuid::new_v4(); + + StripeMeterEvent::create( + &self.client, + StripeCreateMeterEventParams { + identifier: &format!("model_requests/{}", idempotency_key), + event_name, + payload: StripeCreateMeterEventPayload { + value: requests as u64, + stripe_customer_id: customer_id, + }, + timestamp: Some(timestamp), + }, + ) + .await?; + + Ok(()) + } + pub async fn checkout( &self, customer_id: stripe::CustomerId, github_login: &str, - model: &StripeModel, + model: &StripeModelTokenPrices, success_url: &str, ) -> Result { let first_of_next_month = Utc::now()