diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index af3b70ed75..46eb26e99d 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -25,18 +25,12 @@ pub struct LlmTokenClaims { pub is_staff: bool, pub has_llm_closed_beta_feature_flag: bool, pub bypass_account_age_check: bool, - #[serde(default)] pub use_llm_request_queue: bool, pub plan: Plan, - #[serde(default)] pub has_extended_trial: bool, - #[serde(default)] - pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>, - #[serde(default)] + pub subscription_period: (NaiveDateTime, NaiveDateTime), pub enable_model_request_overages: bool, - #[serde(default)] pub model_request_overages_spend_limit_in_cents: u32, - #[serde(default)] pub can_use_web_search_tool: bool, } @@ -57,6 +51,23 @@ impl LlmTokenClaims { .as_ref() .ok_or_else(|| anyhow!("no LLM API secret"))?; + let plan = if is_staff { + Plan::ZedPro + } else { + subscription + .as_ref() + .and_then(|subscription| subscription.kind) + .map_or(Plan::Free, |kind| match kind { + SubscriptionKind::ZedFree => Plan::Free, + SubscriptionKind::ZedPro => Plan::ZedPro, + SubscriptionKind::ZedProTrial => Plan::ZedProTrial, + }) + }; + let subscription_period = + billing_subscription::Model::current_period(subscription, is_staff) + .map(|(start, end)| (start.naive_utc(), end.naive_utc())) + .ok_or_else(|| anyhow!("missing subscription period"))?; + let now = Utc::now(); let claims = Self { iat: now.timestamp() as u64, @@ -76,26 +87,11 @@ impl LlmTokenClaims { .any(|flag| flag == "bypass-account-age-check"), can_use_web_search_tool: true, use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"), - plan: if is_staff { - Plan::ZedPro - } else { - subscription - .as_ref() - .and_then(|subscription| subscription.kind) - .map_or(Plan::Free, |kind| match kind { - SubscriptionKind::ZedFree => Plan::Free, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) - }, + plan, has_extended_trial: feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period: billing_subscription::Model::current_period( - subscription, - is_staff, - ) - .map(|(start, end)| (start.naive_utc(), end.naive_utc())), + subscription_period, enable_model_request_overages: billing_preferences .as_ref() .map_or(false, |preferences| {