diff --git a/Cargo.lock b/Cargo.lock index 0a607155e0..5f6218cb65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -839,9 +839,8 @@ dependencies = [ [[package]] name = "async-stripe" -version = "0.39.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58d670cf4d47a1b8ffef54286a5625382e360a34ee76902fd93ad8c7032a0c30" +version = "0.40.0" +source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735" dependencies = [ "chrono", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index fa4289b9a2..42f935e66b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -480,7 +480,8 @@ which = "6.0.0" wit-component = "0.201" [workspace.dependencies.async-stripe] -version = "0.39" +git = "https://github.com/zed-industries/async-stripe" +rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735" default-features = false features = [ "runtime-tokio-hyper-rustls", diff --git a/crates/collab/migrations_llm/20241010151249_create_billing_events.sql b/crates/collab/migrations_llm/20241010151249_create_billing_events.sql new file mode 100644 index 0000000000..74a270872e --- /dev/null +++ b/crates/collab/migrations_llm/20241010151249_create_billing_events.sql @@ -0,0 +1,12 @@ +create table billing_events ( + id serial primary key, + idempotency_key uuid not null default gen_random_uuid(), + user_id integer not null, + model_id integer not null references models (id) on delete cascade, + input_tokens bigint not null default 0, + input_cache_creation_tokens bigint not null default 0, + input_cache_read_tokens bigint not null default 0, + output_tokens bigint not null default 0 +); + +create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id); diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 838dea1981..c0fc33a643 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1,7 +1,3 @@ -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; - use anyhow::{anyhow, bail, Context}; use axum::{ extract::{self, Query}, @@ -9,28 +5,35 @@ use axum::{ Extension, Json, Router, }; use chrono::{DateTime, SecondsFormat, Utc}; +use collections::HashSet; use reqwest::StatusCode; use sea_orm::ActiveValue; use serde::{Deserialize, Serialize}; +use std::{str::FromStr, sync::Arc, time::Duration}; use stripe::{ - BillingPortalSession, CheckoutSession, CreateBillingPortalSession, - CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, + BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData, + CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowDataAfterCompletionRedirect, - CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems, - CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents, - Subscription, SubscriptionId, SubscriptionStatus, + CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject, + EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus, }; use util::ResultExt; -use crate::db::billing_subscription::{self, StripeSubscriptionStatus}; -use crate::db::{ - billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, - CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, -}; -use crate::llm::db::LlmDatabase; -use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; +use crate::llm::DEFAULT_MAX_MONTHLY_SPEND; use crate::rpc::ResultExt as _; +use crate::{ + db::{ + billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, + CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, + UpdateBillingCustomerParams, UpdateBillingPreferencesParams, + UpdateBillingSubscriptionParams, + }, + stripe_billing::StripeBilling, +}; +use crate::{ + db::{billing_subscription::StripeSubscriptionStatus, UserId}, + llm::db::LlmDatabase, +}; use crate::{AppState, Error, Result}; pub fn router() -> Router { @@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody { async fn update_billing_preferences( Extension(app): Extension>, + Extension(rpc_server): Extension>, extract::Json(body): extract::Json, ) -> Result> { let user = app @@ -119,6 +123,8 @@ async fn update_billing_preferences( .await? }; + rpc_server.refresh_llm_tokens_for_user(user.id).await; + Ok(Json(BillingPreferencesResponse { max_monthly_llm_usage_spending_in_cents: billing_preferences .max_monthly_llm_usage_spending_in_cents, @@ -197,12 +203,15 @@ async fn create_billing_subscription( .await? .ok_or_else(|| anyhow!("user not found"))?; - let Some((stripe_client, stripe_access_price_id)) = app - .stripe_client - .clone() - .zip(app.config.stripe_llm_access_price_id.clone()) - else { - log::error!("failed to retrieve Stripe client or price ID"); + let Some(stripe_client) = app.stripe_client.clone() else { + log::error!("failed to retrieve Stripe client"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + let Some(llm_db) = app.llm_db.clone() else { + log::error!("failed to retrieve LLM database"); Err(Error::http( StatusCode::NOT_IMPLEMENTED, "not supported".into(), @@ -226,26 +235,15 @@ async fn create_billing_subscription( customer.id }; - let checkout_session = { - let mut params = CreateCheckoutSession::new(); - params.mode = Some(stripe::CheckoutSessionMode::Subscription); - params.customer = Some(customer_id); - params.client_reference_id = Some(user.github_login.as_str()); - params.line_items = Some(vec![CreateCheckoutSessionLineItems { - price: Some(stripe_access_price_id.to_string()), - quantity: Some(1), - ..Default::default() - }]); - let success_url = format!("{}/account", app.config.zed_dot_dev_url()); - params.success_url = Some(&success_url); - - CheckoutSession::create(&stripe_client, params).await? - }; - + let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?; + let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?; + let stripe_model = stripe_billing.register_model(default_model).await?; + let success_url = format!("{}/account", app.config.zed_dot_dev_url()); + let checkout_session_url = stripe_billing + .checkout(customer_id, &user.github_login, &stripe_model, &success_url) + .await?; Ok(Json(CreateBillingSubscriptionResponse { - checkout_session_url: checkout_session - .url - .ok_or_else(|| anyhow!("no checkout session URL"))?, + checkout_session_url, })) } @@ -715,15 +713,15 @@ async fn find_or_create_billing_customer( Ok(Some(billing_customer)) } -const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60); +const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60); -pub fn sync_llm_usage_with_stripe_periodically(app: Arc, llm_db: LlmDatabase) { +pub fn sync_llm_usage_with_stripe_periodically(app: Arc) { let Some(stripe_client) = app.stripe_client.clone() else { log::warn!("failed to retrieve Stripe client"); return; }; - let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else { - log::warn!("failed to retrieve Stripe LLM usage price ID"); + let Some(llm_db) = app.llm_db.clone() else { + log::warn!("failed to retrieve LLM database"); return; }; @@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc, llm_db: LlmDa let executor = executor.clone(); async move { loop { - sync_with_stripe( - &app, - &llm_db, - &stripe_client, - stripe_llm_usage_price_id.clone(), - ) - .await - .trace_err(); - + sync_with_stripe(&app, &llm_db, &stripe_client) + .await + .trace_err(); executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await; } } @@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc, llm_db: LlmDa async fn sync_with_stripe( app: &Arc, - llm_db: &LlmDatabase, - stripe_client: &stripe::Client, - stripe_llm_usage_price_id: Arc, + llm_db: &Arc, + stripe_client: &Arc, ) -> anyhow::Result<()> { - let subscriptions = app.db.get_active_billing_subscriptions().await?; + let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?; - for (customer, subscription) in subscriptions { - update_stripe_subscription( - llm_db, - stripe_client, - &stripe_llm_usage_price_id, - customer, - subscription, - ) - .await - .log_err(); + let events = llm_db.get_billing_events().await?; + let user_ids = events + .iter() + .map(|(event, _)| event.user_id) + .collect::>(); + let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?; + + for (event, model) in events { + let Some((stripe_db_customer, stripe_db_subscription)) = + stripe_subscriptions.get(&event.user_id) + else { + tracing::warn!( + user_id = event.user_id.0, + "Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side." + ); + continue; + }; + let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription + .stripe_subscription_id + .parse() + .context("failed to parse stripe subscription id from db")?; + let stripe_customer_id: stripe::CustomerId = stripe_db_customer + .stripe_customer_id + .parse() + .context("failed to parse stripe customer id from db")?; + + let stripe_model = stripe_billing.register_model(&model).await?; + stripe_billing + .subscribe_to_model(&stripe_subscription_id, &stripe_model) + .await?; + stripe_billing + .bill_model_usage(&stripe_customer_id, &stripe_model, &event) + .await?; + llm_db.consume_billing_event(event.id).await?; } Ok(()) } - -async fn update_stripe_subscription( - llm_db: &LlmDatabase, - stripe_client: &stripe::Client, - stripe_llm_usage_price_id: &Arc, - customer: billing_customer::Model, - subscription: billing_subscription::Model, -) -> Result<(), anyhow::Error> { - let monthly_spending = llm_db - .get_user_spending_for_month(customer.user_id, Utc::now()) - .await?; - let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) - .context("failed to parse subscription ID")?; - - let monthly_spending_over_free_tier = - monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT); - - let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil(); - let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?; - - let mut update_params = stripe::UpdateSubscription { - proration_behavior: Some( - stripe::generated::billing::subscription::SubscriptionProrationBehavior::None, - ), - ..Default::default() - }; - - if let Some(existing_item) = current_subscription.items.data.iter().find(|item| { - item.price.as_ref().map_or(false, |price| { - price.id == stripe_llm_usage_price_id.as_ref() - }) - }) { - update_params.items = Some(vec![stripe::UpdateSubscriptionItems { - id: Some(existing_item.id.to_string()), - quantity: Some(new_quantity as u64), - ..Default::default() - }]); - } else { - update_params.items = Some(vec![stripe::UpdateSubscriptionItems { - price: Some(stripe_llm_usage_price_id.to_string()), - quantity: Some(new_quantity as u64), - ..Default::default() - }]); - } - - Subscription::update(stripe_client, &subscription_id, update_params).await?; - Ok(()) -} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index bcf093bebd..53a17f9c53 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -114,23 +114,31 @@ impl Database { pub async fn get_active_billing_subscriptions( &self, - ) -> Result> { - self.transaction(|tx| async move { - let mut result = Vec::new(); - let mut rows = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .select_also(billing_customer::Entity) - .order_by_asc(billing_subscription::Column::Id) - .stream(&*tx) - .await?; + user_ids: HashSet, + ) -> Result> { + self.transaction(|tx| { + let user_ids = user_ids.clone(); + async move { + let mut rows = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .select_also(billing_customer::Entity) + .filter(billing_customer::Column::UserId.is_in(user_ids)) + .filter( + billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Active), + ) + .order_by_asc(billing_subscription::Column::Id) + .stream(&*tx) + .await?; - while let Some(row) = rows.next().await { - if let (subscription, Some(customer)) = row? { - result.push((customer, subscription)); + let mut subscriptions = HashMap::default(); + while let Some(row) = rows.next().await { + if let (subscription, Some(customer)) = row? { + subscriptions.insert(customer.user_id, (customer, subscription)); + } } + Ok(subscriptions) } - - Ok(result) }) .await } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 3896926f43..0cc50f68f3 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -10,6 +10,7 @@ pub mod migrations; mod rate_limiter; pub mod rpc; pub mod seed; +pub mod stripe_billing; pub mod user_backfiller; #[cfg(test)] @@ -24,6 +25,7 @@ use axum::{ pub use cents::*; use db::{ChannelId, Database}; use executor::Executor; +use llm::db::LlmDatabase; pub use rate_limiter::*; use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; @@ -176,8 +178,6 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, - pub stripe_llm_access_price_id: Option>, - pub stripe_llm_usage_price_id: Option>, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -197,7 +197,7 @@ impl Config { } pub fn is_llm_billing_enabled(&self) -> bool { - self.stripe_llm_usage_price_id.is_some() + self.stripe_api_key.is_some() } #[cfg(test)] @@ -238,8 +238,6 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_llm_access_price_id: None, - stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, } @@ -272,6 +270,7 @@ impl ServiceMode { pub struct AppState { pub db: Arc, + pub llm_db: Option>, pub live_kit_client: Option>, pub blob_store_client: Option, pub stripe_client: Option>, @@ -288,6 +287,20 @@ impl AppState { let mut db = Database::new(db_options, Executor::Production).await?; db.initialize_notification_kinds().await?; + let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config + .llm_database_url + .clone() + .zip(config.llm_database_max_connections) + { + let mut llm_db_options = db::ConnectOptions::new(llm_database_url); + llm_db_options.max_connections(llm_database_max_connections); + let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?; + llm_db.initialize().await?; + Some(Arc::new(llm_db)) + } else { + None + }; + let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() @@ -306,9 +319,10 @@ impl AppState { let db = Arc::new(db); let this = Self { db: db.clone(), + llm_db, live_kit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), - stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(), + stripe_client: build_stripe_client(&config).map(Arc::new).log_err(), rate_limiter: Arc::new(RateLimiter::new(db)), executor, clickhouse_client: config @@ -321,12 +335,11 @@ impl AppState { } } -async fn build_stripe_client(config: &Config) -> anyhow::Result { +fn build_stripe_client(config: &Config) -> anyhow::Result { let api_key = config .stripe_api_key .as_ref() .ok_or_else(|| anyhow!("missing stripe_api_key"))?; - Ok(stripe::Client::new(api_key)) } diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index c389a4de62..70a309367c 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -20,13 +20,14 @@ use axum::{ }; use chrono::{DateTime, Duration, Utc}; use collections::HashMap; +use db::TokenUsage; use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use isahc_http_client::IsahcHttpClient; -use rpc::ListModelsResponse; use rpc::{ proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, }; +use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME}; use std::{ pin::Pin, sync::Arc, @@ -418,10 +419,7 @@ async fn perform_completion( claims, provider: params.provider, model, - input_tokens: 0, - output_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, + tokens: TokenUsage::default(), inner_stream: stream, }))) } @@ -476,6 +474,19 @@ async fn check_usage_limit( "Maximum spending limit reached for this month.".to_string(), )); } + + if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) { + return Err(Error::Http( + StatusCode::FORBIDDEN, + "Maximum spending limit reached for this month.".to_string(), + [( + HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME), + HeaderValue::from_static("true"), + )] + .into_iter() + .collect(), + )); + } } } @@ -598,10 +609,7 @@ struct TokenCountingStream { claims: LlmTokenClaims, provider: LanguageModelProvider, model: String, - input_tokens: usize, - output_tokens: usize, - cache_creation_input_tokens: usize, - cache_read_input_tokens: usize, + tokens: TokenUsage, inner_stream: S, } @@ -615,10 +623,10 @@ where match Pin::new(&mut self.inner_stream).poll_next(cx) { Poll::Ready(Some(Ok(mut chunk))) => { chunk.bytes.push(b'\n'); - self.input_tokens += chunk.input_tokens; - self.output_tokens += chunk.output_tokens; - self.cache_creation_input_tokens += chunk.cache_creation_input_tokens; - self.cache_read_input_tokens += chunk.cache_read_input_tokens; + self.tokens.input += chunk.input_tokens; + self.tokens.output += chunk.output_tokens; + self.tokens.input_cache_creation += chunk.cache_creation_input_tokens; + self.tokens.input_cache_read += chunk.cache_read_input_tokens; Poll::Ready(Some(Ok(chunk.bytes))) } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), @@ -634,10 +642,7 @@ impl Drop for TokenCountingStream { let claims = self.claims.clone(); let provider = self.provider; let model = std::mem::take(&mut self.model); - let input_token_count = self.input_tokens; - let output_token_count = self.output_tokens; - let cache_creation_input_token_count = self.cache_creation_input_tokens; - let cache_read_input_token_count = self.cache_read_input_tokens; + let tokens = self.tokens; self.state.executor.spawn_detached(async move { let usage = state .db @@ -646,10 +651,9 @@ impl Drop for TokenCountingStream { claims.is_staff, provider, &model, - input_token_count, - cache_creation_input_token_count, - cache_read_input_token_count, - output_token_count, + tokens, + claims.has_llm_subscription, + Cents(claims.max_monthly_spend_in_cents), Utc::now(), ) .await @@ -679,22 +683,23 @@ impl Drop for TokenCountingStream { }, model, provider: provider.to_string(), - input_token_count: input_token_count as u64, - cache_creation_input_token_count: cache_creation_input_token_count - as u64, - cache_read_input_token_count: cache_read_input_token_count as u64, - output_token_count: output_token_count as u64, + input_token_count: tokens.input as u64, + cache_creation_input_token_count: tokens.input_cache_creation as u64, + cache_read_input_token_count: tokens.input_cache_read as u64, + output_token_count: tokens.output as u64, requests_this_minute: usage.requests_this_minute as u64, tokens_this_minute: usage.tokens_this_minute as u64, tokens_this_day: usage.tokens_this_day as u64, - input_tokens_this_month: usage.input_tokens_this_month as u64, + input_tokens_this_month: usage.tokens_this_month.input as u64, cache_creation_input_tokens_this_month: usage - .cache_creation_input_tokens_this_month + .tokens_this_month + .input_cache_creation as u64, cache_read_input_tokens_this_month: usage - .cache_read_input_tokens_this_month + .tokens_this_month + .input_cache_read as u64, - output_tokens_this_month: usage.output_tokens_this_month as u64, + output_tokens_this_month: usage.tokens_this_month.output as u64, spending_this_month: usage.spending_this_month.0 as u64, lifetime_spending: usage.lifetime_spending.0 as u64, }, diff --git a/crates/collab/src/llm/db.rs b/crates/collab/src/llm/db.rs index 996837116b..4374214c1b 100644 --- a/crates/collab/src/llm/db.rs +++ b/crates/collab/src/llm/db.rs @@ -20,7 +20,7 @@ use std::future::Future; use std::sync::Arc; use anyhow::anyhow; -pub use queries::usages::ActiveUserCount; +pub use queries::usages::{ActiveUserCount, TokenUsage}; use sea_orm::prelude::*; pub use sea_orm::ConnectOptions; use sea_orm::{ diff --git a/crates/collab/src/llm/db/ids.rs b/crates/collab/src/llm/db/ids.rs index 8cc8a0f974..80da32fbf4 100644 --- a/crates/collab/src/llm/db/ids.rs +++ b/crates/collab/src/llm/db/ids.rs @@ -8,3 +8,4 @@ id_type!(ProviderId); id_type!(UsageId); id_type!(UsageMeasureId); id_type!(RevokedAccessTokenId); +id_type!(BillingEventId); diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 907d0589f3..79a17999b7 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -1,5 +1,6 @@ use super::*; +pub mod billing_events; pub mod providers; pub mod revoked_access_tokens; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/billing_events.rs b/crates/collab/src/llm/db/queries/billing_events.rs new file mode 100644 index 0000000000..400477f234 --- /dev/null +++ b/crates/collab/src/llm/db/queries/billing_events.rs @@ -0,0 +1,31 @@ +use super::*; +use crate::Result; +use anyhow::Context as _; + +impl LlmDatabase { + pub async fn get_billing_events(&self) -> Result> { + self.transaction(|tx| async move { + let events_with_models = billing_event::Entity::find() + .find_also_related(model::Entity) + .all(&*tx) + .await?; + events_with_models + .into_iter() + .map(|(event, model)| { + let model = + model.context("could not find model associated with billing event")?; + Ok((event, model)) + }) + .collect() + }) + .await + } + + pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> { + self.transaction(|tx| async move { + billing_event::Entity::delete_by_id(id).exec(&*tx).await?; + Ok(()) + }) + .await + } +} diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 3d6ab18415..c7f8fd163a 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,5 +1,5 @@ -use crate::db::UserId; use crate::llm::Cents; +use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT}; use chrono::{Datelike, Duration}; use futures::StreamExt as _; use rpc::LanguageModelProvider; @@ -9,15 +9,26 @@ use strum::IntoEnumIterator as _; use super::*; +#[derive(Debug, PartialEq, Clone, Copy, Default)] +pub struct TokenUsage { + pub input: usize, + pub input_cache_creation: usize, + pub input_cache_read: usize, + pub output: usize, +} + +impl TokenUsage { + pub fn total(&self) -> usize { + self.input + self.input_cache_creation + self.input_cache_read + self.output + } +} + #[derive(Debug, PartialEq, Clone, Copy)] pub struct Usage { pub requests_this_minute: usize, pub tokens_this_minute: usize, pub tokens_this_day: usize, - pub input_tokens_this_month: usize, - pub cache_creation_input_tokens_this_month: usize, - pub cache_read_input_tokens_this_month: usize, - pub output_tokens_this_month: usize, + pub tokens_this_month: TokenUsage, pub spending_this_month: Cents, pub lifetime_spending: Cents, } @@ -257,18 +268,20 @@ impl LlmDatabase { requests_this_minute, tokens_this_minute, tokens_this_day, - input_tokens_this_month: monthly_usage - .as_ref() - .map_or(0, |usage| usage.input_tokens as usize), - cache_creation_input_tokens_this_month: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_creation_input_tokens as usize), - cache_read_input_tokens_this_month: monthly_usage - .as_ref() - .map_or(0, |usage| usage.cache_read_input_tokens as usize), - output_tokens_this_month: monthly_usage - .as_ref() - .map_or(0, |usage| usage.output_tokens as usize), + tokens_this_month: TokenUsage { + input: monthly_usage + .as_ref() + .map_or(0, |usage| usage.input_tokens as usize), + input_cache_creation: monthly_usage + .as_ref() + .map_or(0, |usage| usage.cache_creation_input_tokens as usize), + input_cache_read: monthly_usage + .as_ref() + .map_or(0, |usage| usage.cache_read_input_tokens as usize), + output: monthly_usage + .as_ref() + .map_or(0, |usage| usage.output_tokens as usize), + }, spending_this_month, lifetime_spending, }) @@ -283,10 +296,9 @@ impl LlmDatabase { is_staff: bool, provider: LanguageModelProvider, model_name: &str, - input_token_count: usize, - cache_creation_input_tokens: usize, - cache_read_input_tokens: usize, - output_token_count: usize, + tokens: TokenUsage, + has_llm_subscription: bool, + max_monthly_spend: Cents, now: DateTimeUtc, ) -> Result { self.transaction(|tx| async move { @@ -313,10 +325,6 @@ impl LlmDatabase { &tx, ) .await?; - let total_token_count = input_token_count - + cache_read_input_tokens - + cache_creation_input_tokens - + output_token_count; let tokens_this_minute = self .update_usage_for_measure( user_id, @@ -325,7 +333,7 @@ impl LlmDatabase { &usages, UsageMeasure::TokensPerMinute, now, - total_token_count, + tokens.total(), &tx, ) .await?; @@ -337,7 +345,7 @@ impl LlmDatabase { &usages, UsageMeasure::TokensPerDay, now, - total_token_count, + tokens.total(), &tx, ) .await?; @@ -361,18 +369,14 @@ impl LlmDatabase { Some(usage) => { monthly_usage::Entity::update(monthly_usage::ActiveModel { id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set( - usage.input_tokens + input_token_count as i64, - ), + input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, + usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, ), cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + cache_read_input_tokens as i64, - ), - output_tokens: ActiveValue::set( - usage.output_tokens + output_token_count as i64, + usage.cache_read_input_tokens + tokens.input_cache_read as i64, ), + output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), ..Default::default() }) .exec(&*tx) @@ -384,12 +388,12 @@ impl LlmDatabase { model_id: ActiveValue::set(model.id), month: ActiveValue::set(month), year: ActiveValue::set(year), - input_tokens: ActiveValue::set(input_token_count as i64), + input_tokens: ActiveValue::set(tokens.input as i64), cache_creation_input_tokens: ActiveValue::set( - cache_creation_input_tokens as i64, + tokens.input_cache_creation as i64, ), - cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64), - output_tokens: ActiveValue::set(output_token_count as i64), + cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), + output_tokens: ActiveValue::set(tokens.output as i64), ..Default::default() } .insert(&*tx) @@ -405,6 +409,26 @@ impl LlmDatabase { monthly_usage.output_tokens as usize, ); + if spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT + && has_llm_subscription + && spending_this_month <= max_monthly_spend + { + billing_event::ActiveModel { + id: ActiveValue::not_set(), + idempotency_key: ActiveValue::not_set(), + user_id: ActiveValue::set(user_id), + model_id: ActiveValue::set(model.id), + input_tokens: ActiveValue::set(tokens.input as i64), + input_cache_creation_tokens: ActiveValue::set( + tokens.input_cache_creation as i64, + ), + input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64), + output_tokens: ActiveValue::set(tokens.output as i64), + } + .insert(&*tx) + .await?; + } + // Update lifetime usage let lifetime_usage = lifetime_usage::Entity::find() .filter( @@ -419,18 +443,14 @@ impl LlmDatabase { Some(usage) => { lifetime_usage::Entity::update(lifetime_usage::ActiveModel { id: ActiveValue::unchanged(usage.id), - input_tokens: ActiveValue::set( - usage.input_tokens + input_token_count as i64, - ), + input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64), cache_creation_input_tokens: ActiveValue::set( - usage.cache_creation_input_tokens + cache_creation_input_tokens as i64, + usage.cache_creation_input_tokens + tokens.input_cache_creation as i64, ), cache_read_input_tokens: ActiveValue::set( - usage.cache_read_input_tokens + cache_read_input_tokens as i64, - ), - output_tokens: ActiveValue::set( - usage.output_tokens + output_token_count as i64, + usage.cache_read_input_tokens + tokens.input_cache_read as i64, ), + output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64), ..Default::default() }) .exec(&*tx) @@ -440,12 +460,12 @@ impl LlmDatabase { lifetime_usage::ActiveModel { user_id: ActiveValue::set(user_id), model_id: ActiveValue::set(model.id), - input_tokens: ActiveValue::set(input_token_count as i64), + input_tokens: ActiveValue::set(tokens.input as i64), cache_creation_input_tokens: ActiveValue::set( - cache_creation_input_tokens as i64, + tokens.input_cache_creation as i64, ), - cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64), - output_tokens: ActiveValue::set(output_token_count as i64), + cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64), + output_tokens: ActiveValue::set(tokens.output as i64), ..Default::default() } .insert(&*tx) @@ -465,11 +485,12 @@ impl LlmDatabase { requests_this_minute, tokens_this_minute, tokens_this_day, - input_tokens_this_month: monthly_usage.input_tokens as usize, - cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens - as usize, - cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize, - output_tokens_this_month: monthly_usage.output_tokens as usize, + tokens_this_month: TokenUsage { + input: monthly_usage.input_tokens as usize, + input_cache_creation: monthly_usage.cache_creation_input_tokens as usize, + input_cache_read: monthly_usage.cache_read_input_tokens as usize, + output: monthly_usage.output_tokens as usize, + }, spending_this_month, lifetime_spending, }) diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 57aded70e9..407c5c8fd0 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,3 +1,4 @@ +pub mod billing_event; pub mod lifetime_usage; pub mod model; pub mod monthly_usage; diff --git a/crates/collab/src/llm/db/tables/billing_event.rs b/crates/collab/src/llm/db/tables/billing_event.rs new file mode 100644 index 0000000000..93987bc71e --- /dev/null +++ b/crates/collab/src/llm/db/tables/billing_event.rs @@ -0,0 +1,37 @@ +use crate::{ + db::UserId, + llm::db::{BillingEventId, ModelId}, +}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "billing_events")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: BillingEventId, + pub idempotency_key: Uuid, + pub user_id: UserId, + pub model_id: ModelId, + pub input_tokens: i64, + pub input_cache_creation_tokens: i64, + pub input_cache_read_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::model::Entity", + from = "Column::ModelId", + to = "super::model::Column::Id" + )] + Model, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Model.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/model.rs b/crates/collab/src/llm/db/tables/model.rs index 4d7d2d8da9..6c52184185 100644 --- a/crates/collab/src/llm/db/tables/model.rs +++ b/crates/collab/src/llm/db/tables/model.rs @@ -29,6 +29,8 @@ pub enum Relation { Provider, #[sea_orm(has_many = "super::usage::Entity")] Usages, + #[sea_orm(has_many = "super::billing_event::Entity")] + BillingEvents, } impl Related for Entity { @@ -43,4 +45,10 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::BillingEvents.def() + } +} + impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 2730a03046..8e96ac4f54 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -2,7 +2,7 @@ use crate::{ db::UserId, llm::db::{ queries::{providers::ModelParams, usages::Usage}, - LlmDatabase, + LlmDatabase, TokenUsage, }, test_llm_db, Cents, }; @@ -36,14 +36,42 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { let user_id = UserId::from_proto(123); let now = t0; - db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 1000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let now = t0 + Duration::seconds(10); - db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 2000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); assert_eq!( @@ -52,10 +80,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 2, tokens_this_minute: 3000, tokens_this_day: 3000, - input_tokens_this_month: 3000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 3000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } @@ -69,19 +99,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 1, tokens_this_minute: 2000, tokens_this_day: 3000, - input_tokens_this_month: 3000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 3000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } ); let now = t0 + Duration::seconds(60); - db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 3000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); assert_eq!( @@ -90,10 +136,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 2, tokens_this_minute: 5000, tokens_this_day: 6000, - input_tokens_this_month: 6000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 6000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } @@ -108,18 +156,34 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 0, tokens_this_minute: 0, tokens_this_day: 5000, - input_tokens_this_month: 6000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 6000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } ); - db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 4000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); assert_eq!( @@ -128,10 +192,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 1, tokens_this_minute: 4000, tokens_this_day: 9000, - input_tokens_this_month: 10000, - cache_creation_input_tokens_this_month: 0, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 10000, + input_cache_creation: 0, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } @@ -143,9 +209,23 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { .with_timezone(&Utc); // Test cache creation input tokens - db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 1000, + input_cache_creation: 500, + input_cache_read: 0, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); assert_eq!( @@ -154,19 +234,35 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 1, tokens_this_minute: 1500, tokens_this_day: 1500, - input_tokens_this_month: 1000, - cache_creation_input_tokens_this_month: 500, - cache_read_input_tokens_this_month: 0, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 1000, + input_cache_creation: 500, + input_cache_read: 0, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } ); // Test cache read input tokens - db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now) - .await - .unwrap(); + db.record_usage( + user_id, + false, + provider, + model, + TokenUsage { + input: 1000, + input_cache_creation: 0, + input_cache_read: 300, + output: 0, + }, + false, + Cents::ZERO, + now, + ) + .await + .unwrap(); let usage = db.get_usage(user_id, provider, model, now).await.unwrap(); assert_eq!( @@ -175,10 +271,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { requests_this_minute: 2, tokens_this_minute: 2800, tokens_this_day: 2800, - input_tokens_this_month: 2000, - cache_creation_input_tokens_this_month: 500, - cache_read_input_tokens_this_month: 300, - output_tokens_this_month: 0, + tokens_this_month: TokenUsage { + input: 2000, + input_cache_creation: 500, + input_cache_read: 300, + output: 0, + }, spending_this_month: Cents::ZERO, lifetime_spending: Cents::ZERO, } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index bd227f17c7..02c0baf9de 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -157,7 +157,7 @@ async fn main() -> Result<()> { if let Some(mut llm_db) = llm_db { llm_db.initialize().await?; - sync_llm_usage_with_stripe_periodically(state.clone(), llm_db); + sync_llm_usage_with_stripe_periodically(state.clone()); } app = app diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 2b861dd403..e5dafe80d8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1218,6 +1218,15 @@ impl Server { Ok(()) } + pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer + .send(connection_id, proto::RefreshLlmToken {}) + .trace_err(); + } + } + pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { connection_pool: ConnectionPoolGuard { diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs new file mode 100644 index 0000000000..878d160666 --- /dev/null +++ b/crates/collab/src/stripe_billing.rs @@ -0,0 +1,427 @@ +use std::sync::Arc; + +use crate::{llm, Cents, Result}; +use anyhow::Context; +use chrono::Utc; +use collections::HashMap; +use serde::{Deserialize, Serialize}; + +pub struct StripeBilling { + meters_by_event_name: HashMap, + price_ids_by_meter_id: HashMap, + client: Arc, +} + +pub struct StripeModel { + input_tokens_price: StripeBillingPrice, + input_cache_creation_tokens_price: StripeBillingPrice, + input_cache_read_tokens_price: StripeBillingPrice, + output_tokens_price: StripeBillingPrice, +} + +struct StripeBillingPrice { + id: stripe::PriceId, + meter_event_name: String, +} + +impl StripeBilling { + pub async fn new(client: Arc) -> Result { + let mut meters_by_event_name = HashMap::default(); + for meter in StripeMeter::list(&client).await?.data { + meters_by_event_name.insert(meter.event_name.clone(), meter); + } + + let mut price_ids_by_meter_id = HashMap::default(); + for price in stripe::Price::list(&client, &stripe::ListPrices::default()) + .await? + .data + { + if let Some(recurring) = price.recurring { + if let Some(meter) = recurring.meter { + price_ids_by_meter_id.insert(meter, price.id); + } + } + } + + Ok(Self { + meters_by_event_name, + price_ids_by_meter_id, + client, + }) + } + + pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result { + let input_tokens_price = self + .get_or_insert_price( + &format!("model_{}/input_tokens", model.id), + &format!("{} (Input Tokens)", model.name), + Cents::new(model.price_per_million_input_tokens as u32), + ) + .await?; + let input_cache_creation_tokens_price = self + .get_or_insert_price( + &format!("model_{}/input_cache_creation_tokens", model.id), + &format!("{} (Input Cache Creation Tokens)", model.name), + Cents::new(model.price_per_million_cache_creation_input_tokens as u32), + ) + .await?; + let input_cache_read_tokens_price = self + .get_or_insert_price( + &format!("model_{}/input_cache_read_tokens", model.id), + &format!("{} (Input Cache Read Tokens)", model.name), + Cents::new(model.price_per_million_cache_read_input_tokens as u32), + ) + .await?; + let output_tokens_price = self + .get_or_insert_price( + &format!("model_{}/output_tokens", model.id), + &format!("{} (Output Tokens)", model.name), + Cents::new(model.price_per_million_output_tokens as u32), + ) + .await?; + Ok(StripeModel { + input_tokens_price, + input_cache_creation_tokens_price, + input_cache_read_tokens_price, + output_tokens_price, + }) + } + + async fn get_or_insert_price( + &mut self, + meter_event_name: &str, + price_description: &str, + price_per_million_tokens: Cents, + ) -> Result { + let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) { + meter.clone() + } else { + let meter = StripeMeter::create( + &self.client, + StripeCreateMeterParams { + default_aggregation: DefaultAggregation { formula: "sum" }, + display_name: price_description.to_string(), + event_name: meter_event_name, + }, + ) + .await?; + self.meters_by_event_name + .insert(meter_event_name.to_string(), meter.clone()); + meter + }; + + let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) { + price_id.clone() + } else { + let price = stripe::Price::create( + &self.client, + stripe::CreatePrice { + active: Some(true), + billing_scheme: Some(stripe::PriceBillingScheme::PerUnit), + currency: stripe::Currency::USD, + currency_options: None, + custom_unit_amount: None, + expand: &[], + lookup_key: None, + metadata: None, + nickname: None, + product: None, + product_data: Some(stripe::CreatePriceProductData { + id: None, + active: Some(true), + metadata: None, + name: price_description.to_string(), + statement_descriptor: None, + tax_code: None, + unit_label: None, + }), + recurring: Some(stripe::CreatePriceRecurring { + aggregate_usage: None, + interval: stripe::CreatePriceRecurringInterval::Month, + interval_count: None, + trial_period_days: None, + usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered), + meter: Some(meter.id.clone()), + }), + tax_behavior: None, + tiers: None, + tiers_mode: None, + transfer_lookup_key: None, + transform_quantity: None, + unit_amount: None, + unit_amount_decimal: Some(&format!( + "{:.12}", + price_per_million_tokens.0 as f64 / 1_000_000f64 + )), + }, + ) + .await?; + self.price_ids_by_meter_id + .insert(meter.id, price.id.clone()); + price.id + }; + + Ok(StripeBillingPrice { + id: price_id, + meter_event_name: meter_event_name.to_string(), + }) + } + + pub async fn subscribe_to_model( + &self, + subscription_id: &stripe::SubscriptionId, + model: &StripeModel, + ) -> Result<()> { + let subscription = + stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?; + + let mut items = Vec::new(); + + if !subscription_contains_price(&subscription, &model.input_tokens_price.id) { + items.push(stripe::UpdateSubscriptionItems { + price: Some(model.input_tokens_price.id.to_string()), + ..Default::default() + }); + } + + if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id) + { + items.push(stripe::UpdateSubscriptionItems { + price: Some(model.input_cache_creation_tokens_price.id.to_string()), + ..Default::default() + }); + } + + if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) { + items.push(stripe::UpdateSubscriptionItems { + price: Some(model.input_cache_read_tokens_price.id.to_string()), + ..Default::default() + }); + } + + if !subscription_contains_price(&subscription, &model.output_tokens_price.id) { + items.push(stripe::UpdateSubscriptionItems { + price: Some(model.output_tokens_price.id.to_string()), + ..Default::default() + }); + } + + if !items.is_empty() { + items.extend(subscription.items.data.iter().map(|item| { + stripe::UpdateSubscriptionItems { + id: Some(item.id.to_string()), + ..Default::default() + } + })); + + stripe::Subscription::update( + &self.client, + subscription_id, + stripe::UpdateSubscription { + items: Some(items), + ..Default::default() + }, + ) + .await?; + } + + Ok(()) + } + + pub async fn bill_model_usage( + &self, + customer_id: &stripe::CustomerId, + model: &StripeModel, + event: &llm::db::billing_event::Model, + ) -> Result<()> { + let timestamp = Utc::now().timestamp(); + + if event.input_tokens > 0 { + StripeMeterEvent::create( + &self.client, + StripeCreateMeterEventParams { + identifier: &format!("input_tokens/{}", event.idempotency_key), + event_name: &model.input_tokens_price.meter_event_name, + payload: StripeCreateMeterEventPayload { + value: event.input_tokens as u64, + stripe_customer_id: customer_id, + }, + timestamp: Some(timestamp), + }, + ) + .await?; + } + + if event.input_cache_creation_tokens > 0 { + StripeMeterEvent::create( + &self.client, + StripeCreateMeterEventParams { + identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key), + event_name: &model.input_cache_creation_tokens_price.meter_event_name, + payload: StripeCreateMeterEventPayload { + value: event.input_cache_creation_tokens as u64, + stripe_customer_id: customer_id, + }, + timestamp: Some(timestamp), + }, + ) + .await?; + } + + if event.input_cache_read_tokens > 0 { + StripeMeterEvent::create( + &self.client, + StripeCreateMeterEventParams { + identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key), + event_name: &model.input_cache_read_tokens_price.meter_event_name, + payload: StripeCreateMeterEventPayload { + value: event.input_cache_read_tokens as u64, + stripe_customer_id: customer_id, + }, + timestamp: Some(timestamp), + }, + ) + .await?; + } + + if event.output_tokens > 0 { + StripeMeterEvent::create( + &self.client, + StripeCreateMeterEventParams { + identifier: &format!("output_tokens/{}", event.idempotency_key), + event_name: &model.output_tokens_price.meter_event_name, + payload: StripeCreateMeterEventPayload { + value: event.output_tokens as u64, + stripe_customer_id: customer_id, + }, + timestamp: Some(timestamp), + }, + ) + .await?; + } + + Ok(()) + } + + pub async fn checkout( + &self, + customer_id: stripe::CustomerId, + github_login: &str, + model: &StripeModel, + success_url: &str, + ) -> Result { + let mut params = stripe::CreateCheckoutSession::new(); + params.mode = Some(stripe::CheckoutSessionMode::Subscription); + params.customer = Some(customer_id); + params.client_reference_id = Some(github_login); + params.line_items = Some( + [ + &model.input_tokens_price.id, + &model.input_cache_creation_tokens_price.id, + &model.input_cache_read_tokens_price.id, + &model.output_tokens_price.id, + ] + .into_iter() + .map(|price_id| stripe::CreateCheckoutSessionLineItems { + price: Some(price_id.to_string()), + ..Default::default() + }) + .collect(), + ); + params.success_url = Some(success_url); + + let session = stripe::CheckoutSession::create(&self.client, params).await?; + Ok(session.url.context("no checkout session URL")?) + } +} + +#[derive(Serialize)] +struct DefaultAggregation { + formula: &'static str, +} + +#[derive(Serialize)] +struct StripeCreateMeterParams<'a> { + default_aggregation: DefaultAggregation, + display_name: String, + event_name: &'a str, +} + +#[derive(Clone, Deserialize)] +struct StripeMeter { + id: String, + event_name: String, +} + +impl StripeMeter { + pub fn create( + client: &stripe::Client, + params: StripeCreateMeterParams, + ) -> stripe::Response { + client.post_form("/billing/meters", params) + } + + pub fn list(client: &stripe::Client) -> stripe::Response> { + #[derive(Serialize)] + struct Params {} + + client.get_query("/billing/meters", Params {}) + } +} + +#[derive(Deserialize)] +struct StripeMeterEvent { + identifier: String, +} + +impl StripeMeterEvent { + pub async fn create( + client: &stripe::Client, + params: StripeCreateMeterEventParams<'_>, + ) -> Result { + let identifier = params.identifier; + match client.post_form("/billing/meter_events", params).await { + Ok(event) => Ok(event), + Err(stripe::StripeError::Stripe(error)) => { + if error.http_status == 400 + && error + .message + .as_ref() + .map_or(false, |message| message.contains(identifier)) + { + Ok(Self { + identifier: identifier.to_string(), + }) + } else { + Err(stripe::StripeError::Stripe(error)) + } + } + Err(error) => Err(error), + } + } +} + +#[derive(Serialize)] +struct StripeCreateMeterEventParams<'a> { + identifier: &'a str, + event_name: &'a str, + payload: StripeCreateMeterEventPayload<'a>, + timestamp: Option, +} + +#[derive(Serialize)] +struct StripeCreateMeterEventPayload<'a> { + value: u64, + stripe_customer_id: &'a stripe::CustomerId, +} + +fn subscription_contains_price( + subscription: &stripe::Subscription, + price_id: &stripe::PriceId, +) -> bool { + subscription.items.data.iter().any(|item| { + item.price + .as_ref() + .map_or(false, |price| price.id == *price_id) + }) +} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 683a53a2f5..484940c527 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -635,6 +635,7 @@ impl TestServer { ) -> Arc { Arc::new(AppState { db: test_db.db().clone(), + llm_db: None, live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), blob_store_client: None, stripe_client: None, @@ -677,8 +678,6 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_llm_access_price_id: None, - stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, },