collab: Add usage-based billing for LLM interactions (#19081)
This PR adds usage-based billing for LLM interactions in the Assistant. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
f1c45d988e
commit
22ea7cef7a
20 changed files with 918 additions and 280 deletions
|
@ -1,7 +1,3 @@
|
|||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use axum::{
|
||||
extract::{self, Query},
|
||||
|
@ -9,28 +5,35 @@ use axum::{
|
|||
Extension, Json, Router,
|
||||
};
|
||||
use chrono::{DateTime, SecondsFormat, Utc};
|
||||
use collections::HashSet;
|
||||
use reqwest::StatusCode;
|
||||
use sea_orm::ActiveValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||
use stripe::{
|
||||
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
||||
Subscription, SubscriptionId, SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
|
||||
use crate::db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
||||
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::rpc::ResultExt as _;
|
||||
use crate::{
|
||||
db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
|
||||
UpdateBillingSubscriptionParams,
|
||||
},
|
||||
stripe_billing::StripeBilling,
|
||||
};
|
||||
use crate::{
|
||||
db::{billing_subscription::StripeSubscriptionStatus, UserId},
|
||||
llm::db::LlmDatabase,
|
||||
};
|
||||
use crate::{AppState, Error, Result};
|
||||
|
||||
pub fn router() -> Router {
|
||||
|
@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
|
|||
|
||||
async fn update_billing_preferences(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
||||
) -> Result<Json<BillingPreferencesResponse>> {
|
||||
let user = app
|
||||
|
@ -119,6 +123,8 @@ async fn update_billing_preferences(
|
|||
.await?
|
||||
};
|
||||
|
||||
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
|
@ -197,12 +203,15 @@ async fn create_billing_subscription(
|
|||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some((stripe_client, stripe_access_price_id)) = app
|
||||
.stripe_client
|
||||
.clone()
|
||||
.zip(app.config.stripe_llm_access_price_id.clone())
|
||||
else {
|
||||
log::error!("failed to retrieve Stripe client or price ID");
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
|
@ -226,26 +235,15 @@ async fn create_billing_subscription(
|
|||
customer.id
|
||||
};
|
||||
|
||||
let checkout_session = {
|
||||
let mut params = CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(user.github_login.as_str());
|
||||
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
|
||||
price: Some(stripe_access_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
params.success_url = Some(&success_url);
|
||||
|
||||
CheckoutSession::create(&stripe_client, params).await?
|
||||
};
|
||||
|
||||
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
let checkout_session_url = stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?;
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
checkout_session_url: checkout_session
|
||||
.url
|
||||
.ok_or_else(|| anyhow!("no checkout session URL"))?,
|
||||
checkout_session_url,
|
||||
}))
|
||||
}
|
||||
|
||||
|
@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
|
|||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
};
|
||||
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
|
||||
log::warn!("failed to retrieve Stripe LLM usage price ID");
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
|
@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
|||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(
|
||||
&app,
|
||||
&llm_db,
|
||||
&stripe_client,
|
||||
stripe_llm_usage_price_id.clone(),
|
||||
)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
sync_with_stripe(&app, &llm_db, &stripe_client)
|
||||
.await
|
||||
.trace_err();
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
|||
|
||||
async fn sync_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: Arc<str>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_client: &Arc<stripe::Client>,
|
||||
) -> anyhow::Result<()> {
|
||||
let subscriptions = app.db.get_active_billing_subscriptions().await?;
|
||||
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||
|
||||
for (customer, subscription) in subscriptions {
|
||||
update_stripe_subscription(
|
||||
llm_db,
|
||||
stripe_client,
|
||||
&stripe_llm_usage_price_id,
|
||||
customer,
|
||||
subscription,
|
||||
)
|
||||
.await
|
||||
.log_err();
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
.map(|(event, _)| event.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||
|
||||
for (event, model) in events {
|
||||
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||
stripe_subscriptions.get(&event.user_id)
|
||||
else {
|
||||
tracing::warn!(
|
||||
user_id = event.user_id.0,
|
||||
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||
.stripe_subscription_id
|
||||
.parse()
|
||||
.context("failed to parse stripe subscription id from db")?;
|
||||
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||
.stripe_customer_id
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_stripe_subscription(
|
||||
llm_db: &LlmDatabase,
|
||||
stripe_client: &stripe::Client,
|
||||
stripe_llm_usage_price_id: &Arc<str>,
|
||||
customer: billing_customer::Model,
|
||||
subscription: billing_subscription::Model,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let monthly_spending = llm_db
|
||||
.get_user_spending_for_month(customer.user_id, Utc::now())
|
||||
.await?;
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
let monthly_spending_over_free_tier =
|
||||
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
||||
|
||||
let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
|
||||
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
|
||||
|
||||
let mut update_params = stripe::UpdateSubscription {
|
||||
proration_behavior: Some(
|
||||
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
|
||||
),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
|
||||
item.price.as_ref().map_or(false, |price| {
|
||||
price.id == stripe_llm_usage_price_id.as_ref()
|
||||
})
|
||||
}) {
|
||||
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
|
||||
id: Some(existing_item.id.to_string()),
|
||||
quantity: Some(new_quantity as u64),
|
||||
..Default::default()
|
||||
}]);
|
||||
} else {
|
||||
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
|
||||
price: Some(stripe_llm_usage_price_id.to_string()),
|
||||
quantity: Some(new_quantity as u64),
|
||||
..Default::default()
|
||||
}]);
|
||||
}
|
||||
|
||||
Subscription::update(stripe_client, &subscription_id, update_params).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue