diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index c8df066cbf..77ed9a9ea8 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -5,7 +5,7 @@ use axum::{ routing::{get, post}, }; use chrono::{DateTime, SecondsFormat, Utc}; -use collections::HashSet; +use collections::{HashMap, HashSet}; use reqwest::StatusCode; use sea_orm::ActiveValue; use serde::{Deserialize, Serialize}; @@ -21,12 +21,13 @@ use stripe::{ PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, }; use util::{ResultExt, maybe}; +use zed_llm_client::LanguageModelProvider; use crate::api::events::SnowflakeRow; use crate::db::billing_subscription::{ StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, }; -use crate::llm::db::subscription_usage_meter::CompletionMode; +use crate::llm::db::subscription_usage_meter::{self, CompletionMode}; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND}; use crate::rpc::{ResultExt as _, Server}; use crate::stripe_client::{ @@ -1416,18 +1417,21 @@ async fn sync_model_request_usage_with_stripe( let usage_meters = llm_db .get_current_subscription_usage_meters(Utc::now()) .await?; - let usage_meters = usage_meters - .into_iter() - .filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id)) - .collect::>(); - let user_ids = usage_meters - .iter() - .map(|(_, usage)| usage.user_id) - .collect::>(); - let billing_subscriptions = app - .db - .get_active_zed_pro_billing_subscriptions(user_ids) - .await?; + let mut usage_meters_by_user_id = + HashMap::>::default(); + for (usage_meter, usage) in usage_meters { + let meters = usage_meters_by_user_id.entry(usage.user_id).or_default(); + meters.push(usage_meter); + } + + log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions"); + let get_zed_pro_subscriptions_started_at = Utc::now(); + let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?; + log::info!( + "Stripe usage sync: Retrieved {} Zed pro subscriptions in {}", + billing_subscriptions.len(), + Utc::now() - get_zed_pro_subscriptions_started_at + ); let claude_sonnet_4 = stripe_billing .find_price_by_lookup_key("claude-sonnet-4-requests") @@ -1451,59 +1455,90 @@ async fn sync_model_request_usage_with_stripe( .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") .await?; - let usage_meter_count = usage_meters.len(); + let model_mode_combinations = [ + ("claude-opus-4", CompletionMode::Max), + ("claude-opus-4", CompletionMode::Normal), + ("claude-sonnet-4", CompletionMode::Max), + ("claude-sonnet-4", CompletionMode::Normal), + ("claude-3-7-sonnet", CompletionMode::Max), + ("claude-3-7-sonnet", CompletionMode::Normal), + ("claude-3-5-sonnet", CompletionMode::Normal), + ]; - log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters"); + let billing_subscription_count = billing_subscriptions.len(); - for (usage_meter, usage) in usage_meters { + log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions"); + + for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions { maybe!(async { - let Some((billing_customer, billing_subscription)) = - billing_subscriptions.get(&usage.user_id) - else { - bail!( - "Attempted to sync usage meter for user who is not a Stripe customer: {}", - usage.user_id - ); - }; + if staff_user_ids.contains(&user_id) { + return anyhow::Ok(()); + } let stripe_customer_id = StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); let stripe_subscription_id = StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); - let model = llm_db.model_by_id(usage_meter.model_id)?; + let usage_meters = usage_meters_by_user_id.get(&user_id); - let (price, meter_event_name) = match model.name.as_str() { - "claude-opus-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), - CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), - }, - "claude-sonnet-4" => match usage_meter.mode { - CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), - CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"), - }, - "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), - "claude-3-7-sonnet" => match usage_meter.mode { - CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), - CompletionMode::Max => { - (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") + for (model, mode) in &model_mode_combinations { + let Ok(model) = + llm_db.model(LanguageModelProvider::Anthropic, model) + else { + log::warn!("Failed to load model for user {user_id}: {model}"); + continue; + }; + + let (price, meter_event_name) = match model.name.as_str() { + "claude-opus-4" => match mode { + CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), + CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), + }, + "claude-sonnet-4" => match mode { + CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), + CompletionMode::Max => { + (&claude_sonnet_4_max, "claude_sonnet_4/requests/max") + } + }, + "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), + "claude-3-7-sonnet" => match mode { + CompletionMode::Normal => { + (&claude_3_7_sonnet, "claude_3_7_sonnet/requests") + } + CompletionMode::Max => { + (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") + } + }, + model_name => { + bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") } - }, - model_name => { - bail!("Attempted to sync usage meter for unsupported model: {model_name:?}") - } - }; + }; - stripe_billing - .subscribe_to_price(&stripe_subscription_id, price) - .await?; - stripe_billing - .bill_model_request_usage( - &stripe_customer_id, - meter_event_name, - usage_meter.requests, - ) - .await?; + let model_requests = usage_meters + .and_then(|usage_meters| { + usage_meters + .iter() + .find(|meter| meter.model_id == model.id && meter.mode == *mode) + }) + .map(|usage_meter| usage_meter.requests) + .unwrap_or(0); + + if model_requests > 0 { + stripe_billing + .subscribe_to_price(&stripe_subscription_id, price) + .await?; + } + + stripe_billing + .bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests) + .await + .with_context(|| { + format!( + "Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}", + ) + })?; + } Ok(()) }) @@ -1512,7 +1547,7 @@ async fn sync_model_request_usage_with_stripe( } log::info!( - "Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}", + "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}", Utc::now() - started_at ); diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index f25d0abeaa..9f82e3dbc4 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -199,6 +199,33 @@ impl Database { pub async fn get_active_zed_pro_billing_subscriptions( &self, + ) -> Result> { + self.transaction(|tx| async move { + let mut rows = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .select_also(billing_customer::Entity) + .filter( + billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Active), + ) + .filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro)) + .order_by_asc(billing_subscription::Column::Id) + .stream(&*tx) + .await?; + + let mut subscriptions = HashMap::default(); + while let Some(row) = rows.next().await { + if let (subscription, Some(customer)) = row? { + subscriptions.insert(customer.user_id, (customer, subscription)); + } + } + Ok(subscriptions) + }) + .await + } + + pub async fn get_active_zed_pro_billing_subscriptions_for_users( + &self, user_ids: HashSet, ) -> Result> { self.transaction(|tx| {