collab: Sync model overages for all active Zed Pro subscriptions (#34452)
Release Notes: - N/A
This commit is contained in:
parent
52f2b32557
commit
848a86a385
2 changed files with 118 additions and 56 deletions
|
@ -5,7 +5,7 @@ use axum::{
|
|||
routing::{get, post},
|
||||
};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use collections::HashSet;
|
||||
use collections::{HashMap, HashSet};
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -21,12 +21,13 @@ use stripe::{
|
|||
PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::api::events::SnowflakeRow;
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||
use crate::llm::db::subscription_usage_meter::{self, CompletionMode};
|
||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::stripe_client::{
|
||||
|
@ -1416,18 +1417,21 @@ async fn sync_model_request_usage_with_stripe(
|
|||
let usage_meters = llm_db
|
||||
.get_current_subscription_usage_meters(Utc::now())
|
||||
.await?;
|
||||
let usage_meters = usage_meters
|
||||
.into_iter()
|
||||
.filter(|(_, usage)| !staff_user_ids.contains(&usage.user_id))
|
||||
.collect::<Vec<_>>();
|
||||
let user_ids = usage_meters
|
||||
.iter()
|
||||
.map(|(_, usage)| usage.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let billing_subscriptions = app
|
||||
.db
|
||||
.get_active_zed_pro_billing_subscriptions(user_ids)
|
||||
.await?;
|
||||
let mut usage_meters_by_user_id =
|
||||
HashMap::<UserId, Vec<subscription_usage_meter::Model>>::default();
|
||||
for (usage_meter, usage) in usage_meters {
|
||||
let meters = usage_meters_by_user_id.entry(usage.user_id).or_default();
|
||||
meters.push(usage_meter);
|
||||
}
|
||||
|
||||
log::info!("Stripe usage sync: Retrieving Zed Pro subscriptions");
|
||||
let get_zed_pro_subscriptions_started_at = Utc::now();
|
||||
let billing_subscriptions = app.db.get_active_zed_pro_billing_subscriptions().await?;
|
||||
log::info!(
|
||||
"Stripe usage sync: Retrieved {} Zed pro subscriptions in {}",
|
||||
billing_subscriptions.len(),
|
||||
Utc::now() - get_zed_pro_subscriptions_started_at
|
||||
);
|
||||
|
||||
let claude_sonnet_4 = stripe_billing
|
||||
.find_price_by_lookup_key("claude-sonnet-4-requests")
|
||||
|
@ -1451,59 +1455,90 @@ async fn sync_model_request_usage_with_stripe(
|
|||
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
|
||||
.await?;
|
||||
|
||||
let usage_meter_count = usage_meters.len();
|
||||
let model_mode_combinations = [
|
||||
("claude-opus-4", CompletionMode::Max),
|
||||
("claude-opus-4", CompletionMode::Normal),
|
||||
("claude-sonnet-4", CompletionMode::Max),
|
||||
("claude-sonnet-4", CompletionMode::Normal),
|
||||
("claude-3-7-sonnet", CompletionMode::Max),
|
||||
("claude-3-7-sonnet", CompletionMode::Normal),
|
||||
("claude-3-5-sonnet", CompletionMode::Normal),
|
||||
];
|
||||
|
||||
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
|
||||
let billing_subscription_count = billing_subscriptions.len();
|
||||
|
||||
for (usage_meter, usage) in usage_meters {
|
||||
log::info!("Stripe usage sync: Syncing {billing_subscription_count} Zed Pro subscriptions");
|
||||
|
||||
for (user_id, (billing_customer, billing_subscription)) in billing_subscriptions {
|
||||
maybe!(async {
|
||||
let Some((billing_customer, billing_subscription)) =
|
||||
billing_subscriptions.get(&usage.user_id)
|
||||
else {
|
||||
bail!(
|
||||
"Attempted to sync usage meter for user who is not a Stripe customer: {}",
|
||||
usage.user_id
|
||||
);
|
||||
};
|
||||
if staff_user_ids.contains(&user_id) {
|
||||
return anyhow::Ok(());
|
||||
}
|
||||
|
||||
let stripe_customer_id =
|
||||
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||
let stripe_subscription_id =
|
||||
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
|
||||
|
||||
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
||||
let usage_meters = usage_meters_by_user_id.get(&user_id);
|
||||
|
||||
let (price, meter_event_name) = match model.name.as_str() {
|
||||
"claude-opus-4" => match usage_meter.mode {
|
||||
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
|
||||
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
|
||||
},
|
||||
"claude-sonnet-4" => match usage_meter.mode {
|
||||
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
|
||||
CompletionMode::Max => (&claude_sonnet_4_max, "claude_sonnet_4/requests/max"),
|
||||
},
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
|
||||
"claude-3-7-sonnet" => match usage_meter.mode {
|
||||
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
|
||||
CompletionMode::Max => {
|
||||
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
|
||||
for (model, mode) in &model_mode_combinations {
|
||||
let Ok(model) =
|
||||
llm_db.model(LanguageModelProvider::Anthropic, model)
|
||||
else {
|
||||
log::warn!("Failed to load model for user {user_id}: {model}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let (price, meter_event_name) = match model.name.as_str() {
|
||||
"claude-opus-4" => match mode {
|
||||
CompletionMode::Normal => (&claude_opus_4, "claude_opus_4/requests"),
|
||||
CompletionMode::Max => (&claude_opus_4_max, "claude_opus_4/requests/max"),
|
||||
},
|
||||
"claude-sonnet-4" => match mode {
|
||||
CompletionMode::Normal => (&claude_sonnet_4, "claude_sonnet_4/requests"),
|
||||
CompletionMode::Max => {
|
||||
(&claude_sonnet_4_max, "claude_sonnet_4/requests/max")
|
||||
}
|
||||
},
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
|
||||
"claude-3-7-sonnet" => match mode {
|
||||
CompletionMode::Normal => {
|
||||
(&claude_3_7_sonnet, "claude_3_7_sonnet/requests")
|
||||
}
|
||||
CompletionMode::Max => {
|
||||
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
|
||||
}
|
||||
},
|
||||
model_name => {
|
||||
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
|
||||
}
|
||||
},
|
||||
model_name => {
|
||||
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_price(&stripe_subscription_id, price)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_request_usage(
|
||||
&stripe_customer_id,
|
||||
meter_event_name,
|
||||
usage_meter.requests,
|
||||
)
|
||||
.await?;
|
||||
let model_requests = usage_meters
|
||||
.and_then(|usage_meters| {
|
||||
usage_meters
|
||||
.iter()
|
||||
.find(|meter| meter.model_id == model.id && meter.mode == *mode)
|
||||
})
|
||||
.map(|usage_meter| usage_meter.requests)
|
||||
.unwrap_or(0);
|
||||
|
||||
if model_requests > 0 {
|
||||
stripe_billing
|
||||
.subscribe_to_price(&stripe_subscription_id, price)
|
||||
.await?;
|
||||
}
|
||||
|
||||
stripe_billing
|
||||
.bill_model_request_usage(&stripe_customer_id, meter_event_name, model_requests)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to bill model request usage of {model_requests} for {stripe_customer_id}: {meter_event_name}",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
|
@ -1512,7 +1547,7 @@ async fn sync_model_request_usage_with_stripe(
|
|||
}
|
||||
|
||||
log::info!(
|
||||
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
|
||||
"Stripe usage sync: Synced {billing_subscription_count} Zed Pro subscriptions in {}",
|
||||
Utc::now() - started_at
|
||||
);
|
||||
|
||||
|
|
|
@ -199,6 +199,33 @@ impl Database {
|
|||
|
||||
pub async fn get_active_zed_pro_billing_subscriptions(
|
||||
&self,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_zed_pro_billing_subscriptions_for_users(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue