diff --git a/Cargo.lock b/Cargo.lock index 2be2380956..65deb35f66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2553,6 +2553,7 @@ dependencies = [ "collections", "ctor", "dashmap 6.0.1", + "derive_more", "dev_server_projects", "editor", "env_logger", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 48482bd435..4513e02af3 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -32,6 +32,7 @@ clickhouse.workspace = true clock.workspace = true collections.workspace = true dashmap.workspace = true +derive_more.workspace = true envy = "0.4.2" futures.workspace = true google_ai.workspace = true diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index b70fc1e3ba..0a0bebbc56 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -29,7 +29,7 @@ use crate::db::{ UpdateBillingSubscriptionParams, }; use crate::llm::db::LlmDatabase; -use crate::llm::MONTHLY_SPENDING_LIMIT_IN_CENTS; +use crate::llm::MONTHLY_SPENDING_LIMIT; use crate::rpc::ResultExt as _; use crate::{AppState, Error, Result}; @@ -703,10 +703,9 @@ async fn update_stripe_subscription( let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id) .context("failed to parse subscription ID")?; - let monthly_spending_over_free_tier = - monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT_IN_CENTS); + let monthly_spending_over_free_tier = monthly_spending.saturating_sub(MONTHLY_SPENDING_LIMIT); - let new_quantity = (monthly_spending_over_free_tier as f32 / 100.).ceil(); + let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil(); Subscription::update( stripe_client, &subscription_id, diff --git a/crates/collab/src/cents.rs b/crates/collab/src/cents.rs new file mode 100644 index 0000000000..917177bc51 --- /dev/null +++ b/crates/collab/src/cents.rs @@ -0,0 +1,78 @@ +/// A number of cents. +#[derive( + Debug, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Clone, + Copy, + derive_more::Add, + derive_more::AddAssign, +)] +pub struct Cents(pub u32); + +impl Cents { + pub const ZERO: Self = Self(0); + + pub const fn new(cents: u32) -> Self { + Self(cents) + } + + pub const fn from_dollars(dollars: u32) -> Self { + Self(dollars * 100) + } + + pub fn saturating_sub(self, other: Cents) -> Self { + Self(self.0.saturating_sub(other.0)) + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_cents_new() { + assert_eq!(Cents::new(50), Cents(50)); + } + + #[test] + fn test_cents_from_dollars() { + assert_eq!(Cents::from_dollars(1), Cents(100)); + assert_eq!(Cents::from_dollars(5), Cents(500)); + } + + #[test] + fn test_cents_zero() { + assert_eq!(Cents::ZERO, Cents(0)); + } + + #[test] + fn test_cents_add() { + assert_eq!(Cents(50) + Cents(30), Cents(80)); + } + + #[test] + fn test_cents_add_assign() { + let mut cents = Cents(50); + cents += Cents(30); + assert_eq!(cents, Cents(80)); + } + + #[test] + fn test_cents_saturating_sub() { + assert_eq!(Cents(50).saturating_sub(Cents(30)), Cents(20)); + assert_eq!(Cents(30).saturating_sub(Cents(50)), Cents(0)); + } + + #[test] + fn test_cents_ordering() { + assert!(Cents(50) > Cents(30)); + assert!(Cents(30) < Cents(50)); + assert_eq!(Cents(50), Cents(50)); + } +} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index ccecf80087..a6141abb88 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; pub mod auth; +mod cents; pub mod clickhouse; pub mod db; pub mod env; @@ -20,6 +21,7 @@ use axum::{ http::{HeaderMap, StatusCode}, response::IntoResponse, }; +pub use cents::*; use db::{ChannelId, Database}; use executor::Executor; pub use rate_limiter::*; diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 86563a766c..c475de3587 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -4,7 +4,7 @@ mod telemetry; mod token; use crate::{ - api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, + api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result, }; use anyhow::{anyhow, Context as _}; @@ -439,12 +439,10 @@ fn normalize_model_name(known_models: Vec, name: String) -> String { } /// The maximum monthly spending an individual user can reach before they have to pay. -pub const MONTHLY_SPENDING_LIMIT_IN_CENTS: usize = 5 * 100; +pub const MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5); /// The maximum lifetime spending an individual user can reach before being cut off. -/// -/// Represented in cents. -const LIFETIME_SPENDING_LIMIT_IN_CENTS: usize = 1_000 * 100; +const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000); async fn check_usage_limit( state: &Arc, @@ -464,7 +462,7 @@ async fn check_usage_limit( .await?; if state.config.is_llm_billing_enabled() { - if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT_IN_CENTS { + if usage.spending_this_month >= MONTHLY_SPENDING_LIMIT { if !claims.has_llm_subscription.unwrap_or(false) { return Err(Error::http( StatusCode::PAYMENT_REQUIRED, @@ -475,7 +473,7 @@ async fn check_usage_limit( } // TODO: Remove this once we've rolled out monthly spending limits. - if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT_IN_CENTS { + if usage.lifetime_spending >= LIFETIME_SPENDING_LIMIT { return Err(Error::http( StatusCode::FORBIDDEN, "Maximum spending limit reached.".to_string(), @@ -690,8 +688,8 @@ impl Drop for TokenCountingStream { .cache_read_input_tokens_this_month as u64, output_tokens_this_month: usage.output_tokens_this_month as u64, - spending_this_month: usage.spending_this_month as u64, - lifetime_spending: usage.lifetime_spending as u64, + spending_this_month: usage.spending_this_month.0 as u64, + lifetime_spending: usage.lifetime_spending.0 as u64, }, ) .await diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 1a98685bcd..3d6ab18415 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,4 +1,5 @@ use crate::db::UserId; +use crate::llm::Cents; use chrono::{Datelike, Duration}; use futures::StreamExt as _; use rpc::LanguageModelProvider; @@ -17,8 +18,8 @@ pub struct Usage { pub cache_creation_input_tokens_this_month: usize, pub cache_read_input_tokens_this_month: usize, pub output_tokens_this_month: usize, - pub spending_this_month: usize, - pub lifetime_spending: usize, + pub spending_this_month: Cents, + pub lifetime_spending: Cents, } #[derive(Debug, PartialEq, Clone)] @@ -144,7 +145,7 @@ impl LlmDatabase { &self, user_id: UserId, now: DateTimeUtc, - ) -> Result { + ) -> Result { self.transaction(|tx| async move { let month = now.date_naive().month() as i32; let year = now.date_naive().year(); @@ -158,7 +159,7 @@ impl LlmDatabase { ) .stream(&*tx) .await?; - let mut monthly_spending_in_cents = 0; + let mut monthly_spending = Cents::ZERO; while let Some(usage) = monthly_usages.next().await { let usage = usage?; @@ -166,7 +167,7 @@ impl LlmDatabase { continue; }; - monthly_spending_in_cents += calculate_spending( + monthly_spending += calculate_spending( model, usage.input_tokens as usize, usage.cache_creation_input_tokens as usize, @@ -175,7 +176,7 @@ impl LlmDatabase { ); } - Ok(monthly_spending_in_cents) + Ok(monthly_spending) }) .await } @@ -238,7 +239,7 @@ impl LlmDatabase { monthly_usage.output_tokens as usize, ) } else { - 0 + Cents::ZERO }; let lifetime_spending = if let Some(lifetime_usage) = &lifetime_usage { calculate_spending( @@ -249,7 +250,7 @@ impl LlmDatabase { lifetime_usage.output_tokens as usize, ) } else { - 0 + Cents::ZERO }; Ok(Usage { @@ -637,7 +638,7 @@ fn calculate_spending( cache_creation_input_tokens_this_month: usize, cache_read_input_tokens_this_month: usize, output_tokens_this_month: usize, -) -> usize { +) -> Cents { let input_token_cost = input_tokens_this_month * model.price_per_million_input_tokens as usize / 1_000_000; let cache_creation_input_token_cost = cache_creation_input_tokens_this_month @@ -648,10 +649,11 @@ fn calculate_spending( / 1_000_000; let output_token_cost = output_tokens_this_month * model.price_per_million_output_tokens as usize / 1_000_000; - input_token_cost + let spending = input_token_cost + cache_creation_input_token_cost + cache_read_input_token_cost - + output_token_cost + + output_token_cost; + Cents::new(spending as u32) } const MINUTE_BUCKET_COUNT: usize = 12; diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 8e8dc0ff6b..2730a03046 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -4,7 +4,7 @@ use crate::{ queries::{providers::ModelParams, usages::Usage}, LlmDatabase, }, - test_llm_db, + test_llm_db, Cents, }; use chrono::{DateTime, Duration, Utc}; use pretty_assertions::assert_eq; @@ -56,8 +56,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 0, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -73,8 +73,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 0, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -94,8 +94,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 0, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -112,8 +112,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 0, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -132,8 +132,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 0, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -158,8 +158,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 500, cache_read_input_tokens_this_month: 0, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); @@ -179,8 +179,8 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { cache_creation_input_tokens_this_month: 500, cache_read_input_tokens_this_month: 300, output_tokens_this_month: 0, - spending_this_month: 0, - lifetime_spending: 0, + spending_this_month: Cents::ZERO, + lifetime_spending: Cents::ZERO, } ); }