collab: Add usage-based billing for LLM interactions (#19081)

This PR adds usage-based billing for LLM interactions in the Assistant.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Marshall Bowers 2024-10-11 13:36:54 -04:00 committed by GitHub
parent f1c45d988e
commit 22ea7cef7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 918 additions and 280 deletions

View file

@ -1,7 +1,3 @@
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, bail, Context};
use axum::{
extract::{self, Query},
@ -9,28 +5,35 @@ use axum::{
Extension, Json, Router,
};
use chrono::{DateTime, SecondsFormat, Utc};
use collections::HashSet;
use reqwest::StatusCode;
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
Subscription, SubscriptionId, SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::ResultExt;
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
use crate::rpc::ResultExt as _;
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
UpdateBillingSubscriptionParams,
},
stripe_billing::StripeBilling,
};
use crate::{
db::{billing_subscription::StripeSubscriptionStatus, UserId},
llm::db::LlmDatabase,
};
use crate::{AppState, Error, Result};
pub fn router() -> Router {
@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
async fn update_billing_preferences(
Extension(app): Extension<Arc<AppState>>,
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
) -> Result<Json<BillingPreferencesResponse>> {
let user = app
@ -119,6 +123,8 @@ async fn update_billing_preferences(
.await?
};
rpc_server.refresh_llm_tokens_for_user(user.id).await;
Ok(Json(BillingPreferencesResponse {
max_monthly_llm_usage_spending_in_cents: billing_preferences
.max_monthly_llm_usage_spending_in_cents,
@ -197,12 +203,15 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some((stripe_client, stripe_access_price_id)) = app
.stripe_client
.clone()
.zip(app.config.stripe_llm_access_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
@ -226,26 +235,15 @@ async fn create_billing_subscription(
customer.id
};
let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_access_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
params.success_url = Some(&success_url);
CheckoutSession::create(&stripe_client, params).await?
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
let checkout_session_url = stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?;
Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url: checkout_session
.url
.ok_or_else(|| anyhow!("no checkout session URL"))?,
checkout_session_url,
}))
}
@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
log::warn!("failed to retrieve Stripe LLM usage price ID");
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
return;
};
@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
let executor = executor.clone();
async move {
loop {
sync_with_stripe(
&app,
&llm_db,
&stripe_client,
stripe_llm_usage_price_id.clone(),
)
.await
.trace_err();
sync_with_stripe(&app, &llm_db, &stripe_client)
.await
.trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
}
}
@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
async fn sync_with_stripe(
app: &Arc<AppState>,
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: Arc<str>,
llm_db: &Arc<LlmDatabase>,
stripe_client: &Arc<stripe::Client>,
) -> anyhow::Result<()> {
let subscriptions = app.db.get_active_billing_subscriptions().await?;
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
for (customer, subscription) in subscriptions {
update_stripe_subscription(
llm_db,
stripe_client,
&stripe_llm_usage_price_id,
customer,
subscription,
)
.await
.log_err();
let events = llm_db.get_billing_events().await?;
let user_ids = events
.iter()
.map(|(event, _)| event.user_id)
.collect::<HashSet<UserId>>();
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
for (event, model) in events {
let Some((stripe_db_customer, stripe_db_subscription)) =
stripe_subscriptions.get(&event.user_id)
else {
tracing::warn!(
user_id = event.user_id.0,
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
);
continue;
};
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
.stripe_subscription_id
.parse()
.context("failed to parse stripe subscription id from db")?;
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
.stripe_customer_id
.parse()
.context("failed to parse stripe customer id from db")?;
let stripe_model = stripe_billing.register_model(&model).await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
stripe_billing
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
.await?;
llm_db.consume_billing_event(event.id).await?;
}
Ok(())
}
async fn update_stripe_subscription(
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: &Arc<str>,
customer: billing_customer::Model,
subscription: billing_subscription::Model,
) -> Result<(), anyhow::Error> {
let monthly_spending = llm_db
.get_user_spending_for_month(customer.user_id, Utc::now())
.await?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
let monthly_spending_over_free_tier =
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
let mut update_params = stripe::UpdateSubscription {
proration_behavior: Some(
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
),
..Default::default()
};
if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
item.price.as_ref().map_or(false, |price| {
price.id == stripe_llm_usage_price_id.as_ref()
})
}) {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
id: Some(existing_item.id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
} else {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
price: Some(stripe_llm_usage_price_id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
}
Subscription::update(stripe_client, &subscription_id, update_params).await?;
Ok(())
}