diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5ffff0921f..c6bd87a8a5 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -11,7 +11,8 @@ CREATE TABLE "users" ( "metrics_id" TEXT, "github_user_id" INTEGER NOT NULL, "accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE, - "github_user_created_at" TIMESTAMP WITHOUT TIME ZONE + "github_user_created_at" TIMESTAMP WITHOUT TIME ZONE, + "custom_llm_monthly_allowance_in_cents" INTEGER ); CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); diff --git a/crates/collab/migrations/20241021202606_add_custom_llm_monthly_allowance_in_cents_to_users.sql b/crates/collab/migrations/20241021202606_add_custom_llm_monthly_allowance_in_cents_to_users.sql new file mode 100644 index 0000000000..60a9bfa910 --- /dev/null +++ b/crates/collab/migrations/20241021202606_add_custom_llm_monthly_allowance_in_cents_to_users.sql @@ -0,0 +1 @@ +alter table users add column custom_llm_monthly_allowance_in_cents integer; diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 5e167a668c..1f5ec595d8 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -34,7 +34,7 @@ use crate::{ db::{billing_subscription::StripeSubscriptionStatus, UserId}, llm::db::LlmDatabase, }; -use crate::{AppState, Error, Result}; +use crate::{AppState, Cents, Error, Result}; pub fn router() -> Router { Router::new() @@ -700,10 +700,15 @@ async fn get_monthly_spend( )); }; + let free_tier = user + .custom_llm_monthly_allowance_in_cents + .map(|allowance| Cents(allowance as u32)) + .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT); + let monthly_spend = llm_db .get_user_spending_for_month(user.id, Utc::now()) .await? - .saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT); + .saturating_sub(free_tier); Ok(Json(GetMonthlySpendResponse { monthly_spend_in_cents: monthly_spend.0 as i32, diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index ff4331331b..3b66225af8 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -21,6 +21,7 @@ pub struct Model { pub metrics_id: Uuid, pub created_at: NaiveDateTime, pub accepted_tos_at: Option, + pub custom_llm_monthly_allowance_in_cents: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 9ee31ab3d1..cb3478879e 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -459,8 +459,9 @@ async fn check_usage_limit( Utc::now(), ) .await?; + let free_tier = claims.free_tier_monthly_spending_limit(); - if usage.spending_this_month >= FREE_TIER_MONTHLY_SPENDING_LIMIT { + if usage.spending_this_month >= free_tier { if !claims.has_llm_subscription { return Err(Error::http( StatusCode::PAYMENT_REQUIRED, @@ -468,9 +469,7 @@ async fn check_usage_limit( )); } - if (usage.spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT) - >= Cents(claims.max_monthly_spend_in_cents) - { + if (usage.spending_this_month - free_tier) >= Cents(claims.max_monthly_spend_in_cents) { return Err(Error::Http( StatusCode::FORBIDDEN, "Maximum spending limit reached for this month.".to_string(), @@ -640,6 +639,7 @@ impl Drop for TokenCountingStream { tokens, claims.has_llm_subscription, Cents(claims.max_monthly_spend_in_cents), + claims.free_tier_monthly_spending_limit(), Utc::now(), ) .await diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 5883bcef57..f262821743 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,5 +1,5 @@ +use crate::db::UserId; use crate::llm::Cents; -use crate::{db::UserId, llm::FREE_TIER_MONTHLY_SPENDING_LIMIT}; use chrono::{Datelike, Duration}; use futures::StreamExt as _; use rpc::LanguageModelProvider; @@ -299,6 +299,7 @@ impl LlmDatabase { tokens: TokenUsage, has_llm_subscription: bool, max_monthly_spend: Cents, + free_tier_monthly_spending_limit: Cents, now: DateTimeUtc, ) -> Result { self.transaction(|tx| async move { @@ -410,9 +411,9 @@ impl LlmDatabase { ); if !is_staff - && spending_this_month > FREE_TIER_MONTHLY_SPENDING_LIMIT + && spending_this_month > free_tier_monthly_spending_limit && has_llm_subscription - && (spending_this_month - FREE_TIER_MONTHLY_SPENDING_LIMIT) <= max_monthly_spend + && (spending_this_month - free_tier_monthly_spending_limit) <= max_monthly_spend { billing_event::ActiveModel { id: ActiveValue::not_set(), diff --git a/crates/collab/src/llm/db/tests/billing_tests.rs b/crates/collab/src/llm/db/tests/billing_tests.rs index 88551dd5f8..b76121887c 100644 --- a/crates/collab/src/llm/db/tests/billing_tests.rs +++ b/crates/collab/src/llm/db/tests/billing_tests.rs @@ -66,6 +66,7 @@ async fn test_billing_limit_exceeded(db: &mut LlmDatabase) { usage, true, max_monthly_spend, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -103,6 +104,7 @@ async fn test_billing_limit_exceeded(db: &mut LlmDatabase) { usage_2, true, max_monthly_spend, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -132,6 +134,7 @@ async fn test_billing_limit_exceeded(db: &mut LlmDatabase) { model, usage_exceeding, true, + FREE_TIER_MONTHLY_SPENDING_LIMIT, max_monthly_spend, now, ) diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 8e96ac4f54..3213c26e82 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -1,3 +1,4 @@ +use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT; use crate::{ db::UserId, llm::db::{ @@ -49,6 +50,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -68,6 +70,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -124,6 +127,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -180,6 +184,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -222,6 +227,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await @@ -259,6 +265,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { }, false, Cents::ZERO, + FREE_TIER_MONTHLY_SPENDING_LIMIT, now, ) .await diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index 28f52b5164..35f7cf26e7 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -1,8 +1,7 @@ -use crate::llm::DEFAULT_MAX_MONTHLY_SPEND; -use crate::{ - db::{billing_preference, UserId}, - Config, -}; +use crate::db::user; +use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; +use crate::Cents; +use crate::{db::billing_preference, Config}; use anyhow::{anyhow, Result}; use chrono::Utc; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; @@ -22,6 +21,7 @@ pub struct LlmTokenClaims { pub has_llm_closed_beta_feature_flag: bool, pub has_llm_subscription: bool, pub max_monthly_spend_in_cents: u32, + pub custom_llm_monthly_allowance_in_cents: Option, pub plan: rpc::proto::Plan, } @@ -30,8 +30,7 @@ const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); impl LlmTokenClaims { #[allow(clippy::too_many_arguments)] pub fn create( - user_id: UserId, - github_user_login: String, + user: &user::Model, is_staff: bool, billing_preferences: Option, has_llm_closed_beta_feature_flag: bool, @@ -49,8 +48,8 @@ impl LlmTokenClaims { iat: now.timestamp() as u64, exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64, jti: uuid::Uuid::new_v4().to_string(), - user_id: user_id.to_proto(), - github_user_login, + user_id: user.id.to_proto(), + github_user_login: user.github_login.clone(), is_staff, has_llm_closed_beta_feature_flag, has_llm_subscription, @@ -58,6 +57,9 @@ impl LlmTokenClaims { .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| { preferences.max_monthly_llm_usage_spending_in_cents as u32 }), + custom_llm_monthly_allowance_in_cents: user + .custom_llm_monthly_allowance_in_cents + .map(|allowance| allowance as u32), plan, }; @@ -89,6 +91,12 @@ impl LlmTokenClaims { } } } + + pub fn free_tier_monthly_spending_limit(&self) -> Cents { + self.custom_llm_monthly_allowance_in_cents + .map(Cents) + .unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT) + } } #[derive(Error, Debug)] diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 865453555f..e805893567 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4930,8 +4930,7 @@ async fn get_llm_api_token( let billing_preferences = db.get_billing_preferences(user.id).await?; let token = LlmTokenClaims::create( - user.id, - user.github_login.clone(), + &user, session.is_staff(), billing_preferences, has_llm_closed_beta_feature_flag,