From dbcd06642c77719d59c8084f7c38636f1576bf99 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 12 Aug 2024 17:15:26 -0700 Subject: [PATCH] Track lifetime spending for each user and model (#16137) Release Notes: - N/A Co-authored-by: Marshall --- .../20240812225346_create_lifetime_usages.sql | 9 +++ crates/collab/src/llm.rs | 14 ++-- crates/collab/src/llm/db/queries/usages.rs | 68 ++++++++++++++++++- crates/collab/src/llm/db/tables.rs | 1 + .../src/llm/db/tables/lifetime_usage.rs | 18 +++++ crates/collab/src/llm/db/tables/usage.rs | 7 +- crates/collab/src/llm/db/tests/usage_tests.rs | 14 +++- crates/collab/src/llm/telemetry.rs | 1 + 8 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql create mode 100644 crates/collab/src/llm/db/tables/lifetime_usage.rs diff --git a/crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql b/crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql new file mode 100644 index 0000000000..42047433e5 --- /dev/null +++ b/crates/collab/migrations_llm/20240812225346_create_lifetime_usages.sql @@ -0,0 +1,9 @@ +create table lifetime_usages ( + id serial primary key, + user_id integer not null, + model_id integer not null references models (id) on delete cascade, + input_tokens bigint not null default 0, + output_tokens bigint not null default 0 +); + +create unique index uix_lifetime_usages_on_user_id_model_id on lifetime_usages (user_id, model_id); diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index e7d489e837..42f4db7a38 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -4,8 +4,8 @@ mod telemetry; mod token; use crate::{ - api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error, - Result, + api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, + Config, Error, Result, }; use anyhow::{anyhow, Context as _}; use authorization::authorize_access_to_language_model; @@ -396,7 +396,12 @@ async fn check_usage_limit( let model = state.db.model(provider, model_name)?; let usage = state .db - .get_usage(claims.user_id as i32, provider, model_name, Utc::now()) + .get_usage( + UserId::from_proto(claims.user_id), + provider, + model_name, + Utc::now(), + ) .await?; let active_users = state.get_active_user_count().await?; @@ -523,7 +528,7 @@ impl Drop for TokenCountingStream { let usage = state .db .record_usage( - claims.user_id as i32, + UserId::from_proto(claims.user_id), claims.is_staff, provider, &model, @@ -555,6 +560,7 @@ impl Drop for TokenCountingStream { input_tokens_this_month: usage.input_tokens_this_month as u64, output_tokens_this_month: usage.output_tokens_this_month as u64, spending_this_month: usage.spending_this_month as u64, + lifetime_spending: usage.lifetime_spending as u64, }, ) .await diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index e11adaf2a7..adfd55088f 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -1,3 +1,4 @@ +use crate::db::UserId; use chrono::Duration; use rpc::LanguageModelProvider; use sea_orm::QuerySelect; @@ -14,6 +15,7 @@ pub struct Usage { pub input_tokens_this_month: usize, pub output_tokens_this_month: usize, pub spending_this_month: usize, + pub lifetime_spending: usize, } #[derive(Clone, Copy, Debug, Default)] @@ -63,7 +65,7 @@ impl LlmDatabase { pub async fn get_usage( &self, - user_id: i32, + user_id: UserId, provider: LanguageModelProvider, model_name: &str, now: DateTimeUtc, @@ -83,6 +85,18 @@ impl LlmDatabase { .all(&*tx) .await?; + let (lifetime_input_tokens, lifetime_output_tokens) = lifetime_usage::Entity::find() + .filter( + lifetime_usage::Column::UserId + .eq(user_id) + .and(lifetime_usage::Column::ModelId.eq(model.id)), + ) + .one(&*tx) + .await? + .map_or((0, 0), |usage| { + (usage.input_tokens as usize, usage.output_tokens as usize) + }); + let requests_this_minute = self.get_usage_for_measure(&usages, now, UsageMeasure::RequestsPerMinute)?; let tokens_this_minute = @@ -95,6 +109,8 @@ impl LlmDatabase { self.get_usage_for_measure(&usages, now, UsageMeasure::OutputTokensPerMonth)?; let spending_this_month = calculate_spending(model, input_tokens_this_month, output_tokens_this_month); + let lifetime_spending = + calculate_spending(model, lifetime_input_tokens, lifetime_output_tokens); Ok(Usage { requests_this_minute, @@ -103,6 +119,7 @@ impl LlmDatabase { input_tokens_this_month, output_tokens_this_month, spending_this_month, + lifetime_spending, }) }) .await @@ -111,7 +128,7 @@ impl LlmDatabase { #[allow(clippy::too_many_arguments)] pub async fn record_usage( &self, - user_id: i32, + user_id: UserId, is_staff: bool, provider: LanguageModelProvider, model_name: &str, @@ -194,6 +211,50 @@ impl LlmDatabase { let spending_this_month = calculate_spending(model, input_tokens_this_month, output_tokens_this_month); + // Update lifetime usage + let lifetime_usage = lifetime_usage::Entity::find() + .filter( + lifetime_usage::Column::UserId + .eq(user_id) + .and(lifetime_usage::Column::ModelId.eq(model.id)), + ) + .one(&*tx) + .await?; + + let lifetime_usage = match lifetime_usage { + Some(usage) => { + lifetime_usage::Entity::update(lifetime_usage::ActiveModel { + id: ActiveValue::unchanged(usage.id), + input_tokens: ActiveValue::set( + usage.input_tokens + input_token_count as i64, + ), + output_tokens: ActiveValue::set( + usage.output_tokens + output_token_count as i64, + ), + ..Default::default() + }) + .exec(&*tx) + .await? + } + None => { + lifetime_usage::ActiveModel { + user_id: ActiveValue::set(user_id), + model_id: ActiveValue::set(model.id), + input_tokens: ActiveValue::set(input_token_count as i64), + output_tokens: ActiveValue::set(output_token_count as i64), + ..Default::default() + } + .insert(&*tx) + .await? + } + }; + + let lifetime_spending = calculate_spending( + model, + lifetime_usage.input_tokens as usize, + lifetime_usage.output_tokens as usize, + ); + Ok(Usage { requests_this_minute, tokens_this_minute, @@ -201,6 +262,7 @@ impl LlmDatabase { input_tokens_this_month, output_tokens_this_month, spending_this_month, + lifetime_spending, }) }) .await @@ -246,7 +308,7 @@ impl LlmDatabase { #[allow(clippy::too_many_arguments)] async fn update_usage_for_measure( &self, - user_id: i32, + user_id: UserId, is_staff: bool, model_id: ModelId, usages: &[usage::Model], diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 603e7f91a4..2333c20a2e 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -1,3 +1,4 @@ +pub mod lifetime_usage; pub mod model; pub mod provider; pub mod usage; diff --git a/crates/collab/src/llm/db/tables/lifetime_usage.rs b/crates/collab/src/llm/db/tables/lifetime_usage.rs new file mode 100644 index 0000000000..05ad2d5e94 --- /dev/null +++ b/crates/collab/src/llm/db/tables/lifetime_usage.rs @@ -0,0 +1,18 @@ +use crate::{db::UserId, llm::db::ModelId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "lifetime_usages")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub user_id: UserId, + pub model_id: ModelId, + pub input_tokens: i64, + pub output_tokens: i64, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/db/tables/usage.rs b/crates/collab/src/llm/db/tables/usage.rs index 76f7e1b01d..331c94a8a9 100644 --- a/crates/collab/src/llm/db/tables/usage.rs +++ b/crates/collab/src/llm/db/tables/usage.rs @@ -1,4 +1,7 @@ -use crate::llm::db::{ModelId, UsageId, UsageMeasureId}; +use crate::{ + db::UserId, + llm::db::{ModelId, UsageId, UsageMeasureId}, +}; use sea_orm::entity::prelude::*; /// An LLM usage record. @@ -10,7 +13,7 @@ pub struct Model { /// The ID of the Zed user. /// /// Corresponds to the `users` table in the primary collab database. - pub user_id: i32, + pub user_id: UserId, pub model_id: ModelId, pub measure_id: UsageMeasureId, pub timestamp: DateTime, diff --git a/crates/collab/src/llm/db/tests/usage_tests.rs b/crates/collab/src/llm/db/tests/usage_tests.rs index 336c3c9301..905a3dda08 100644 --- a/crates/collab/src/llm/db/tests/usage_tests.rs +++ b/crates/collab/src/llm/db/tests/usage_tests.rs @@ -1,5 +1,9 @@ use crate::{ - llm::db::{queries::providers::ModelParams, queries::usages::Usage, LlmDatabase}, + db::UserId, + llm::db::{ + queries::{providers::ModelParams, usages::Usage}, + LlmDatabase, + }, test_llm_db, }; use chrono::{Duration, Utc}; @@ -26,7 +30,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { .unwrap(); let t0 = Utc::now(); - let user_id = 123; + let user_id = UserId::from_proto(123); let now = t0; db.record_usage(user_id, false, provider, model, 1000, 0, now) @@ -48,6 +52,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 3000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); @@ -62,6 +67,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 3000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); @@ -80,6 +86,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 6000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); @@ -95,6 +102,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 6000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); @@ -112,6 +120,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 10000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); @@ -127,6 +136,7 @@ async fn test_tracking_usage(db: &mut LlmDatabase) { input_tokens_this_month: 9000, output_tokens_this_month: 0, spending_this_month: 0, + lifetime_spending: 0, } ); } diff --git a/crates/collab/src/llm/telemetry.rs b/crates/collab/src/llm/telemetry.rs index 1cfa18e69d..ac90bd265a 100644 --- a/crates/collab/src/llm/telemetry.rs +++ b/crates/collab/src/llm/telemetry.rs @@ -17,6 +17,7 @@ pub struct LlmUsageEventRow { pub input_tokens_this_month: u64, pub output_tokens_this_month: u64, pub spending_this_month: u64, + pub lifetime_spending: u64, } #[derive(Serialize, Debug, clickhouse::Row)]