collab: Sync model overages for all active Zed Pro subscriptions (#34452)

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-07-15 09:01:01 -04:00 committed by GitHub
parent 52f2b32557
commit 848a86a385
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 118 additions and 56 deletions

View file

@ -5,7 +5,7 @@ use axum::{
routing::{get, post}, routing::{get, post},
}; };
use chrono::{DateTime, SecondsFormat, Utc}; use chrono::{DateTime, SecondsFormat, Utc};
use collections::HashSet; use collections::{HashMap, HashSet};
use reqwest::StatusCode; use reqwest::StatusCode;
use sea_orm::ActiveValue; use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -21,12 +21,13 @@ use stripe::{
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
}; };
use util::{ResultExt, maybe}; use util::{ResultExt, maybe};
use zed_llm_client::LanguageModelProvider;
use crate::api::events::SnowflakeRow; use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{ use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind, StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
}; };
use crate::llm::db::subscription_usage_meter::CompletionMode; use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND}; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server}; use crate::rpc::{ResultExt as _, Server};
use crate::stripe_client::{ use crate::stripe_client::{
@ -1416,18 +1417,21 @@ async fn sync_model_request_usage_with_stripe(
let usage_meters = llm_db let usage_meters = llm_db
.get_current_subscription_usage_meters(Utc::now()) .get_current_subscription_usage_meters(Utc::now())
.await?; .await?;
let usage_meters = usage_meters let mut usage_meters_by_user_id =
.into_iter() HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
.filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id)) for (usage_meter, usage) in usage_meters {
.collect::<Vec<_>>(); let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
let user_ids = usage_meters meters.push(usage_meter);
.iter() }
.map(|(_, usage)| usage.user_id)
.collect::<HashSet<UserId>>(); log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
let billing_subscriptions = app let get_zed_pro_subscriptions_started_at = Utc::now();
.db let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
.get_active_zed_pro_billing_subscriptions(user_ids) log::info!(
.await?; "Stripe usage sync: Retrieved {} Zed pro subscriptions in {}",
billing_subscriptions.len(),
Utc::now() - get_zed_pro_subscriptions_started_at
);
let claude_sonnet_4 = stripe_billing let claude_sonnet_4 = stripe_billing
.find_price_by_lookup_key("claude-sonnet-4-requests") .find_price_by_lookup_key("claude-sonnet-4-requests")
@ -1451,59 +1455,90 @@ async fn sync_model_request_usage_with_stripe(
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max") .find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?; .await?;
let usage_meter_count = usage_meters.len(); let model_mode_combinations = [
("claude-opus-4", CompletionMode::Max),
("claude-opus-4", CompletionMode::Normal),
("claude-sonnet-4", CompletionMode::Max),
("claude-sonnet-4", CompletionMode::Normal),
("claude-3-7-sonnet", CompletionMode::Max),
("claude-3-7-sonnet", CompletionMode::Normal),
("claude-3-5-sonnet", CompletionMode::Normal),
];
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters"); let billing_subscription_count = billing_subscriptions.len();
for (usage_meter, usage) in usage_meters { log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
maybe!(async { maybe!(async {
let Some((billing_customer, billing_subscription)) = if staff_user_ids.contains(&user_id) {
billing_subscriptions.get(&usage.user_id) return anyhow::Ok(());
else { }
bail!(
"Attempted to sync usage meter for user who is not a Stripe customer: {}",
usage.user_id
);
};
let stripe_customer_id = let stripe_customer_id =
StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
let stripe_subscription_id = let stripe_subscription_id =
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
let model = llm_db.model_by_id(usage_meter.model_id)?; let usage_meters = usage_meters_by_user_id.get(&user_id);
let (price, meter_event_name) = match model.name.as_str() { for (model, mode) in &model_mode_combinations {
"claude-opus-4" => match usage_meter.mode { let Ok(model) =
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"), llm_db.model(LanguageModelProvider::Anthropic, model)
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"), else {
}, log::warn!("Failed to load model for user {user_id}: {model}");
"claude-sonnet-4" => match usage_meter.mode { continue;
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"), };
CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
}, let (price, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), "claude-opus-4" => match mode {
"claude-3-7-sonnet" => match usage_meter.mode { CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
CompletionMode::Max => { },
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") "claude-sonnet-4" => match mode {
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
CompletionMode::Max => {
(&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
}
},
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match mode {
CompletionMode::Normal => {
(&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
}
CompletionMode::Max => {
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
}
},
model_name => {
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
} }
}, };
model_name => {
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
}
};
stripe_billing let model_requests = usage_meters
.subscribe_to_price(&stripe_subscription_id, price) .and_then(|usage_meters| {
.await?; usage_meters
stripe_billing .iter()
.bill_model_request_usage( .find(|meter| meter.model_id == model.id && meter.mode == *mode)
&stripe_customer_id, })
meter_event_name, .map(|usage_meter| usage_meter.requests)
usage_meter.requests, .unwrap_or(0);
)
.await?; if model_requests > 0 {
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price)
.await?;
}
stripe_billing
.bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
.await
.with_context(|| {
format!(
"Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
)
})?;
}
Ok(()) Ok(())
}) })
@ -1512,7 +1547,7 @@ async fn sync_model_request_usage_with_stripe(
} }
log::info!( log::info!(
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}", "Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
Utc::now() - started_at Utc::now() - started_at
); );

View file

@ -199,6 +199,33 @@ impl Database {
pub async fn get_active_zed_pro_billing_subscriptions( pub async fn get_active_zed_pro_billing_subscriptions(
&self, &self,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
})
.await
}
pub async fn get_active_zed_pro_billing_subscriptions_for_users(
&self,
user_ids: HashSet<UserId>, user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> { ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| { self.transaction(|tx| {