collab: Include subscription_period in LLM token claims (#28819)

This PR updates the LLM token claims to include the user's active
subscription period.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-15 19:25:41 -04:00 committed by GitHub
parent 102ea6ac79
commit 68ec1d724c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 45 additions and 6 deletions

View file

@ -108,6 +108,28 @@ impl Database {
.await
}
pub async fn get_active_billing_subscription(
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.filter(
Condition::all()
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.add(billing_subscription::Column::Kind.is_not_null()),
)
.one(&*tx)
.await?)
})
.await
}
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
@ -145,6 +167,7 @@ impl Database {
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.filter(billing_subscription::Column::Kind.is_null())
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;

View file

@ -1,13 +1,14 @@
use crate::Cents;
use crate::db::user;
use crate::db::{billing_subscription, user};
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::{Config, db::billing_preference};
use anyhow::{Result, anyhow};
use chrono::{NaiveDateTime, Utc};
use chrono::{DateTime, NaiveDateTime, Utc};
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use util::maybe;
use uuid::Uuid;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
@ -29,6 +30,8 @@ pub struct LlmTokenClaims {
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
pub plan: rpc::proto::Plan,
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
@ -39,8 +42,9 @@ impl LlmTokenClaims {
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_llm_subscription: bool,
has_legacy_llm_subscription: bool,
plan: rpc::proto::Plan,
subscription: Option<billing_subscription::Model>,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
@ -69,7 +73,7 @@ impl LlmTokenClaims {
has_predict_edits_feature_flag: feature_flags
.iter()
.any(|flag| flag == "predict-edits"),
has_llm_subscription,
has_llm_subscription: has_legacy_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
preferences.max_monthly_llm_usage_spending_in_cents as u32
@ -78,6 +82,16 @@ impl LlmTokenClaims {
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
plan,
subscription_period: maybe!({
let subscription = subscription?;
let period_start = subscription.stripe_current_period_start?;
let period_start = DateTime::from_timestamp(period_start, 0)?;
let period_end = subscription.stripe_current_period_end?;
let period_end = DateTime::from_timestamp(period_end, 0)?;
Some((period_start.naive_utc(), period_end.naive_utc()))
}),
};
Ok(jsonwebtoken::encode(

View file

@ -4135,7 +4135,8 @@ async fn get_llm_api_token(
Err(anyhow!("terms of service not accepted"))?
}
let has_llm_subscription = session.has_llm_subscription(&db).await?;
let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(
@ -4143,8 +4144,9 @@ async fn get_llm_api_token(
session.is_staff(),
billing_preferences,
&flags,
has_llm_subscription,
has_legacy_llm_subscription,
session.current_plan(&db).await?,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,
)?;