collab: Add usage-based billing for LLM interactions (#19081)

This PR adds usage-based billing for LLM interactions in the Assistant.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
Marshall Bowers 2024-10-11 13:36:54 -04:00 committed by GitHub
parent f1c45d988e
commit 22ea7cef7a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 918 additions and 280 deletions

5
Cargo.lock generated
View file

@ -839,9 +839,8 @@ dependencies = [
[[package]] [[package]]
name = "async-stripe" name = "async-stripe"
version = "0.39.1" version = "0.40.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
checksum = "58d670cf4d47a1b8ffef54286a5625382e360a34ee76902fd93ad8c7032a0c30"
dependencies = [ dependencies = [
"chrono", "chrono",
"futures-util", "futures-util",

View file

@ -480,7 +480,8 @@ which = "6.0.0"
wit-component = "0.201" wit-component = "0.201"
[workspace.dependencies.async-stripe] [workspace.dependencies.async-stripe]
version = "0.39" git = "https://github.com/zed-industries/async-stripe"
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
default-features = false default-features = false
features = [ features = [
"runtime-tokio-hyper-rustls", "runtime-tokio-hyper-rustls",

View file

@ -0,0 +1,12 @@
create table billing_events (
id serial primary key,
idempotency_key uuid not null default gen_random_uuid(),
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
input_tokens bigint not null default 0,
input_cache_creation_tokens bigint not null default 0,
input_cache_read_tokens bigint not null default 0,
output_tokens bigint not null default 0
);
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);

View file

@ -1,7 +1,3 @@
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, bail, Context}; use anyhow::{anyhow, bail, Context};
use axum::{ use axum::{
extract::{self, Query}, extract::{self, Query},
@ -9,28 +5,35 @@ use axum::{
Extension, Json, Router, Extension, Json, Router,
}; };
use chrono::{DateTime, SecondsFormat, Utc}; use chrono::{DateTime, SecondsFormat, Utc};
use collections::HashSet;
use reqwest::StatusCode; use reqwest::StatusCode;
use sea_orm::ActiveValue; use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{ use stripe::{
BillingPortalSession, CheckoutSession, CreateBillingPortalSession, BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
Subscription, SubscriptionId, SubscriptionStatus,
}; };
use util::ResultExt; use util::ResultExt;
use crate::db::billing_subscription::{self, StripeSubscriptionStatus}; use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::rpc::ResultExt as _; use crate::rpc::ResultExt as _;
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
UpdateBillingSubscriptionParams,
},
stripe_billing::StripeBilling,
};
use crate::{
db::{billing_subscription::StripeSubscriptionStatus, UserId},
llm::db::LlmDatabase,
};
use crate::{AppState, Error, Result}; use crate::{AppState, Error, Result};
pub fn router() -> Router { pub fn router() -> Router {
@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
async fn update_billing_preferences( async fn update_billing_preferences(
Extension(app): Extension<Arc<AppState>>, Extension(app): Extension<Arc<AppState>>,
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>, extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
) -> Result<Json<BillingPreferencesResponse>> { ) -> Result<Json<BillingPreferencesResponse>> {
let user = app let user = app
@ -119,6 +123,8 @@ async fn update_billing_preferences(
.await? .await?
}; };
rpc_server.refresh_llm_tokens_for_user(user.id).await;
Ok(Json(BillingPreferencesResponse { Ok(Json(BillingPreferencesResponse {
max_monthly_llm_usage_spending_in_cents: billing_preferences max_monthly_llm_usage_spending_in_cents: billing_preferences
.max_monthly_llm_usage_spending_in_cents, .max_monthly_llm_usage_spending_in_cents,
@ -197,12 +203,15 @@ async fn create_billing_subscription(
.await? .await?
.ok_or_else(|| anyhow!("user not found"))?; .ok_or_else(|| anyhow!("user not found"))?;
let Some((stripe_client, stripe_access_price_id)) = app let Some(stripe_client) = app.stripe_client.clone() else {
.stripe_client log::error!("failed to retrieve Stripe client");
.clone() Err(Error::http(
.zip(app.config.stripe_llm_access_price_id.clone()) StatusCode::NOT_IMPLEMENTED,
else { "not supported".into(),
log::error!("failed to retrieve Stripe client or price ID"); ))?
};
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
Err(Error::http( Err(Error::http(
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
"not supported".into(), "not supported".into(),
@ -226,26 +235,15 @@ async fn create_billing_subscription(
customer.id customer.id
}; };
let checkout_session = { let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let mut params = CreateCheckoutSession::new(); let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
params.mode = Some(stripe::CheckoutSessionMode::Subscription); let stripe_model = stripe_billing.register_model(default_model).await?;
params.customer = Some(customer_id); let success_url = format!("{}/account", app.config.zed_dot_dev_url());
params.client_reference_id = Some(user.github_login.as_str()); let checkout_session_url = stripe_billing
params.line_items = Some(vec![CreateCheckoutSessionLineItems { .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
price: Some(stripe_access_price_id.to_string()), .await?;
quantity: Some(1),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
params.success_url = Some(&success_url);
CheckoutSession::create(&stripe_client, params).await?
};
Ok(Json(CreateBillingSubscriptionResponse { Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url: checkout_session checkout_session_url,
.url
.ok_or_else(|| anyhow!("no checkout session URL"))?,
})) }))
} }
@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer)) Ok(Some(billing_customer))
} }
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) { pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else { let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client"); log::warn!("failed to retrieve Stripe client");
return; return;
}; };
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else { let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve Stripe LLM usage price ID"); log::warn!("failed to retrieve LLM database");
return; return;
}; };
@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
let executor = executor.clone(); let executor = executor.clone();
async move { async move {
loop { loop {
sync_with_stripe( sync_with_stripe(&app, &llm_db, &stripe_client)
&app, .await
&llm_db, .trace_err();
&stripe_client,
stripe_llm_usage_price_id.clone(),
)
.await
.trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await; executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
} }
} }
@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
async fn sync_with_stripe( async fn sync_with_stripe(
app: &Arc<AppState>, app: &Arc<AppState>,
llm_db: &LlmDatabase, llm_db: &Arc<LlmDatabase>,
stripe_client: &stripe::Client, stripe_client: &Arc<stripe::Client>,
stripe_llm_usage_price_id: Arc<str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let subscriptions = app.db.get_active_billing_subscriptions().await?; let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
for (customer, subscription) in subscriptions { let events = llm_db.get_billing_events().await?;
update_stripe_subscription( let user_ids = events
llm_db, .iter()
stripe_client, .map(|(event, _)| event.user_id)
&stripe_llm_usage_price_id, .collect::<HashSet<UserId>>();
customer, let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
subscription,
) for (event, model) in events {
.await let Some((stripe_db_customer, stripe_db_subscription)) =
.log_err(); stripe_subscriptions.get(&event.user_id)
else {
tracing::warn!(
user_id = event.user_id.0,
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
);
continue;
};
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
.stripe_subscription_id
.parse()
.context("failed to parse stripe subscription id from db")?;
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
.stripe_customer_id
.parse()
.context("failed to parse stripe customer id from db")?;
let stripe_model = stripe_billing.register_model(&model).await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
stripe_billing
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
.await?;
llm_db.consume_billing_event(event.id).await?;
} }
Ok(()) Ok(())
} }
async fn update_stripe_subscription(
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: &Arc<str>,
customer: billing_customer::Model,
subscription: billing_subscription::Model,
) -> Result<(), anyhow::Error> {
let monthly_spending = llm_db
.get_user_spending_for_month(customer.user_id, Utc::now())
.await?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
let monthly_spending_over_free_tier =
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
let mut update_params = stripe::UpdateSubscription {
proration_behavior: Some(
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
),
..Default::default()
};
if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
item.price.as_ref().map_or(false, |price| {
price.id == stripe_llm_usage_price_id.as_ref()
})
}) {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
id: Some(existing_item.id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
} else {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
price: Some(stripe_llm_usage_price_id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
}
Subscription::update(stripe_client, &subscription_id, update_params).await?;
Ok(())
}

View file

@ -114,23 +114,31 @@ impl Database {
pub async fn get_active_billing_subscriptions( pub async fn get_active_billing_subscriptions(
&self, &self,
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> { user_ids: HashSet<UserId>,
self.transaction(|tx| async move { ) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
let mut result = Vec::new(); self.transaction(|tx| {
let mut rows = billing_subscription::Entity::find() let user_ids = user_ids.clone();
.inner_join(billing_customer::Entity) async move {
.select_also(billing_customer::Entity) let mut rows = billing_subscription::Entity::find()
.order_by_asc(billing_subscription::Column::Id) .inner_join(billing_customer::Entity)
.stream(&*tx) .select_also(billing_customer::Entity)
.await?; .filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;
while let Some(row) = rows.next().await { let mut subscriptions = HashMap::default();
if let (subscription, Some(customer)) = row? { while let Some(row) = rows.next().await {
result.push((customer, subscription)); if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
} }
Ok(subscriptions)
} }
Ok(result)
}) })
.await .await
} }

View file

@ -10,6 +10,7 @@ pub mod migrations;
mod rate_limiter; mod rate_limiter;
pub mod rpc; pub mod rpc;
pub mod seed; pub mod seed;
pub mod stripe_billing;
pub mod user_backfiller; pub mod user_backfiller;
#[cfg(test)] #[cfg(test)]
@ -24,6 +25,7 @@ use axum::{
pub use cents::*; pub use cents::*;
use db::{ChannelId, Database}; use db::{ChannelId, Database};
use executor::Executor; use executor::Executor;
use llm::db::LlmDatabase;
pub use rate_limiter::*; pub use rate_limiter::*;
use serde::Deserialize; use serde::Deserialize;
use std::{path::PathBuf, sync::Arc}; use std::{path::PathBuf, sync::Arc};
@ -176,8 +178,6 @@ pub struct Config {
pub slack_panics_webhook: Option<String>, pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>, pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>, pub stripe_api_key: Option<String>,
pub stripe_llm_access_price_id: Option<Arc<str>>,
pub stripe_llm_usage_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>, pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>, pub user_backfiller_github_access_token: Option<Arc<str>>,
} }
@ -197,7 +197,7 @@ impl Config {
} }
pub fn is_llm_billing_enabled(&self) -> bool { pub fn is_llm_billing_enabled(&self) -> bool {
self.stripe_llm_usage_price_id.is_some() self.stripe_api_key.is_some()
} }
#[cfg(test)] #[cfg(test)]
@ -238,8 +238,6 @@ impl Config {
migrations_path: None, migrations_path: None,
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_llm_access_price_id: None,
stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
} }
@ -272,6 +270,7 @@ impl ServiceMode {
pub struct AppState { pub struct AppState {
pub db: Arc<Database>, pub db: Arc<Database>,
pub llm_db: Option<Arc<LlmDatabase>>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>, pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>, pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>, pub stripe_client: Option<Arc<stripe::Client>>,
@ -288,6 +287,20 @@ impl AppState {
let mut db = Database::new(db_options, Executor::Production).await?; let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?; db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
.llm_database_url
.clone()
.zip(config.llm_database_max_connections)
{
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
llm_db_options.max_connections(llm_database_max_connections);
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
llm_db.initialize().await?;
Some(Arc::new(llm_db))
} else {
None
};
let live_kit_client = if let Some(((server, key), secret)) = config let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server .live_kit_server
.as_ref() .as_ref()
@ -306,9 +319,10 @@ impl AppState {
let db = Arc::new(db); let db = Arc::new(db);
let this = Self { let this = Self {
db: db.clone(), db: db.clone(),
llm_db,
live_kit_client, live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(), blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(), stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)), rate_limiter: Arc::new(RateLimiter::new(db)),
executor, executor,
clickhouse_client: config clickhouse_client: config
@ -321,12 +335,11 @@ impl AppState {
} }
} }
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> { fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config let api_key = config
.stripe_api_key .stripe_api_key
.as_ref() .as_ref()
.ok_or_else(|| anyhow!("missing stripe_api_key"))?; .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
Ok(stripe::Client::new(api_key)) Ok(stripe::Client::new(api_key))
} }

View file

@ -20,13 +20,14 @@ use axum::{
}; };
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use collections::HashMap; use collections::HashMap;
use db::TokenUsage;
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _}; use futures::{Stream, StreamExt as _};
use isahc_http_client::IsahcHttpClient; use isahc_http_client::IsahcHttpClient;
use rpc::ListModelsResponse;
use rpc::{ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
}; };
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use std::{ use std::{
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
@ -418,10 +419,7 @@ async fn perform_completion(
claims, claims,
provider: params.provider, provider: params.provider,
model, model,
input_tokens: 0, tokens: TokenUsage::default(),
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
inner_stream: stream, inner_stream: stream,
}))) })))
} }
@ -476,6 +474,19 @@ async fn check_usage_limit(
"Maximum spending limit reached for this month.".to_string(), "Maximum spending limit reached for this month.".to_string(),
)); ));
} }
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
return Err(Error::Http(
StatusCode::FORBIDDEN,
"Maximum spending limit reached for this month.".to_string(),
[(
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
HeaderValue::from_static("true"),
)]
.into_iter()
.collect(),
));
}
} }
} }
@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
claims: LlmTokenClaims, claims: LlmTokenClaims,
provider: LanguageModelProvider, provider: LanguageModelProvider,
model: String, model: String,
input_tokens: usize, tokens: TokenUsage,
output_tokens: usize,
cache_creation_input_tokens: usize,
cache_read_input_tokens: usize,
inner_stream: S, inner_stream: S,
} }
@ -615,10 +623,10 @@ where
match Pin::new(&mut self.inner_stream).poll_next(cx) { match Pin::new(&mut self.inner_stream).poll_next(cx) {
Poll::Ready(Some(Ok(mut chunk))) => { Poll::Ready(Some(Ok(mut chunk))) => {
chunk.bytes.push(b'\n'); chunk.bytes.push(b'\n');
self.input_tokens += chunk.input_tokens; self.tokens.input += chunk.input_tokens;
self.output_tokens += chunk.output_tokens; self.tokens.output += chunk.output_tokens;
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens; self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
self.cache_read_input_tokens += chunk.cache_read_input_tokens; self.tokens.input_cache_read += chunk.cache_read_input_tokens;
Poll::Ready(Some(Ok(chunk.bytes))) Poll::Ready(Some(Ok(chunk.bytes)))
} }
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
let claims = self.claims.clone(); let claims = self.claims.clone();
let provider = self.provider; let provider = self.provider;
let model = std::mem::take(&mut self.model); let model = std::mem::take(&mut self.model);
let input_token_count = self.input_tokens; let tokens = self.tokens;
let output_token_count = self.output_tokens;
let cache_creation_input_token_count = self.cache_creation_input_tokens;
let cache_read_input_token_count = self.cache_read_input_tokens;
self.state.executor.spawn_detached(async move { self.state.executor.spawn_detached(async move {
let usage = state let usage = state
.db .db
@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
claims.is_staff, claims.is_staff,
provider, provider,
&model, &model,
input_token_count, tokens,
cache_creation_input_token_count, claims.has_llm_subscription,
cache_read_input_token_count, Cents(claims.max_monthly_spend_in_cents),
output_token_count,
Utc::now(), Utc::now(),
) )
.await .await
@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
}, },
model, model,
provider: provider.to_string(), provider: provider.to_string(),
input_token_count: input_token_count as u64, input_token_count: tokens.input as u64,
cache_creation_input_token_count: cache_creation_input_token_count cache_creation_input_token_count: tokens.input_cache_creation as u64,
as u64, cache_read_input_token_count: tokens.input_cache_read as u64,
cache_read_input_token_count: cache_read_input_token_count as u64, output_token_count: tokens.output as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64, requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64, tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64, tokens_this_day: usage.tokens_this_day as u64,
input_tokens_this_month: usage.input_tokens_this_month as u64, input_tokens_this_month: usage.tokens_this_month.input as u64,
cache_creation_input_tokens_this_month: usage cache_creation_input_tokens_this_month: usage
.cache_creation_input_tokens_this_month .tokens_this_month
.input_cache_creation
as u64, as u64,
cache_read_input_tokens_this_month: usage cache_read_input_tokens_this_month: usage
.cache_read_input_tokens_this_month .tokens_this_month
.input_cache_read
as u64, as u64,
output_tokens_this_month: usage.output_tokens_this_month as u64, output_tokens_this_month: usage.tokens_this_month.output as u64,
spending_this_month: usage.spending_this_month.0 as u64, spending_this_month: usage.spending_this_month.0 as u64,
lifetime_spending: usage.lifetime_spending.0 as u64, lifetime_spending: usage.lifetime_spending.0 as u64,
}, },

View file

@ -20,7 +20,7 @@ use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow; use anyhow::anyhow;
pub use queries::usages::ActiveUserCount; pub use queries::usages::{ActiveUserCount, TokenUsage};
use sea_orm::prelude::*; use sea_orm::prelude::*;
pub use sea_orm::ConnectOptions; pub use sea_orm::ConnectOptions;
use sea_orm::{ use sea_orm::{

View file

@ -8,3 +8,4 @@ id_type!(ProviderId);
id_type!(UsageId); id_type!(UsageId);
id_type!(UsageMeasureId); id_type!(UsageMeasureId);
id_type!(RevokedAccessTokenId); id_type!(RevokedAccessTokenId);
id_type!(BillingEventId);

View file

@ -1,5 +1,6 @@
use super::*; use super::*;
pub mod billing_events;
pub mod providers; pub mod providers;
pub mod revoked_access_tokens; pub mod revoked_access_tokens;
pub mod usages; pub mod usages;

View file

@ -0,0 +1,31 @@
use super::*;
use crate::Result;
use anyhow::Context as _;
impl LlmDatabase {
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
self.transaction(|tx| async move {
let events_with_models = billing_event::Entity::find()
.find_also_related(model::Entity)
.all(&*tx)
.await?;
events_with_models
.into_iter()
.map(|(event, model)| {
let model =
model.context("could not find model associated with billing event")?;
Ok((event, model))
})
.collect()
})
.await
}
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
self.transaction(|tx| async move {
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
Ok(())
})
.await
}
}

View file

@ -1,5 +1,5 @@
use crate::db::UserId;
use crate::llm::Cents; use crate::llm::Cents;
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
use chrono::{Datelike, Duration}; use chrono::{Datelike, Duration};
use futures::StreamExt as _; use futures::StreamExt as _;
use rpc::LanguageModelProvider; use rpc::LanguageModelProvider;
@ -9,15 +9,26 @@ use strum::IntoEnumIterator as _;
use super::*; use super::*;
#[derive(Debug, PartialEq, Clone, Copy, Default)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
pub input_cache_read: usize,
pub output: usize,
}
impl TokenUsage {
pub fn total(&self) -> usize {
self.input + self.input_cache_creation + self.input_cache_read + self.output
}
}
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub struct Usage { pub struct Usage {
pub requests_this_minute: usize, pub requests_this_minute: usize,
pub tokens_this_minute: usize, pub tokens_this_minute: usize,
pub tokens_this_day: usize, pub tokens_this_day: usize,
pub input_tokens_this_month: usize, pub tokens_this_month: TokenUsage,
pub cache_creation_input_tokens_this_month: usize,
pub cache_read_input_tokens_this_month: usize,
pub output_tokens_this_month: usize,
pub spending_this_month: Cents, pub spending_this_month: Cents,
pub lifetime_spending: Cents, pub lifetime_spending: Cents,
} }
@ -257,18 +268,20 @@ impl LlmDatabase {
requests_this_minute, requests_this_minute,
tokens_this_minute, tokens_this_minute,
tokens_this_day, tokens_this_day,
input_tokens_this_month: monthly_usage tokens_this_month: TokenUsage {
.as_ref() input: monthly_usage
.map_or(0, |usage| usage.input_tokens as usize), .as_ref()
cache_creation_input_tokens_this_month: monthly_usage .map_or(0, |usage| usage.input_tokens as usize),
.as_ref() input_cache_creation: monthly_usage
.map_or(0, |usage| usage.cache_creation_input_tokens as usize), .as_ref()
cache_read_input_tokens_this_month: monthly_usage .map_or(0, |usage| usage.cache_creation_input_tokens as usize),
.as_ref() input_cache_read: monthly_usage
.map_or(0, |usage| usage.cache_read_input_tokens as usize), .as_ref()
output_tokens_this_month: monthly_usage .map_or(0, |usage| usage.cache_read_input_tokens as usize),
.as_ref() output: monthly_usage
.map_or(0, |usage| usage.output_tokens as usize), .as_ref()
.map_or(0, |usage| usage.output_tokens as usize),
},
spending_this_month, spending_this_month,
lifetime_spending, lifetime_spending,
}) })
@ -283,10 +296,9 @@ impl LlmDatabase {
is_staff: bool, is_staff: bool,
provider: LanguageModelProvider, provider: LanguageModelProvider,
model_name: &str, model_name: &str,
input_token_count: usize, tokens: TokenUsage,
cache_creation_input_tokens: usize, has_llm_subscription: bool,
cache_read_input_tokens: usize, max_monthly_spend: Cents,
output_token_count: usize,
now: DateTimeUtc, now: DateTimeUtc,
) -> Result<Usage> { ) -> Result<Usage> {
self.transaction(|tx| async move { self.transaction(|tx| async move {
@ -313,10 +325,6 @@ impl LlmDatabase {
&tx, &tx,
) )
.await?; .await?;
let total_token_count = input_token_count
+ cache_read_input_tokens
+ cache_creation_input_tokens
+ output_token_count;
let tokens_this_minute = self let tokens_this_minute = self
.update_usage_for_measure( .update_usage_for_measure(
user_id, user_id,
@ -325,7 +333,7 @@ impl LlmDatabase {
&usages, &usages,
UsageMeasure::TokensPerMinute, UsageMeasure::TokensPerMinute,
now, now,
total_token_count, tokens.total(),
&tx, &tx,
) )
.await?; .await?;
@ -337,7 +345,7 @@ impl LlmDatabase {
&usages, &usages,
UsageMeasure::TokensPerDay, UsageMeasure::TokensPerDay,
now, now,
total_token_count, tokens.total(),
&tx, &tx,
) )
.await?; .await?;
@ -361,18 +369,14 @@ impl LlmDatabase {
Some(usage) => { Some(usage) => {
monthly_usage::Entity::update(monthly_usage::ActiveModel { monthly_usage::Entity::update(monthly_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id), id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set( input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
usage.input_tokens + input_token_count as i64,
),
cache_creation_input_tokens: ActiveValue::set( cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
), ),
cache_read_input_tokens: ActiveValue::set( cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + cache_read_input_tokens as i64, usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(
usage.output_tokens + output_token_count as i64,
), ),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default() ..Default::default()
}) })
.exec(&*tx) .exec(&*tx)
@ -384,12 +388,12 @@ impl LlmDatabase {
model_id: ActiveValue::set(model.id), model_id: ActiveValue::set(model.id),
month: ActiveValue::set(month), month: ActiveValue::set(month),
year: ActiveValue::set(year), year: ActiveValue::set(year),
input_tokens: ActiveValue::set(input_token_count as i64), input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set( cache_creation_input_tokens: ActiveValue::set(
cache_creation_input_tokens as i64, tokens.input_cache_creation as i64,
), ),
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64), cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(output_token_count as i64), output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default() ..Default::default()
} }
.insert(&*tx) .insert(&*tx)
@ -405,6 +409,26 @@ impl LlmDatabase {
monthly_usage.output_tokens as usize, monthly_usage.output_tokens as usize,
); );
if spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
&& has_llm_subscription
&& spending_this_month <= max_monthly_spend
{
billing_event::ActiveModel {
id: ActiveValue::not_set(),
idempotency_key: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(tokens.input as i64),
input_cache_creation_tokens: ActiveValue::set(
tokens.input_cache_creation as i64,
),
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(tokens.output as i64),
}
.insert(&*tx)
.await?;
}
// Update lifetime usage // Update lifetime usage
let lifetime_usage = lifetime_usage::Entity::find() let lifetime_usage = lifetime_usage::Entity::find()
.filter( .filter(
@ -419,18 +443,14 @@ impl LlmDatabase {
Some(usage) => { Some(usage) => {
lifetime_usage::Entity::update(lifetime_usage::ActiveModel { lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
id: ActiveValue::unchanged(usage.id), id: ActiveValue::unchanged(usage.id),
input_tokens: ActiveValue::set( input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
usage.input_tokens + input_token_count as i64,
),
cache_creation_input_tokens: ActiveValue::set( cache_creation_input_tokens: ActiveValue::set(
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
), ),
cache_read_input_tokens: ActiveValue::set( cache_read_input_tokens: ActiveValue::set(
usage.cache_read_input_tokens + cache_read_input_tokens as i64, usage.cache_read_input_tokens + tokens.input_cache_read as i64,
),
output_tokens: ActiveValue::set(
usage.output_tokens + output_token_count as i64,
), ),
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
..Default::default() ..Default::default()
}) })
.exec(&*tx) .exec(&*tx)
@ -440,12 +460,12 @@ impl LlmDatabase {
lifetime_usage::ActiveModel { lifetime_usage::ActiveModel {
user_id: ActiveValue::set(user_id), user_id: ActiveValue::set(user_id),
model_id: ActiveValue::set(model.id), model_id: ActiveValue::set(model.id),
input_tokens: ActiveValue::set(input_token_count as i64), input_tokens: ActiveValue::set(tokens.input as i64),
cache_creation_input_tokens: ActiveValue::set( cache_creation_input_tokens: ActiveValue::set(
cache_creation_input_tokens as i64, tokens.input_cache_creation as i64,
), ),
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64), cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
output_tokens: ActiveValue::set(output_token_count as i64), output_tokens: ActiveValue::set(tokens.output as i64),
..Default::default() ..Default::default()
} }
.insert(&*tx) .insert(&*tx)
@ -465,11 +485,12 @@ impl LlmDatabase {
requests_this_minute, requests_this_minute,
tokens_this_minute, tokens_this_minute,
tokens_this_day, tokens_this_day,
input_tokens_this_month: monthly_usage.input_tokens as usize, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens input: monthly_usage.input_tokens as usize,
as usize, input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize, input_cache_read: monthly_usage.cache_read_input_tokens as usize,
output_tokens_this_month: monthly_usage.output_tokens as usize, output: monthly_usage.output_tokens as usize,
},
spending_this_month, spending_this_month,
lifetime_spending, lifetime_spending,
}) })

View file

@ -1,3 +1,4 @@
pub mod billing_event;
pub mod lifetime_usage; pub mod lifetime_usage;
pub mod model; pub mod model;
pub mod monthly_usage; pub mod monthly_usage;

View file

@ -0,0 +1,37 @@
use crate::{
db::UserId,
llm::db::{BillingEventId, ModelId},
};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_events")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingEventId,
pub idempotency_key: Uuid,
pub user_id: UserId,
pub model_id: ModelId,
pub input_tokens: i64,
pub input_cache_creation_tokens: i64,
pub input_cache_read_tokens: i64,
pub output_tokens: i64,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -29,6 +29,8 @@ pub enum Relation {
Provider, Provider,
#[sea_orm(has_many = "super::usage::Entity")] #[sea_orm(has_many = "super::usage::Entity")]
Usages, Usages,
#[sea_orm(has_many = "super::billing_event::Entity")]
BillingEvents,
} }
impl Related<super::provider::Entity> for Entity { impl Related<super::provider::Entity> for Entity {
@ -43,4 +45,10 @@ impl Related<super::usage::Entity> for Entity {
} }
} }
impl Related<super::billing_event::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingEvents.def()
}
}
impl ActiveModelBehavior for ActiveModel {} impl ActiveModelBehavior for ActiveModel {}

View file

@ -2,7 +2,7 @@ use crate::{
db::UserId, db::UserId,
llm::db::{ llm::db::{
queries::{providers::ModelParams, usages::Usage}, queries::{providers::ModelParams, usages::Usage},
LlmDatabase, LlmDatabase, TokenUsage,
}, },
test_llm_db, Cents, test_llm_db, Cents,
}; };
@ -36,14 +36,42 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
let user_id = UserId::from_proto(123); let user_id = UserId::from_proto(123);
let now = t0; let now = t0;
db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let now = t0 + Duration::seconds(10); let now = t0 + Duration::seconds(10);
db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 2000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!( assert_eq!(
@ -52,10 +80,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2, requests_this_minute: 2,
tokens_this_minute: 3000, tokens_this_minute: 3000,
tokens_this_day: 3000, tokens_this_day: 3000,
input_tokens_this_month: 3000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 0, input: 3000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 0,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
@ -69,19 +99,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1, requests_this_minute: 1,
tokens_this_minute: 2000, tokens_this_minute: 2000,
tokens_this_day: 3000, tokens_this_day: 3000,
input_tokens_this_month: 3000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 0, input: 3000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 0,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
); );
let now = t0 + Duration::seconds(60); let now = t0 + Duration::seconds(60);
db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 3000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!( assert_eq!(
@ -90,10 +136,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2, requests_this_minute: 2,
tokens_this_minute: 5000, tokens_this_minute: 5000,
tokens_this_day: 6000, tokens_this_day: 6000,
input_tokens_this_month: 6000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 0, input: 6000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 0,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
@ -108,18 +156,34 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 0, requests_this_minute: 0,
tokens_this_minute: 0, tokens_this_minute: 0,
tokens_this_day: 5000, tokens_this_day: 5000,
input_tokens_this_month: 6000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 0, input: 6000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 0,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
); );
db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 4000,
input_cache_creation: 0,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!( assert_eq!(
@ -128,10 +192,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1, requests_this_minute: 1,
tokens_this_minute: 4000, tokens_this_minute: 4000,
tokens_this_day: 9000, tokens_this_day: 9000,
input_tokens_this_month: 10000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 0, input: 10000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 0,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
@ -143,9 +209,23 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
.with_timezone(&Utc); .with_timezone(&Utc);
// Test cache creation input tokens // Test cache creation input tokens
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 500,
input_cache_read: 0,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!( assert_eq!(
@ -154,19 +234,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 1, requests_this_minute: 1,
tokens_this_minute: 1500, tokens_this_minute: 1500,
tokens_this_day: 1500, tokens_this_day: 1500,
input_tokens_this_month: 1000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 500, input: 1000,
cache_read_input_tokens_this_month: 0, input_cache_creation: 500,
output_tokens_this_month: 0, input_cache_read: 0,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }
); );
// Test cache read input tokens // Test cache read input tokens
db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now) db.record_usage(
.await user_id,
.unwrap(); false,
provider,
model,
TokenUsage {
input: 1000,
input_cache_creation: 0,
input_cache_read: 300,
output: 0,
},
false,
Cents::ZERO,
now,
)
.await
.unwrap();
let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap();
assert_eq!( assert_eq!(
@ -175,10 +271,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
requests_this_minute: 2, requests_this_minute: 2,
tokens_this_minute: 2800, tokens_this_minute: 2800,
tokens_this_day: 2800, tokens_this_day: 2800,
input_tokens_this_month: 2000, tokens_this_month: TokenUsage {
cache_creation_input_tokens_this_month: 500, input: 2000,
cache_read_input_tokens_this_month: 300, input_cache_creation: 500,
output_tokens_this_month: 0, input_cache_read: 300,
output: 0,
},
spending_this_month: Cents::ZERO, spending_this_month: Cents::ZERO,
lifetime_spending: Cents::ZERO, lifetime_spending: Cents::ZERO,
} }

View file

@ -157,7 +157,7 @@ async fn main() -> Result<()> {
if let Some(mut llm_db) = llm_db { if let Some(mut llm_db) = llm_db {
llm_db.initialize().await?; llm_db.initialize().await?;
sync_llm_usage_with_stripe_periodically(state.clone(), llm_db); sync_llm_usage_with_stripe_periodically(state.clone());
} }
app = app app = app

View file

@ -1218,6 +1218,15 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, proto::RefreshLlmToken {})
.trace_err();
}
}
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> { pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
ServerSnapshot { ServerSnapshot {
connection_pool: ConnectionPoolGuard { connection_pool: ConnectionPoolGuard {

View file

@ -0,0 +1,427 @@
use std::sync::Arc;
use crate::{llm, Cents, Result};
use anyhow::Context;
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
pub struct StripeBilling {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
client: Arc<stripe::Client>,
}
pub struct StripeModel {
input_tokens_price: StripeBillingPrice,
input_cache_creation_tokens_price: StripeBillingPrice,
input_cache_read_tokens_price: StripeBillingPrice,
output_tokens_price: StripeBillingPrice,
}
struct StripeBillingPrice {
id: stripe::PriceId,
meter_event_name: String,
}
impl StripeBilling {
pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
let mut meters_by_event_name = HashMap::default();
for meter in StripeMeter::list(&client).await?.data {
meters_by_event_name.insert(meter.event_name.clone(), meter);
}
let mut price_ids_by_meter_id = HashMap::default();
for price in stripe::Price::list(&client, &stripe::ListPrices::default())
.await?
.data
{
if let Some(recurring) = price.recurring {
if let Some(meter) = recurring.meter {
price_ids_by_meter_id.insert(meter, price.id);
}
}
}
Ok(Self {
meters_by_event_name,
price_ids_by_meter_id,
client,
})
}
pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
let input_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_tokens", model.id),
&format!("{} (Input Tokens)", model.name),
Cents::new(model.price_per_million_input_tokens as u32),
)
.await?;
let input_cache_creation_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_cache_creation_tokens", model.id),
&format!("{} (Input Cache Creation Tokens)", model.name),
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
)
.await?;
let input_cache_read_tokens_price = self
.get_or_insert_price(
&format!("model_{}/input_cache_read_tokens", model.id),
&format!("{} (Input Cache Read Tokens)", model.name),
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
)
.await?;
let output_tokens_price = self
.get_or_insert_price(
&format!("model_{}/output_tokens", model.id),
&format!("{} (Output Tokens)", model.name),
Cents::new(model.price_per_million_output_tokens as u32),
)
.await?;
Ok(StripeModel {
input_tokens_price,
input_cache_creation_tokens_price,
input_cache_read_tokens_price,
output_tokens_price,
})
}
async fn get_or_insert_price(
&mut self,
meter_event_name: &str,
price_description: &str,
price_per_million_tokens: Cents,
) -> Result<StripeBillingPrice> {
let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
meter.clone()
} else {
let meter = StripeMeter::create(
&self.client,
StripeCreateMeterParams {
default_aggregation: DefaultAggregation { formula: "sum" },
display_name: price_description.to_string(),
event_name: meter_event_name,
},
)
.await?;
self.meters_by_event_name
.insert(meter_event_name.to_string(), meter.clone());
meter
};
let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
price_id.clone()
} else {
let price = stripe::Price::create(
&self.client,
stripe::CreatePrice {
active: Some(true),
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
currency: stripe::Currency::USD,
currency_options: None,
custom_unit_amount: None,
expand: &[],
lookup_key: None,
metadata: None,
nickname: None,
product: None,
product_data: Some(stripe::CreatePriceProductData {
id: None,
active: Some(true),
metadata: None,
name: price_description.to_string(),
statement_descriptor: None,
tax_code: None,
unit_label: None,
}),
recurring: Some(stripe::CreatePriceRecurring {
aggregate_usage: None,
interval: stripe::CreatePriceRecurringInterval::Month,
interval_count: None,
trial_period_days: None,
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
meter: Some(meter.id.clone()),
}),
tax_behavior: None,
tiers: None,
tiers_mode: None,
transfer_lookup_key: None,
transform_quantity: None,
unit_amount: None,
unit_amount_decimal: Some(&format!(
"{:.12}",
price_per_million_tokens.0 as f64 / 1_000_000f64
)),
},
)
.await?;
self.price_ids_by_meter_id
.insert(meter.id, price.id.clone());
price.id
};
Ok(StripeBillingPrice {
id: price_id,
meter_event_name: meter_event_name.to_string(),
})
}
pub async fn subscribe_to_model(
&self,
subscription_id: &stripe::SubscriptionId,
model: &StripeModel,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
let mut items = Vec::new();
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
{
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.input_cache_read_tokens_price.id.to_string()),
..Default::default()
});
}
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
items.push(stripe::UpdateSubscriptionItems {
price: Some(model.output_tokens_price.id.to_string()),
..Default::default()
});
}
if !items.is_empty() {
items.extend(subscription.items.data.iter().map(|item| {
stripe::UpdateSubscriptionItems {
id: Some(item.id.to_string()),
..Default::default()
}
}));
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(items),
..Default::default()
},
)
.await?;
}
Ok(())
}
pub async fn bill_model_usage(
&self,
customer_id: &stripe::CustomerId,
model: &StripeModel,
event: &llm::db::billing_event::Model,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
if event.input_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_tokens/{}", event.idempotency_key),
event_name: &model.input_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_creation_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_creation_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.input_cache_read_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
event_name: &model.input_cache_read_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.input_cache_read_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
if event.output_tokens > 0 {
StripeMeterEvent::create(
&self.client,
StripeCreateMeterEventParams {
identifier: &format!("output_tokens/{}", event.idempotency_key),
event_name: &model.output_tokens_price.meter_event_name,
payload: StripeCreateMeterEventPayload {
value: event.output_tokens as u64,
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
},
)
.await?;
}
Ok(())
}
pub async fn checkout(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
model: &StripeModel,
success_url: &str,
) -> Result<String> {
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(
[
&model.input_tokens_price.id,
&model.input_cache_creation_tokens_price.id,
&model.input_cache_read_tokens_price.id,
&model.output_tokens_price.id,
]
.into_iter()
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
price: Some(price_id.to_string()),
..Default::default()
})
.collect(),
);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
}
#[derive(Serialize)]
struct DefaultAggregation {
formula: &'static str,
}
#[derive(Serialize)]
struct StripeCreateMeterParams<'a> {
default_aggregation: DefaultAggregation,
display_name: String,
event_name: &'a str,
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
event_name: String,
}
impl StripeMeter {
pub fn create(
client: &stripe::Client,
params: StripeCreateMeterParams,
) -> stripe::Response<Self> {
client.post_form("/billing/meters", params)
}
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {}
client.get_query("/billing/meters", Params {})
}
}
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,
}
impl StripeMeterEvent {
pub async fn create(
client: &stripe::Client,
params: StripeCreateMeterEventParams<'_>,
) -> Result<Self, stripe::StripeError> {
let identifier = params.identifier;
match client.post_form("/billing/meter_events", params).await {
Ok(event) => Ok(event),
Err(stripe::StripeError::Stripe(error)) => {
if error.http_status == 400
&& error
.message
.as_ref()
.map_or(false, |message| message.contains(identifier))
{
Ok(Self {
identifier: identifier.to_string(),
})
} else {
Err(stripe::StripeError::Stripe(error))
}
}
Err(error) => Err(error),
}
}
}
#[derive(Serialize)]
struct StripeCreateMeterEventParams<'a> {
identifier: &'a str,
event_name: &'a str,
payload: StripeCreateMeterEventPayload<'a>,
timestamp: Option<i64>,
}
#[derive(Serialize)]
struct StripeCreateMeterEventPayload<'a> {
value: u64,
stripe_customer_id: &'a stripe::CustomerId,
}
fn subscription_contains_price(
subscription: &stripe::Subscription,
price_id: &stripe::PriceId,
) -> bool {
subscription.items.data.iter().any(|item| {
item.price
.as_ref()
.map_or(false, |price| price.id == *price_id)
})
}

View file

@ -635,6 +635,7 @@ impl TestServer {
) -> Arc<AppState> { ) -> Arc<AppState> {
Arc::new(AppState { Arc::new(AppState {
db: test_db.db().clone(), db: test_db.db().clone(),
llm_db: None,
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None, blob_store_client: None,
stripe_client: None, stripe_client: None,
@ -677,8 +678,6 @@ impl TestServer {
migrations_path: None, migrations_path: None,
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_llm_access_price_id: None,
stripe_llm_usage_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
}, },