From 68ec1d724c11f76c0b1d1a508245115ad498c5c7 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 15 Apr 2025 19:25:41 -0400 Subject: [PATCH] collab: Include `subscription_period` in LLM token claims (#28819) This PR updates the LLM token claims to include the user's active subscription period. Release Notes: - N/A --- .../src/db/queries/billing_subscriptions.rs | 23 +++++++++++++++++++ crates/collab/src/llm/token.rs | 22 ++++++++++++++---- crates/collab/src/rpc.rs | 6 +++-- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 1eef99beb7..0d49aa9048 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -108,6 +108,28 @@ impl Database { .await } + pub async fn get_active_billing_subscription( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .filter(billing_customer::Column::UserId.eq(user_id)) + .filter( + Condition::all() + .add( + billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Active), + ) + .add(billing_subscription::Column::Kind.is_not_null()), + ) + .one(&*tx) + .await?) + }) + .await + } + /// Returns all of the billing subscriptions for the user with the specified ID. /// /// Note that this returns the subscriptions regardless of their status. @@ -145,6 +167,7 @@ impl Database { billing_subscription::Column::StripeSubscriptionStatus .eq(StripeSubscriptionStatus::Active), ) + .filter(billing_subscription::Column::Kind.is_null()) .order_by_asc(billing_subscription::Column::Id) .stream(&*tx) .await?; diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index 68005a2d4c..e52e155d38 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -1,13 +1,14 @@ use crate::Cents; -use crate::db::user; +use crate::db::{billing_subscription, user}; use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; use crate::{Config, db::billing_preference}; use anyhow::{Result, anyhow}; -use chrono::{NaiveDateTime, Utc}; +use chrono::{DateTime, NaiveDateTime, Utc}; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; +use util::maybe; use uuid::Uuid; #[derive(Clone, Debug, Default, Serialize, Deserialize)] @@ -29,6 +30,8 @@ pub struct LlmTokenClaims { pub max_monthly_spend_in_cents: u32, pub custom_llm_monthly_allowance_in_cents: Option, pub plan: rpc::proto::Plan, + #[serde(default)] + pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>, } const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60); @@ -39,8 +42,9 @@ impl LlmTokenClaims { is_staff: bool, billing_preferences: Option, feature_flags: &Vec, - has_llm_subscription: bool, + has_legacy_llm_subscription: bool, plan: rpc::proto::Plan, + subscription: Option, system_id: Option, config: &Config, ) -> Result { @@ -69,7 +73,7 @@ impl LlmTokenClaims { has_predict_edits_feature_flag: feature_flags .iter() .any(|flag| flag == "predict-edits"), - has_llm_subscription, + has_llm_subscription: has_legacy_llm_subscription, max_monthly_spend_in_cents: billing_preferences .map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| { preferences.max_monthly_llm_usage_spending_in_cents as u32 @@ -78,6 +82,16 @@ impl LlmTokenClaims { .custom_llm_monthly_allowance_in_cents .map(|allowance| allowance as u32), plan, + subscription_period: maybe!({ + let subscription = subscription?; + let period_start = subscription.stripe_current_period_start?; + let period_start = DateTime::from_timestamp(period_start, 0)?; + + let period_end = subscription.stripe_current_period_end?; + let period_end = DateTime::from_timestamp(period_end, 0)?; + + Some((period_start.naive_utc(), period_end.naive_utc())) + }), }; Ok(jsonwebtoken::encode( diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index bbc70c839e..87f69fbfa3 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -4135,7 +4135,8 @@ async fn get_llm_api_token( Err(anyhow!("terms of service not accepted"))? } - let has_llm_subscription = session.has_llm_subscription(&db).await?; + let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?; + let billing_subscription = db.get_active_billing_subscription(user.id).await?; let billing_preferences = db.get_billing_preferences(user.id).await?; let token = LlmTokenClaims::create( @@ -4143,8 +4144,9 @@ async fn get_llm_api_token( session.is_staff(), billing_preferences, &flags, - has_llm_subscription, + has_legacy_llm_subscription, session.current_plan(&db).await?, + billing_subscription, session.system_id.clone(), &session.app_state.config, )?;