collab: Add usage-based billing for LLM interactions (#19081)
This PR adds usage-based billing for LLM interactions in the Assistant. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra <me@as-cii.com> Co-authored-by: Antonio <antonio@zed.dev> Co-authored-by: Richard <richard@zed.dev> Co-authored-by: Richard Feldman <oss@rtfeldman.com>
This commit is contained in:
parent
f1c45d988e
commit
22ea7cef7a
20 changed files with 918 additions and 280 deletions
5
Cargo.lock
generated
5
Cargo.lock
generated
|
@ -839,9 +839,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stripe"
|
name = "async-stripe"
|
||||||
version = "0.39.1"
|
version = "0.40.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||||
checksum = "58d670cf4d47a1b8ffef54286a5625382e360a34ee76902fd93ad8c7032a0c30"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"chrono",
|
"chrono",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
|
|
@ -480,7 +480,8 @@ which = "6.0.0"
|
||||||
wit-component = "0.201"
|
wit-component = "0.201"
|
||||||
|
|
||||||
[workspace.dependencies.async-stripe]
|
[workspace.dependencies.async-stripe]
|
||||||
version = "0.39"
|
git = "https://github.com/zed-industries/async-stripe"
|
||||||
|
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = [
|
features = [
|
||||||
"runtime-tokio-hyper-rustls",
|
"runtime-tokio-hyper-rustls",
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
create table billing_events (
|
||||||
|
id serial primary key,
|
||||||
|
idempotency_key uuid not null default gen_random_uuid(),
|
||||||
|
user_id integer not null,
|
||||||
|
model_id integer not null references models (id) on delete cascade,
|
||||||
|
input_tokens bigint not null default 0,
|
||||||
|
input_cache_creation_tokens bigint not null default 0,
|
||||||
|
input_cache_read_tokens bigint not null default 0,
|
||||||
|
output_tokens bigint not null default 0
|
||||||
|
);
|
||||||
|
|
||||||
|
create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
|
|
@ -1,7 +1,3 @@
|
||||||
use std::str::FromStr;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, bail, Context};
|
use anyhow::{anyhow, bail, Context};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{self, Query},
|
extract::{self, Query},
|
||||||
|
@ -9,28 +5,35 @@ use axum::{
|
||||||
Extension, Json, Router,
|
Extension, Json, Router,
|
||||||
};
|
};
|
||||||
use chrono::{DateTime, SecondsFormat, Utc};
|
use chrono::{DateTime, SecondsFormat, Utc};
|
||||||
|
use collections::HashSet;
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
use sea_orm::ActiveValue;
|
use sea_orm::ActiveValue;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{str::FromStr, sync::Arc, time::Duration};
|
||||||
use stripe::{
|
use stripe::{
|
||||||
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
|
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
|
||||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||||
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
|
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||||
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
|
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||||
Subscription, SubscriptionId, SubscriptionStatus,
|
|
||||||
};
|
};
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
|
|
||||||
use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
|
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||||
use crate::db::{
|
|
||||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
|
||||||
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
|
|
||||||
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
|
|
||||||
};
|
|
||||||
use crate::llm::db::LlmDatabase;
|
|
||||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
|
||||||
use crate::rpc::ResultExt as _;
|
use crate::rpc::ResultExt as _;
|
||||||
|
use crate::{
|
||||||
|
db::{
|
||||||
|
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||||
|
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
|
||||||
|
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
|
||||||
|
UpdateBillingSubscriptionParams,
|
||||||
|
},
|
||||||
|
stripe_billing::StripeBilling,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
db::{billing_subscription::StripeSubscriptionStatus, UserId},
|
||||||
|
llm::db::LlmDatabase,
|
||||||
|
};
|
||||||
use crate::{AppState, Error, Result};
|
use crate::{AppState, Error, Result};
|
||||||
|
|
||||||
pub fn router() -> Router {
|
pub fn router() -> Router {
|
||||||
|
@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {
|
||||||
|
|
||||||
async fn update_billing_preferences(
|
async fn update_billing_preferences(
|
||||||
Extension(app): Extension<Arc<AppState>>,
|
Extension(app): Extension<Arc<AppState>>,
|
||||||
|
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
|
||||||
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
|
||||||
) -> Result<Json<BillingPreferencesResponse>> {
|
) -> Result<Json<BillingPreferencesResponse>> {
|
||||||
let user = app
|
let user = app
|
||||||
|
@ -119,6 +123,8 @@ async fn update_billing_preferences(
|
||||||
.await?
|
.await?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||||
|
|
||||||
Ok(Json(BillingPreferencesResponse {
|
Ok(Json(BillingPreferencesResponse {
|
||||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||||
.max_monthly_llm_usage_spending_in_cents,
|
.max_monthly_llm_usage_spending_in_cents,
|
||||||
|
@ -197,12 +203,15 @@ async fn create_billing_subscription(
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| anyhow!("user not found"))?;
|
.ok_or_else(|| anyhow!("user not found"))?;
|
||||||
|
|
||||||
let Some((stripe_client, stripe_access_price_id)) = app
|
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||||
.stripe_client
|
log::error!("failed to retrieve Stripe client");
|
||||||
.clone()
|
Err(Error::http(
|
||||||
.zip(app.config.stripe_llm_access_price_id.clone())
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
else {
|
"not supported".into(),
|
||||||
log::error!("failed to retrieve Stripe client or price ID");
|
))?
|
||||||
|
};
|
||||||
|
let Some(llm_db) = app.llm_db.clone() else {
|
||||||
|
log::error!("failed to retrieve LLM database");
|
||||||
Err(Error::http(
|
Err(Error::http(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"not supported".into(),
|
"not supported".into(),
|
||||||
|
@ -226,26 +235,15 @@ async fn create_billing_subscription(
|
||||||
customer.id
|
customer.id
|
||||||
};
|
};
|
||||||
|
|
||||||
let checkout_session = {
|
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||||
let mut params = CreateCheckoutSession::new();
|
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||||
params.customer = Some(customer_id);
|
|
||||||
params.client_reference_id = Some(user.github_login.as_str());
|
|
||||||
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
|
|
||||||
price: Some(stripe_access_price_id.to_string()),
|
|
||||||
quantity: Some(1),
|
|
||||||
..Default::default()
|
|
||||||
}]);
|
|
||||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||||
params.success_url = Some(&success_url);
|
let checkout_session_url = stripe_billing
|
||||||
|
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||||
CheckoutSession::create(&stripe_client, params).await?
|
.await?;
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Json(CreateBillingSubscriptionResponse {
|
Ok(Json(CreateBillingSubscriptionResponse {
|
||||||
checkout_session_url: checkout_session
|
checkout_session_url,
|
||||||
.url
|
|
||||||
.ok_or_else(|| anyhow!("no checkout session URL"))?,
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
|
||||||
Ok(Some(billing_customer))
|
Ok(Some(billing_customer))
|
||||||
}
|
}
|
||||||
|
|
||||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
|
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||||
|
|
||||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
|
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||||
log::warn!("failed to retrieve Stripe client");
|
log::warn!("failed to retrieve Stripe client");
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
|
let Some(llm_db) = app.llm_db.clone() else {
|
||||||
log::warn!("failed to retrieve Stripe LLM usage price ID");
|
log::warn!("failed to retrieve LLM database");
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
||||||
let executor = executor.clone();
|
let executor = executor.clone();
|
||||||
async move {
|
async move {
|
||||||
loop {
|
loop {
|
||||||
sync_with_stripe(
|
sync_with_stripe(&app, &llm_db, &stripe_client)
|
||||||
&app,
|
|
||||||
&llm_db,
|
|
||||||
&stripe_client,
|
|
||||||
stripe_llm_usage_price_id.clone(),
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.trace_err();
|
.trace_err();
|
||||||
|
|
||||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
|
||||||
|
|
||||||
async fn sync_with_stripe(
|
async fn sync_with_stripe(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
llm_db: &LlmDatabase,
|
llm_db: &Arc<LlmDatabase>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &Arc<stripe::Client>,
|
||||||
stripe_llm_usage_price_id: Arc<str>,
|
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let subscriptions = app.db.get_active_billing_subscriptions().await?;
|
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||||
|
|
||||||
for (customer, subscription) in subscriptions {
|
let events = llm_db.get_billing_events().await?;
|
||||||
update_stripe_subscription(
|
let user_ids = events
|
||||||
llm_db,
|
.iter()
|
||||||
stripe_client,
|
.map(|(event, _)| event.user_id)
|
||||||
&stripe_llm_usage_price_id,
|
.collect::<HashSet<UserId>>();
|
||||||
customer,
|
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||||
subscription,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.log_err();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
for (event, model) in events {
|
||||||
}
|
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||||
|
stripe_subscriptions.get(&event.user_id)
|
||||||
async fn update_stripe_subscription(
|
else {
|
||||||
llm_db: &LlmDatabase,
|
tracing::warn!(
|
||||||
stripe_client: &stripe::Client,
|
user_id = event.user_id.0,
|
||||||
stripe_llm_usage_price_id: &Arc<str>,
|
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||||
customer: billing_customer::Model,
|
);
|
||||||
subscription: billing_subscription::Model,
|
continue;
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
let monthly_spending = llm_db
|
|
||||||
.get_user_spending_for_month(customer.user_id, Utc::now())
|
|
||||||
.await?;
|
|
||||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
|
||||||
.context("failed to parse subscription ID")?;
|
|
||||||
|
|
||||||
let monthly_spending_over_free_tier =
|
|
||||||
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);
|
|
||||||
|
|
||||||
let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
|
|
||||||
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;
|
|
||||||
|
|
||||||
let mut update_params = stripe::UpdateSubscription {
|
|
||||||
proration_behavior: Some(
|
|
||||||
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
|
|
||||||
),
|
|
||||||
..Default::default()
|
|
||||||
};
|
};
|
||||||
|
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||||
|
.stripe_subscription_id
|
||||||
|
.parse()
|
||||||
|
.context("failed to parse stripe subscription id from db")?;
|
||||||
|
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||||
|
.stripe_customer_id
|
||||||
|
.parse()
|
||||||
|
.context("failed to parse stripe customer id from db")?;
|
||||||
|
|
||||||
if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
|
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||||
item.price.as_ref().map_or(false, |price| {
|
stripe_billing
|
||||||
price.id == stripe_llm_usage_price_id.as_ref()
|
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||||
})
|
.await?;
|
||||||
}) {
|
stripe_billing
|
||||||
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
|
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||||
id: Some(existing_item.id.to_string()),
|
.await?;
|
||||||
quantity: Some(new_quantity as u64),
|
llm_db.consume_billing_event(event.id).await?;
|
||||||
..Default::default()
|
|
||||||
}]);
|
|
||||||
} else {
|
|
||||||
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
|
|
||||||
price: Some(stripe_llm_usage_price_id.to_string()),
|
|
||||||
quantity: Some(new_quantity as u64),
|
|
||||||
..Default::default()
|
|
||||||
}]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Subscription::update(stripe_client, &subscription_id, update_params).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,23 +114,31 @@ impl Database {
|
||||||
|
|
||||||
pub async fn get_active_billing_subscriptions(
|
pub async fn get_active_billing_subscriptions(
|
||||||
&self,
|
&self,
|
||||||
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
|
user_ids: HashSet<UserId>,
|
||||||
self.transaction(|tx| async move {
|
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||||
let mut result = Vec::new();
|
self.transaction(|tx| {
|
||||||
|
let user_ids = user_ids.clone();
|
||||||
|
async move {
|
||||||
let mut rows = billing_subscription::Entity::find()
|
let mut rows = billing_subscription::Entity::find()
|
||||||
.inner_join(billing_customer::Entity)
|
.inner_join(billing_customer::Entity)
|
||||||
.select_also(billing_customer::Entity)
|
.select_also(billing_customer::Entity)
|
||||||
|
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||||
|
.filter(
|
||||||
|
billing_subscription::Column::StripeSubscriptionStatus
|
||||||
|
.eq(StripeSubscriptionStatus::Active),
|
||||||
|
)
|
||||||
.order_by_asc(billing_subscription::Column::Id)
|
.order_by_asc(billing_subscription::Column::Id)
|
||||||
.stream(&*tx)
|
.stream(&*tx)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let mut subscriptions = HashMap::default();
|
||||||
while let Some(row) = rows.next().await {
|
while let Some(row) = rows.next().await {
|
||||||
if let (subscription, Some(customer)) = row? {
|
if let (subscription, Some(customer)) = row? {
|
||||||
result.push((customer, subscription));
|
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(subscriptions)
|
||||||
Ok(result)
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ pub mod migrations;
|
||||||
mod rate_limiter;
|
mod rate_limiter;
|
||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
pub mod seed;
|
pub mod seed;
|
||||||
|
pub mod stripe_billing;
|
||||||
pub mod user_backfiller;
|
pub mod user_backfiller;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -24,6 +25,7 @@ use axum::{
|
||||||
pub use cents::*;
|
pub use cents::*;
|
||||||
use db::{ChannelId, Database};
|
use db::{ChannelId, Database};
|
||||||
use executor::Executor;
|
use executor::Executor;
|
||||||
|
use llm::db::LlmDatabase;
|
||||||
pub use rate_limiter::*;
|
pub use rate_limiter::*;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{path::PathBuf, sync::Arc};
|
use std::{path::PathBuf, sync::Arc};
|
||||||
|
@ -176,8 +178,6 @@ pub struct Config {
|
||||||
pub slack_panics_webhook: Option<String>,
|
pub slack_panics_webhook: Option<String>,
|
||||||
pub auto_join_channel_id: Option<ChannelId>,
|
pub auto_join_channel_id: Option<ChannelId>,
|
||||||
pub stripe_api_key: Option<String>,
|
pub stripe_api_key: Option<String>,
|
||||||
pub stripe_llm_access_price_id: Option<Arc<str>>,
|
|
||||||
pub stripe_llm_usage_price_id: Option<Arc<str>>,
|
|
||||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||||
}
|
}
|
||||||
|
@ -197,7 +197,7 @@ impl Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_llm_billing_enabled(&self) -> bool {
|
pub fn is_llm_billing_enabled(&self) -> bool {
|
||||||
self.stripe_llm_usage_price_id.is_some()
|
self.stripe_api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -238,8 +238,6 @@ impl Config {
|
||||||
migrations_path: None,
|
migrations_path: None,
|
||||||
seed_path: None,
|
seed_path: None,
|
||||||
stripe_api_key: None,
|
stripe_api_key: None,
|
||||||
stripe_llm_access_price_id: None,
|
|
||||||
stripe_llm_usage_price_id: None,
|
|
||||||
supermaven_admin_api_key: None,
|
supermaven_admin_api_key: None,
|
||||||
user_backfiller_github_access_token: None,
|
user_backfiller_github_access_token: None,
|
||||||
}
|
}
|
||||||
|
@ -272,6 +270,7 @@ impl ServiceMode {
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub db: Arc<Database>,
|
pub db: Arc<Database>,
|
||||||
|
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||||
|
@ -288,6 +287,20 @@ impl AppState {
|
||||||
let mut db = Database::new(db_options, Executor::Production).await?;
|
let mut db = Database::new(db_options, Executor::Production).await?;
|
||||||
db.initialize_notification_kinds().await?;
|
db.initialize_notification_kinds().await?;
|
||||||
|
|
||||||
|
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||||
|
.llm_database_url
|
||||||
|
.clone()
|
||||||
|
.zip(config.llm_database_max_connections)
|
||||||
|
{
|
||||||
|
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||||
|
llm_db_options.max_connections(llm_database_max_connections);
|
||||||
|
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||||
|
llm_db.initialize().await?;
|
||||||
|
Some(Arc::new(llm_db))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let live_kit_client = if let Some(((server, key), secret)) = config
|
let live_kit_client = if let Some(((server, key), secret)) = config
|
||||||
.live_kit_server
|
.live_kit_server
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
@ -306,9 +319,10 @@ impl AppState {
|
||||||
let db = Arc::new(db);
|
let db = Arc::new(db);
|
||||||
let this = Self {
|
let this = Self {
|
||||||
db: db.clone(),
|
db: db.clone(),
|
||||||
|
llm_db,
|
||||||
live_kit_client,
|
live_kit_client,
|
||||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||||
stripe_client: build_stripe_client(&config).await.map(Arc::new).log_err(),
|
stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
|
||||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||||
executor,
|
executor,
|
||||||
clickhouse_client: config
|
clickhouse_client: config
|
||||||
|
@ -321,12 +335,11 @@ impl AppState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||||
let api_key = config
|
let api_key = config
|
||||||
.stripe_api_key
|
.stripe_api_key
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
|
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
|
||||||
|
|
||||||
Ok(stripe::Client::new(api_key))
|
Ok(stripe::Client::new(api_key))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,14 @@ use axum::{
|
||||||
};
|
};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
|
use db::TokenUsage;
|
||||||
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
|
||||||
use futures::{Stream, StreamExt as _};
|
use futures::{Stream, StreamExt as _};
|
||||||
use isahc_http_client::IsahcHttpClient;
|
use isahc_http_client::IsahcHttpClient;
|
||||||
use rpc::ListModelsResponse;
|
|
||||||
use rpc::{
|
use rpc::{
|
||||||
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||||
};
|
};
|
||||||
|
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
|
||||||
use std::{
|
use std::{
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
|
@ -418,10 +419,7 @@ async fn perform_completion(
|
||||||
claims,
|
claims,
|
||||||
provider: params.provider,
|
provider: params.provider,
|
||||||
model,
|
model,
|
||||||
input_tokens: 0,
|
tokens: TokenUsage::default(),
|
||||||
output_tokens: 0,
|
|
||||||
cache_creation_input_tokens: 0,
|
|
||||||
cache_read_input_tokens: 0,
|
|
||||||
inner_stream: stream,
|
inner_stream: stream,
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
@ -476,6 +474,19 @@ async fn check_usage_limit(
|
||||||
"Maximum spending limit reached for this month.".to_string(),
|
"Maximum spending limit reached for this month.".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if usage.spending_this_month >= Cents(claims.max_monthly_spend_in_cents) {
|
||||||
|
return Err(Error::Http(
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"Maximum spending limit reached for this month.".to_string(),
|
||||||
|
[(
|
||||||
|
HeaderName::from_static(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME),
|
||||||
|
HeaderValue::from_static("true"),
|
||||||
|
)]
|
||||||
|
.into_iter()
|
||||||
|
.collect(),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -598,10 +609,7 @@ struct TokenCountingStream<S> {
|
||||||
claims: LlmTokenClaims,
|
claims: LlmTokenClaims,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model: String,
|
model: String,
|
||||||
input_tokens: usize,
|
tokens: TokenUsage,
|
||||||
output_tokens: usize,
|
|
||||||
cache_creation_input_tokens: usize,
|
|
||||||
cache_read_input_tokens: usize,
|
|
||||||
inner_stream: S,
|
inner_stream: S,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -615,10 +623,10 @@ where
|
||||||
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
match Pin::new(&mut self.inner_stream).poll_next(cx) {
|
||||||
Poll::Ready(Some(Ok(mut chunk))) => {
|
Poll::Ready(Some(Ok(mut chunk))) => {
|
||||||
chunk.bytes.push(b'\n');
|
chunk.bytes.push(b'\n');
|
||||||
self.input_tokens += chunk.input_tokens;
|
self.tokens.input += chunk.input_tokens;
|
||||||
self.output_tokens += chunk.output_tokens;
|
self.tokens.output += chunk.output_tokens;
|
||||||
self.cache_creation_input_tokens += chunk.cache_creation_input_tokens;
|
self.tokens.input_cache_creation += chunk.cache_creation_input_tokens;
|
||||||
self.cache_read_input_tokens += chunk.cache_read_input_tokens;
|
self.tokens.input_cache_read += chunk.cache_read_input_tokens;
|
||||||
Poll::Ready(Some(Ok(chunk.bytes)))
|
Poll::Ready(Some(Ok(chunk.bytes)))
|
||||||
}
|
}
|
||||||
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
|
||||||
|
@ -634,10 +642,7 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||||
let claims = self.claims.clone();
|
let claims = self.claims.clone();
|
||||||
let provider = self.provider;
|
let provider = self.provider;
|
||||||
let model = std::mem::take(&mut self.model);
|
let model = std::mem::take(&mut self.model);
|
||||||
let input_token_count = self.input_tokens;
|
let tokens = self.tokens;
|
||||||
let output_token_count = self.output_tokens;
|
|
||||||
let cache_creation_input_token_count = self.cache_creation_input_tokens;
|
|
||||||
let cache_read_input_token_count = self.cache_read_input_tokens;
|
|
||||||
self.state.executor.spawn_detached(async move {
|
self.state.executor.spawn_detached(async move {
|
||||||
let usage = state
|
let usage = state
|
||||||
.db
|
.db
|
||||||
|
@ -646,10 +651,9 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||||
claims.is_staff,
|
claims.is_staff,
|
||||||
provider,
|
provider,
|
||||||
&model,
|
&model,
|
||||||
input_token_count,
|
tokens,
|
||||||
cache_creation_input_token_count,
|
claims.has_llm_subscription,
|
||||||
cache_read_input_token_count,
|
Cents(claims.max_monthly_spend_in_cents),
|
||||||
output_token_count,
|
|
||||||
Utc::now(),
|
Utc::now(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -679,22 +683,23 @@ impl<S> Drop for TokenCountingStream<S> {
|
||||||
},
|
},
|
||||||
model,
|
model,
|
||||||
provider: provider.to_string(),
|
provider: provider.to_string(),
|
||||||
input_token_count: input_token_count as u64,
|
input_token_count: tokens.input as u64,
|
||||||
cache_creation_input_token_count: cache_creation_input_token_count
|
cache_creation_input_token_count: tokens.input_cache_creation as u64,
|
||||||
as u64,
|
cache_read_input_token_count: tokens.input_cache_read as u64,
|
||||||
cache_read_input_token_count: cache_read_input_token_count as u64,
|
output_token_count: tokens.output as u64,
|
||||||
output_token_count: output_token_count as u64,
|
|
||||||
requests_this_minute: usage.requests_this_minute as u64,
|
requests_this_minute: usage.requests_this_minute as u64,
|
||||||
tokens_this_minute: usage.tokens_this_minute as u64,
|
tokens_this_minute: usage.tokens_this_minute as u64,
|
||||||
tokens_this_day: usage.tokens_this_day as u64,
|
tokens_this_day: usage.tokens_this_day as u64,
|
||||||
input_tokens_this_month: usage.input_tokens_this_month as u64,
|
input_tokens_this_month: usage.tokens_this_month.input as u64,
|
||||||
cache_creation_input_tokens_this_month: usage
|
cache_creation_input_tokens_this_month: usage
|
||||||
.cache_creation_input_tokens_this_month
|
.tokens_this_month
|
||||||
|
.input_cache_creation
|
||||||
as u64,
|
as u64,
|
||||||
cache_read_input_tokens_this_month: usage
|
cache_read_input_tokens_this_month: usage
|
||||||
.cache_read_input_tokens_this_month
|
.tokens_this_month
|
||||||
|
.input_cache_read
|
||||||
as u64,
|
as u64,
|
||||||
output_tokens_this_month: usage.output_tokens_this_month as u64,
|
output_tokens_this_month: usage.tokens_this_month.output as u64,
|
||||||
spending_this_month: usage.spending_this_month.0 as u64,
|
spending_this_month: usage.spending_this_month.0 as u64,
|
||||||
lifetime_spending: usage.lifetime_spending.0 as u64,
|
lifetime_spending: usage.lifetime_spending.0 as u64,
|
||||||
},
|
},
|
||||||
|
|
|
@ -20,7 +20,7 @@ use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
pub use queries::usages::ActiveUserCount;
|
pub use queries::usages::{ActiveUserCount, TokenUsage};
|
||||||
use sea_orm::prelude::*;
|
use sea_orm::prelude::*;
|
||||||
pub use sea_orm::ConnectOptions;
|
pub use sea_orm::ConnectOptions;
|
||||||
use sea_orm::{
|
use sea_orm::{
|
||||||
|
|
|
@ -8,3 +8,4 @@ id_type!(ProviderId);
|
||||||
id_type!(UsageId);
|
id_type!(UsageId);
|
||||||
id_type!(UsageMeasureId);
|
id_type!(UsageMeasureId);
|
||||||
id_type!(RevokedAccessTokenId);
|
id_type!(RevokedAccessTokenId);
|
||||||
|
id_type!(BillingEventId);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
pub mod billing_events;
|
||||||
pub mod providers;
|
pub mod providers;
|
||||||
pub mod revoked_access_tokens;
|
pub mod revoked_access_tokens;
|
||||||
pub mod usages;
|
pub mod usages;
|
||||||
|
|
31
crates/collab/src/llm/db/queries/billing_events.rs
Normal file
31
crates/collab/src/llm/db/queries/billing_events.rs
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
use super::*;
|
||||||
|
use crate::Result;
|
||||||
|
use anyhow::Context as _;
|
||||||
|
|
||||||
|
impl LlmDatabase {
|
||||||
|
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
|
||||||
|
self.transaction(|tx| async move {
|
||||||
|
let events_with_models = billing_event::Entity::find()
|
||||||
|
.find_also_related(model::Entity)
|
||||||
|
.all(&*tx)
|
||||||
|
.await?;
|
||||||
|
events_with_models
|
||||||
|
.into_iter()
|
||||||
|
.map(|(event, model)| {
|
||||||
|
let model =
|
||||||
|
model.context("could not find model associated with billing event")?;
|
||||||
|
Ok((event, model))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
|
||||||
|
self.transaction(|tx| async move {
|
||||||
|
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::db::UserId;
|
|
||||||
use crate::llm::Cents;
|
use crate::llm::Cents;
|
||||||
|
use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||||
use chrono::{Datelike, Duration};
|
use chrono::{Datelike, Duration};
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use rpc::LanguageModelProvider;
|
use rpc::LanguageModelProvider;
|
||||||
|
@ -9,15 +9,26 @@ use strum::IntoEnumIterator as _;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone, Copy, Default)]
|
||||||
|
pub struct TokenUsage {
|
||||||
|
pub input: usize,
|
||||||
|
pub input_cache_creation: usize,
|
||||||
|
pub input_cache_read: usize,
|
||||||
|
pub output: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenUsage {
|
||||||
|
pub fn total(&self) -> usize {
|
||||||
|
self.input + self.input_cache_creation + self.input_cache_read + self.output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Clone, Copy)]
|
||||||
pub struct Usage {
|
pub struct Usage {
|
||||||
pub requests_this_minute: usize,
|
pub requests_this_minute: usize,
|
||||||
pub tokens_this_minute: usize,
|
pub tokens_this_minute: usize,
|
||||||
pub tokens_this_day: usize,
|
pub tokens_this_day: usize,
|
||||||
pub input_tokens_this_month: usize,
|
pub tokens_this_month: TokenUsage,
|
||||||
pub cache_creation_input_tokens_this_month: usize,
|
|
||||||
pub cache_read_input_tokens_this_month: usize,
|
|
||||||
pub output_tokens_this_month: usize,
|
|
||||||
pub spending_this_month: Cents,
|
pub spending_this_month: Cents,
|
||||||
pub lifetime_spending: Cents,
|
pub lifetime_spending: Cents,
|
||||||
}
|
}
|
||||||
|
@ -257,18 +268,20 @@ impl LlmDatabase {
|
||||||
requests_this_minute,
|
requests_this_minute,
|
||||||
tokens_this_minute,
|
tokens_this_minute,
|
||||||
tokens_this_day,
|
tokens_this_day,
|
||||||
input_tokens_this_month: monthly_usage
|
tokens_this_month: TokenUsage {
|
||||||
|
input: monthly_usage
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(0, |usage| usage.input_tokens as usize),
|
.map_or(0, |usage| usage.input_tokens as usize),
|
||||||
cache_creation_input_tokens_this_month: monthly_usage
|
input_cache_creation: monthly_usage
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
.map_or(0, |usage| usage.cache_creation_input_tokens as usize),
|
||||||
cache_read_input_tokens_this_month: monthly_usage
|
input_cache_read: monthly_usage
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
.map_or(0, |usage| usage.cache_read_input_tokens as usize),
|
||||||
output_tokens_this_month: monthly_usage
|
output: monthly_usage
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(0, |usage| usage.output_tokens as usize),
|
.map_or(0, |usage| usage.output_tokens as usize),
|
||||||
|
},
|
||||||
spending_this_month,
|
spending_this_month,
|
||||||
lifetime_spending,
|
lifetime_spending,
|
||||||
})
|
})
|
||||||
|
@ -283,10 +296,9 @@ impl LlmDatabase {
|
||||||
is_staff: bool,
|
is_staff: bool,
|
||||||
provider: LanguageModelProvider,
|
provider: LanguageModelProvider,
|
||||||
model_name: &str,
|
model_name: &str,
|
||||||
input_token_count: usize,
|
tokens: TokenUsage,
|
||||||
cache_creation_input_tokens: usize,
|
has_llm_subscription: bool,
|
||||||
cache_read_input_tokens: usize,
|
max_monthly_spend: Cents,
|
||||||
output_token_count: usize,
|
|
||||||
now: DateTimeUtc,
|
now: DateTimeUtc,
|
||||||
) -> Result<Usage> {
|
) -> Result<Usage> {
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
|
@ -313,10 +325,6 @@ impl LlmDatabase {
|
||||||
&tx,
|
&tx,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let total_token_count = input_token_count
|
|
||||||
+ cache_read_input_tokens
|
|
||||||
+ cache_creation_input_tokens
|
|
||||||
+ output_token_count;
|
|
||||||
let tokens_this_minute = self
|
let tokens_this_minute = self
|
||||||
.update_usage_for_measure(
|
.update_usage_for_measure(
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -325,7 +333,7 @@ impl LlmDatabase {
|
||||||
&usages,
|
&usages,
|
||||||
UsageMeasure::TokensPerMinute,
|
UsageMeasure::TokensPerMinute,
|
||||||
now,
|
now,
|
||||||
total_token_count,
|
tokens.total(),
|
||||||
&tx,
|
&tx,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -337,7 +345,7 @@ impl LlmDatabase {
|
||||||
&usages,
|
&usages,
|
||||||
UsageMeasure::TokensPerDay,
|
UsageMeasure::TokensPerDay,
|
||||||
now,
|
now,
|
||||||
total_token_count,
|
tokens.total(),
|
||||||
&tx,
|
&tx,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -361,18 +369,14 @@ impl LlmDatabase {
|
||||||
Some(usage) => {
|
Some(usage) => {
|
||||||
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
monthly_usage::Entity::update(monthly_usage::ActiveModel {
|
||||||
id: ActiveValue::unchanged(usage.id),
|
id: ActiveValue::unchanged(usage.id),
|
||||||
input_tokens: ActiveValue::set(
|
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||||
usage.input_tokens + input_token_count as i64,
|
|
||||||
),
|
|
||||||
cache_creation_input_tokens: ActiveValue::set(
|
cache_creation_input_tokens: ActiveValue::set(
|
||||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||||
),
|
),
|
||||||
cache_read_input_tokens: ActiveValue::set(
|
cache_read_input_tokens: ActiveValue::set(
|
||||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||||
),
|
|
||||||
output_tokens: ActiveValue::set(
|
|
||||||
usage.output_tokens + output_token_count as i64,
|
|
||||||
),
|
),
|
||||||
|
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
.exec(&*tx)
|
.exec(&*tx)
|
||||||
|
@ -384,12 +388,12 @@ impl LlmDatabase {
|
||||||
model_id: ActiveValue::set(model.id),
|
model_id: ActiveValue::set(model.id),
|
||||||
month: ActiveValue::set(month),
|
month: ActiveValue::set(month),
|
||||||
year: ActiveValue::set(year),
|
year: ActiveValue::set(year),
|
||||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||||
cache_creation_input_tokens: ActiveValue::set(
|
cache_creation_input_tokens: ActiveValue::set(
|
||||||
cache_creation_input_tokens as i64,
|
tokens.input_cache_creation as i64,
|
||||||
),
|
),
|
||||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
.insert(&*tx)
|
.insert(&*tx)
|
||||||
|
@ -405,6 +409,26 @@ impl LlmDatabase {
|
||||||
monthly_usage.output_tokens as usize,
|
monthly_usage.output_tokens as usize,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT
|
||||||
|
&& has_llm_subscription
|
||||||
|
&& spending_this_month <= max_monthly_spend
|
||||||
|
{
|
||||||
|
billing_event::ActiveModel {
|
||||||
|
id: ActiveValue::not_set(),
|
||||||
|
idempotency_key: ActiveValue::not_set(),
|
||||||
|
user_id: ActiveValue::set(user_id),
|
||||||
|
model_id: ActiveValue::set(model.id),
|
||||||
|
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||||
|
input_cache_creation_tokens: ActiveValue::set(
|
||||||
|
tokens.input_cache_creation as i64,
|
||||||
|
),
|
||||||
|
input_cache_read_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||||
|
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||||
|
}
|
||||||
|
.insert(&*tx)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
// Update lifetime usage
|
// Update lifetime usage
|
||||||
let lifetime_usage = lifetime_usage::Entity::find()
|
let lifetime_usage = lifetime_usage::Entity::find()
|
||||||
.filter(
|
.filter(
|
||||||
|
@ -419,18 +443,14 @@ impl LlmDatabase {
|
||||||
Some(usage) => {
|
Some(usage) => {
|
||||||
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
lifetime_usage::Entity::update(lifetime_usage::ActiveModel {
|
||||||
id: ActiveValue::unchanged(usage.id),
|
id: ActiveValue::unchanged(usage.id),
|
||||||
input_tokens: ActiveValue::set(
|
input_tokens: ActiveValue::set(usage.input_tokens + tokens.input as i64),
|
||||||
usage.input_tokens + input_token_count as i64,
|
|
||||||
),
|
|
||||||
cache_creation_input_tokens: ActiveValue::set(
|
cache_creation_input_tokens: ActiveValue::set(
|
||||||
usage.cache_creation_input_tokens + cache_creation_input_tokens as i64,
|
usage.cache_creation_input_tokens + tokens.input_cache_creation as i64,
|
||||||
),
|
),
|
||||||
cache_read_input_tokens: ActiveValue::set(
|
cache_read_input_tokens: ActiveValue::set(
|
||||||
usage.cache_read_input_tokens + cache_read_input_tokens as i64,
|
usage.cache_read_input_tokens + tokens.input_cache_read as i64,
|
||||||
),
|
|
||||||
output_tokens: ActiveValue::set(
|
|
||||||
usage.output_tokens + output_token_count as i64,
|
|
||||||
),
|
),
|
||||||
|
output_tokens: ActiveValue::set(usage.output_tokens + tokens.output as i64),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
.exec(&*tx)
|
.exec(&*tx)
|
||||||
|
@ -440,12 +460,12 @@ impl LlmDatabase {
|
||||||
lifetime_usage::ActiveModel {
|
lifetime_usage::ActiveModel {
|
||||||
user_id: ActiveValue::set(user_id),
|
user_id: ActiveValue::set(user_id),
|
||||||
model_id: ActiveValue::set(model.id),
|
model_id: ActiveValue::set(model.id),
|
||||||
input_tokens: ActiveValue::set(input_token_count as i64),
|
input_tokens: ActiveValue::set(tokens.input as i64),
|
||||||
cache_creation_input_tokens: ActiveValue::set(
|
cache_creation_input_tokens: ActiveValue::set(
|
||||||
cache_creation_input_tokens as i64,
|
tokens.input_cache_creation as i64,
|
||||||
),
|
),
|
||||||
cache_read_input_tokens: ActiveValue::set(cache_read_input_tokens as i64),
|
cache_read_input_tokens: ActiveValue::set(tokens.input_cache_read as i64),
|
||||||
output_tokens: ActiveValue::set(output_token_count as i64),
|
output_tokens: ActiveValue::set(tokens.output as i64),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
.insert(&*tx)
|
.insert(&*tx)
|
||||||
|
@ -465,11 +485,12 @@ impl LlmDatabase {
|
||||||
requests_this_minute,
|
requests_this_minute,
|
||||||
tokens_this_minute,
|
tokens_this_minute,
|
||||||
tokens_this_day,
|
tokens_this_day,
|
||||||
input_tokens_this_month: monthly_usage.input_tokens as usize,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: monthly_usage.cache_creation_input_tokens
|
input: monthly_usage.input_tokens as usize,
|
||||||
as usize,
|
input_cache_creation: monthly_usage.cache_creation_input_tokens as usize,
|
||||||
cache_read_input_tokens_this_month: monthly_usage.cache_read_input_tokens as usize,
|
input_cache_read: monthly_usage.cache_read_input_tokens as usize,
|
||||||
output_tokens_this_month: monthly_usage.output_tokens as usize,
|
output: monthly_usage.output_tokens as usize,
|
||||||
|
},
|
||||||
spending_this_month,
|
spending_this_month,
|
||||||
lifetime_spending,
|
lifetime_spending,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod billing_event;
|
||||||
pub mod lifetime_usage;
|
pub mod lifetime_usage;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod monthly_usage;
|
pub mod monthly_usage;
|
||||||
|
|
37
crates/collab/src/llm/db/tables/billing_event.rs
Normal file
37
crates/collab/src/llm/db/tables/billing_event.rs
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
use crate::{
|
||||||
|
db::UserId,
|
||||||
|
llm::db::{BillingEventId, ModelId},
|
||||||
|
};
|
||||||
|
use sea_orm::entity::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||||
|
#[sea_orm(table_name = "billing_events")]
|
||||||
|
pub struct Model {
|
||||||
|
#[sea_orm(primary_key)]
|
||||||
|
pub id: BillingEventId,
|
||||||
|
pub idempotency_key: Uuid,
|
||||||
|
pub user_id: UserId,
|
||||||
|
pub model_id: ModelId,
|
||||||
|
pub input_tokens: i64,
|
||||||
|
pub input_cache_creation_tokens: i64,
|
||||||
|
pub input_cache_read_tokens: i64,
|
||||||
|
pub output_tokens: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||||
|
pub enum Relation {
|
||||||
|
#[sea_orm(
|
||||||
|
belongs_to = "super::model::Entity",
|
||||||
|
from = "Column::ModelId",
|
||||||
|
to = "super::model::Column::Id"
|
||||||
|
)]
|
||||||
|
Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Related<super::model::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::Model.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -29,6 +29,8 @@ pub enum Relation {
|
||||||
Provider,
|
Provider,
|
||||||
#[sea_orm(has_many = "super::usage::Entity")]
|
#[sea_orm(has_many = "super::usage::Entity")]
|
||||||
Usages,
|
Usages,
|
||||||
|
#[sea_orm(has_many = "super::billing_event::Entity")]
|
||||||
|
BillingEvents,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Related<super::provider::Entity> for Entity {
|
impl Related<super::provider::Entity> for Entity {
|
||||||
|
@ -43,4 +45,10 @@ impl Related<super::usage::Entity> for Entity {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Related<super::billing_event::Entity> for Entity {
|
||||||
|
fn to() -> RelationDef {
|
||||||
|
Relation::BillingEvents.def()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ActiveModelBehavior for ActiveModel {}
|
impl ActiveModelBehavior for ActiveModel {}
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
||||||
db::UserId,
|
db::UserId,
|
||||||
llm::db::{
|
llm::db::{
|
||||||
queries::{providers::ModelParams, usages::Usage},
|
queries::{providers::ModelParams, usages::Usage},
|
||||||
LlmDatabase,
|
LlmDatabase, TokenUsage,
|
||||||
},
|
},
|
||||||
test_llm_db, Cents,
|
test_llm_db, Cents,
|
||||||
};
|
};
|
||||||
|
@ -36,12 +36,40 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
let user_id = UserId::from_proto(123);
|
let user_id = UserId::from_proto(123);
|
||||||
|
|
||||||
let now = t0;
|
let now = t0;
|
||||||
db.record_usage(user_id, false, provider, model, 1000, 0, 0, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 1000,
|
||||||
|
input_cache_creation: 0,
|
||||||
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let now = t0 + Duration::seconds(10);
|
let now = t0 + Duration::seconds(10);
|
||||||
db.record_usage(user_id, false, provider, model, 2000, 0, 0, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 2000,
|
||||||
|
input_cache_creation: 0,
|
||||||
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -52,10 +80,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 2,
|
requests_this_minute: 2,
|
||||||
tokens_this_minute: 3000,
|
tokens_this_minute: 3000,
|
||||||
tokens_this_day: 3000,
|
tokens_this_day: 3000,
|
||||||
input_tokens_this_month: 3000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 0,
|
input: 3000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 0,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
|
@ -69,17 +99,33 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 1,
|
requests_this_minute: 1,
|
||||||
tokens_this_minute: 2000,
|
tokens_this_minute: 2000,
|
||||||
tokens_this_day: 3000,
|
tokens_this_day: 3000,
|
||||||
input_tokens_this_month: 3000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 0,
|
input: 3000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 0,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
let now = t0 + Duration::seconds(60);
|
let now = t0 + Duration::seconds(60);
|
||||||
db.record_usage(user_id, false, provider, model, 3000, 0, 0, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 3000,
|
||||||
|
input_cache_creation: 0,
|
||||||
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -90,10 +136,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 2,
|
requests_this_minute: 2,
|
||||||
tokens_this_minute: 5000,
|
tokens_this_minute: 5000,
|
||||||
tokens_this_day: 6000,
|
tokens_this_day: 6000,
|
||||||
input_tokens_this_month: 6000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 0,
|
input: 6000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 0,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
|
@ -108,16 +156,32 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 0,
|
requests_this_minute: 0,
|
||||||
tokens_this_minute: 0,
|
tokens_this_minute: 0,
|
||||||
tokens_this_day: 5000,
|
tokens_this_day: 5000,
|
||||||
input_tokens_this_month: 6000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 0,
|
input: 6000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 0,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
db.record_usage(user_id, false, provider, model, 4000, 0, 0, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 4000,
|
||||||
|
input_cache_creation: 0,
|
||||||
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -128,10 +192,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 1,
|
requests_this_minute: 1,
|
||||||
tokens_this_minute: 4000,
|
tokens_this_minute: 4000,
|
||||||
tokens_this_day: 9000,
|
tokens_this_day: 9000,
|
||||||
input_tokens_this_month: 10000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 0,
|
input: 10000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 0,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
|
@ -143,7 +209,21 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
.with_timezone(&Utc);
|
.with_timezone(&Utc);
|
||||||
|
|
||||||
// Test cache creation input tokens
|
// Test cache creation input tokens
|
||||||
db.record_usage(user_id, false, provider, model, 1000, 500, 0, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 1000,
|
||||||
|
input_cache_creation: 500,
|
||||||
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -154,17 +234,33 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 1,
|
requests_this_minute: 1,
|
||||||
tokens_this_minute: 1500,
|
tokens_this_minute: 1500,
|
||||||
tokens_this_day: 1500,
|
tokens_this_day: 1500,
|
||||||
input_tokens_this_month: 1000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 500,
|
input: 1000,
|
||||||
cache_read_input_tokens_this_month: 0,
|
input_cache_creation: 500,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 0,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Test cache read input tokens
|
// Test cache read input tokens
|
||||||
db.record_usage(user_id, false, provider, model, 1000, 0, 300, 0, now)
|
db.record_usage(
|
||||||
|
user_id,
|
||||||
|
false,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
TokenUsage {
|
||||||
|
input: 1000,
|
||||||
|
input_cache_creation: 0,
|
||||||
|
input_cache_read: 300,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
Cents::ZERO,
|
||||||
|
now,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -175,10 +271,12 @@ async fn test_tracking_usage(db: &mut LlmDatabase) {
|
||||||
requests_this_minute: 2,
|
requests_this_minute: 2,
|
||||||
tokens_this_minute: 2800,
|
tokens_this_minute: 2800,
|
||||||
tokens_this_day: 2800,
|
tokens_this_day: 2800,
|
||||||
input_tokens_this_month: 2000,
|
tokens_this_month: TokenUsage {
|
||||||
cache_creation_input_tokens_this_month: 500,
|
input: 2000,
|
||||||
cache_read_input_tokens_this_month: 300,
|
input_cache_creation: 500,
|
||||||
output_tokens_this_month: 0,
|
input_cache_read: 300,
|
||||||
|
output: 0,
|
||||||
|
},
|
||||||
spending_this_month: Cents::ZERO,
|
spending_this_month: Cents::ZERO,
|
||||||
lifetime_spending: Cents::ZERO,
|
lifetime_spending: Cents::ZERO,
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,7 @@ async fn main() -> Result<()> {
|
||||||
|
|
||||||
if let Some(mut llm_db) = llm_db {
|
if let Some(mut llm_db) = llm_db {
|
||||||
llm_db.initialize().await?;
|
llm_db.initialize().await?;
|
||||||
sync_llm_usage_with_stripe_periodically(state.clone(), llm_db);
|
sync_llm_usage_with_stripe_periodically(state.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
app = app
|
app = app
|
||||||
|
|
|
@ -1218,6 +1218,15 @@ impl Server {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||||
|
let pool = self.connection_pool.lock();
|
||||||
|
for connection_id in pool.user_connection_ids(user_id) {
|
||||||
|
self.peer
|
||||||
|
.send(connection_id, proto::RefreshLlmToken {})
|
||||||
|
.trace_err();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
|
||||||
ServerSnapshot {
|
ServerSnapshot {
|
||||||
connection_pool: ConnectionPoolGuard {
|
connection_pool: ConnectionPoolGuard {
|
||||||
|
|
427
crates/collab/src/stripe_billing.rs
Normal file
427
crates/collab/src/stripe_billing.rs
Normal file
|
@ -0,0 +1,427 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{llm, Cents, Result};
|
||||||
|
use anyhow::Context;
|
||||||
|
use chrono::Utc;
|
||||||
|
use collections::HashMap;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
pub struct StripeBilling {
|
||||||
|
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||||
|
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||||
|
client: Arc<stripe::Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StripeModel {
|
||||||
|
input_tokens_price: StripeBillingPrice,
|
||||||
|
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||||
|
input_cache_read_tokens_price: StripeBillingPrice,
|
||||||
|
output_tokens_price: StripeBillingPrice,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct StripeBillingPrice {
|
||||||
|
id: stripe::PriceId,
|
||||||
|
meter_event_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StripeBilling {
|
||||||
|
pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
|
||||||
|
let mut meters_by_event_name = HashMap::default();
|
||||||
|
for meter in StripeMeter::list(&client).await?.data {
|
||||||
|
meters_by_event_name.insert(meter.event_name.clone(), meter);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut price_ids_by_meter_id = HashMap::default();
|
||||||
|
for price in stripe::Price::list(&client, &stripe::ListPrices::default())
|
||||||
|
.await?
|
||||||
|
.data
|
||||||
|
{
|
||||||
|
if let Some(recurring) = price.recurring {
|
||||||
|
if let Some(meter) = recurring.meter {
|
||||||
|
price_ids_by_meter_id.insert(meter, price.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
meters_by_event_name,
|
||||||
|
price_ids_by_meter_id,
|
||||||
|
client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||||
|
let input_tokens_price = self
|
||||||
|
.get_or_insert_price(
|
||||||
|
&format!("model_{}/input_tokens", model.id),
|
||||||
|
&format!("{} (Input Tokens)", model.name),
|
||||||
|
Cents::new(model.price_per_million_input_tokens as u32),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let input_cache_creation_tokens_price = self
|
||||||
|
.get_or_insert_price(
|
||||||
|
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||||
|
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||||
|
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let input_cache_read_tokens_price = self
|
||||||
|
.get_or_insert_price(
|
||||||
|
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||||
|
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||||
|
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
let output_tokens_price = self
|
||||||
|
.get_or_insert_price(
|
||||||
|
&format!("model_{}/output_tokens", model.id),
|
||||||
|
&format!("{} (Output Tokens)", model.name),
|
||||||
|
Cents::new(model.price_per_million_output_tokens as u32),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(StripeModel {
|
||||||
|
input_tokens_price,
|
||||||
|
input_cache_creation_tokens_price,
|
||||||
|
input_cache_read_tokens_price,
|
||||||
|
output_tokens_price,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_or_insert_price(
|
||||||
|
&mut self,
|
||||||
|
meter_event_name: &str,
|
||||||
|
price_description: &str,
|
||||||
|
price_per_million_tokens: Cents,
|
||||||
|
) -> Result<StripeBillingPrice> {
|
||||||
|
let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
|
||||||
|
meter.clone()
|
||||||
|
} else {
|
||||||
|
let meter = StripeMeter::create(
|
||||||
|
&self.client,
|
||||||
|
StripeCreateMeterParams {
|
||||||
|
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||||
|
display_name: price_description.to_string(),
|
||||||
|
event_name: meter_event_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
self.meters_by_event_name
|
||||||
|
.insert(meter_event_name.to_string(), meter.clone());
|
||||||
|
meter
|
||||||
|
};
|
||||||
|
|
||||||
|
let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
|
||||||
|
price_id.clone()
|
||||||
|
} else {
|
||||||
|
let price = stripe::Price::create(
|
||||||
|
&self.client,
|
||||||
|
stripe::CreatePrice {
|
||||||
|
active: Some(true),
|
||||||
|
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||||
|
currency: stripe::Currency::USD,
|
||||||
|
currency_options: None,
|
||||||
|
custom_unit_amount: None,
|
||||||
|
expand: &[],
|
||||||
|
lookup_key: None,
|
||||||
|
metadata: None,
|
||||||
|
nickname: None,
|
||||||
|
product: None,
|
||||||
|
product_data: Some(stripe::CreatePriceProductData {
|
||||||
|
id: None,
|
||||||
|
active: Some(true),
|
||||||
|
metadata: None,
|
||||||
|
name: price_description.to_string(),
|
||||||
|
statement_descriptor: None,
|
||||||
|
tax_code: None,
|
||||||
|
unit_label: None,
|
||||||
|
}),
|
||||||
|
recurring: Some(stripe::CreatePriceRecurring {
|
||||||
|
aggregate_usage: None,
|
||||||
|
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||||
|
interval_count: None,
|
||||||
|
trial_period_days: None,
|
||||||
|
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||||
|
meter: Some(meter.id.clone()),
|
||||||
|
}),
|
||||||
|
tax_behavior: None,
|
||||||
|
tiers: None,
|
||||||
|
tiers_mode: None,
|
||||||
|
transfer_lookup_key: None,
|
||||||
|
transform_quantity: None,
|
||||||
|
unit_amount: None,
|
||||||
|
unit_amount_decimal: Some(&format!(
|
||||||
|
"{:.12}",
|
||||||
|
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
self.price_ids_by_meter_id
|
||||||
|
.insert(meter.id, price.id.clone());
|
||||||
|
price.id
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(StripeBillingPrice {
|
||||||
|
id: price_id,
|
||||||
|
meter_event_name: meter_event_name.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn subscribe_to_model(
|
||||||
|
&self,
|
||||||
|
subscription_id: &stripe::SubscriptionId,
|
||||||
|
model: &StripeModel,
|
||||||
|
) -> Result<()> {
|
||||||
|
let subscription =
|
||||||
|
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||||
|
|
||||||
|
let mut items = Vec::new();
|
||||||
|
|
||||||
|
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
|
||||||
|
items.push(stripe::UpdateSubscriptionItems {
|
||||||
|
price: Some(model.input_tokens_price.id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
|
||||||
|
{
|
||||||
|
items.push(stripe::UpdateSubscriptionItems {
|
||||||
|
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
|
||||||
|
items.push(stripe::UpdateSubscriptionItems {
|
||||||
|
price: Some(model.input_cache_read_tokens_price.id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
|
||||||
|
items.push(stripe::UpdateSubscriptionItems {
|
||||||
|
price: Some(model.output_tokens_price.id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !items.is_empty() {
|
||||||
|
items.extend(subscription.items.data.iter().map(|item| {
|
||||||
|
stripe::UpdateSubscriptionItems {
|
||||||
|
id: Some(item.id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
stripe::Subscription::update(
|
||||||
|
&self.client,
|
||||||
|
subscription_id,
|
||||||
|
stripe::UpdateSubscription {
|
||||||
|
items: Some(items),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn bill_model_usage(
|
||||||
|
&self,
|
||||||
|
customer_id: &stripe::CustomerId,
|
||||||
|
model: &StripeModel,
|
||||||
|
event: &llm::db::billing_event::Model,
|
||||||
|
) -> Result<()> {
|
||||||
|
let timestamp = Utc::now().timestamp();
|
||||||
|
|
||||||
|
if event.input_tokens > 0 {
|
||||||
|
StripeMeterEvent::create(
|
||||||
|
&self.client,
|
||||||
|
StripeCreateMeterEventParams {
|
||||||
|
identifier: &format!("input_tokens/{}", event.idempotency_key),
|
||||||
|
event_name: &model.input_tokens_price.meter_event_name,
|
||||||
|
payload: StripeCreateMeterEventPayload {
|
||||||
|
value: event.input_tokens as u64,
|
||||||
|
stripe_customer_id: customer_id,
|
||||||
|
},
|
||||||
|
timestamp: Some(timestamp),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.input_cache_creation_tokens > 0 {
|
||||||
|
StripeMeterEvent::create(
|
||||||
|
&self.client,
|
||||||
|
StripeCreateMeterEventParams {
|
||||||
|
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
|
||||||
|
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
|
||||||
|
payload: StripeCreateMeterEventPayload {
|
||||||
|
value: event.input_cache_creation_tokens as u64,
|
||||||
|
stripe_customer_id: customer_id,
|
||||||
|
},
|
||||||
|
timestamp: Some(timestamp),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.input_cache_read_tokens > 0 {
|
||||||
|
StripeMeterEvent::create(
|
||||||
|
&self.client,
|
||||||
|
StripeCreateMeterEventParams {
|
||||||
|
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
|
||||||
|
event_name: &model.input_cache_read_tokens_price.meter_event_name,
|
||||||
|
payload: StripeCreateMeterEventPayload {
|
||||||
|
value: event.input_cache_read_tokens as u64,
|
||||||
|
stripe_customer_id: customer_id,
|
||||||
|
},
|
||||||
|
timestamp: Some(timestamp),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.output_tokens > 0 {
|
||||||
|
StripeMeterEvent::create(
|
||||||
|
&self.client,
|
||||||
|
StripeCreateMeterEventParams {
|
||||||
|
identifier: &format!("output_tokens/{}", event.idempotency_key),
|
||||||
|
event_name: &model.output_tokens_price.meter_event_name,
|
||||||
|
payload: StripeCreateMeterEventPayload {
|
||||||
|
value: event.output_tokens as u64,
|
||||||
|
stripe_customer_id: customer_id,
|
||||||
|
},
|
||||||
|
timestamp: Some(timestamp),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn checkout(
|
||||||
|
&self,
|
||||||
|
customer_id: stripe::CustomerId,
|
||||||
|
github_login: &str,
|
||||||
|
model: &StripeModel,
|
||||||
|
success_url: &str,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut params = stripe::CreateCheckoutSession::new();
|
||||||
|
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||||
|
params.customer = Some(customer_id);
|
||||||
|
params.client_reference_id = Some(github_login);
|
||||||
|
params.line_items = Some(
|
||||||
|
[
|
||||||
|
&model.input_tokens_price.id,
|
||||||
|
&model.input_cache_creation_tokens_price.id,
|
||||||
|
&model.input_cache_read_tokens_price.id,
|
||||||
|
&model.output_tokens_price.id,
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
|
||||||
|
price: Some(price_id.to_string()),
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
params.success_url = Some(success_url);
|
||||||
|
|
||||||
|
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||||
|
Ok(session.url.context("no checkout session URL")?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct DefaultAggregation {
|
||||||
|
formula: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct StripeCreateMeterParams<'a> {
|
||||||
|
default_aggregation: DefaultAggregation,
|
||||||
|
display_name: String,
|
||||||
|
event_name: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize)]
|
||||||
|
struct StripeMeter {
|
||||||
|
id: String,
|
||||||
|
event_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StripeMeter {
|
||||||
|
pub fn create(
|
||||||
|
client: &stripe::Client,
|
||||||
|
params: StripeCreateMeterParams,
|
||||||
|
) -> stripe::Response<Self> {
|
||||||
|
client.post_form("/billing/meters", params)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct Params {}
|
||||||
|
|
||||||
|
client.get_query("/billing/meters", Params {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct StripeMeterEvent {
|
||||||
|
identifier: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StripeMeterEvent {
|
||||||
|
pub async fn create(
|
||||||
|
client: &stripe::Client,
|
||||||
|
params: StripeCreateMeterEventParams<'_>,
|
||||||
|
) -> Result<Self, stripe::StripeError> {
|
||||||
|
let identifier = params.identifier;
|
||||||
|
match client.post_form("/billing/meter_events", params).await {
|
||||||
|
Ok(event) => Ok(event),
|
||||||
|
Err(stripe::StripeError::Stripe(error)) => {
|
||||||
|
if error.http_status == 400
|
||||||
|
&& error
|
||||||
|
.message
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |message| message.contains(identifier))
|
||||||
|
{
|
||||||
|
Ok(Self {
|
||||||
|
identifier: identifier.to_string(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(stripe::StripeError::Stripe(error))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => Err(error),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct StripeCreateMeterEventParams<'a> {
|
||||||
|
identifier: &'a str,
|
||||||
|
event_name: &'a str,
|
||||||
|
payload: StripeCreateMeterEventPayload<'a>,
|
||||||
|
timestamp: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct StripeCreateMeterEventPayload<'a> {
|
||||||
|
value: u64,
|
||||||
|
stripe_customer_id: &'a stripe::CustomerId,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subscription_contains_price(
|
||||||
|
subscription: &stripe::Subscription,
|
||||||
|
price_id: &stripe::PriceId,
|
||||||
|
) -> bool {
|
||||||
|
subscription.items.data.iter().any(|item| {
|
||||||
|
item.price
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |price| price.id == *price_id)
|
||||||
|
})
|
||||||
|
}
|
|
@ -635,6 +635,7 @@ impl TestServer {
|
||||||
) -> Arc<AppState> {
|
) -> Arc<AppState> {
|
||||||
Arc::new(AppState {
|
Arc::new(AppState {
|
||||||
db: test_db.db().clone(),
|
db: test_db.db().clone(),
|
||||||
|
llm_db: None,
|
||||||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||||
blob_store_client: None,
|
blob_store_client: None,
|
||||||
stripe_client: None,
|
stripe_client: None,
|
||||||
|
@ -677,8 +678,6 @@ impl TestServer {
|
||||||
migrations_path: None,
|
migrations_path: None,
|
||||||
seed_path: None,
|
seed_path: None,
|
||||||
stripe_api_key: None,
|
stripe_api_key: None,
|
||||||
stripe_llm_access_price_id: None,
|
|
||||||
stripe_llm_usage_price_id: None,
|
|
||||||
supermaven_admin_api_key: None,
|
supermaven_admin_api_key: None,
|
||||||
user_backfiller_github_access_token: None,
|
user_backfiller_github_access_token: None,
|
||||||
},
|
},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue