diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index 6b891c1949..eeeb6fde4b 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -1,4 +1,5 @@ use crate::db::{BillingCustomerId, BillingSubscriptionId}; +use chrono::{Datelike as _, NaiveDate, Utc}; use sea_orm::entity::prelude::*; use serde::Serialize; @@ -29,6 +30,38 @@ impl Model { let period_end = self.stripe_current_period_end?; chrono::DateTime::from_timestamp(period_end, 0) } + + pub fn current_period( + subscription: Option, + is_staff: bool, + ) -> Option<(DateTimeUtc, DateTimeUtc)> { + if is_staff { + let now = Utc::now(); + let year = now.year(); + let month = now.month(); + + let first_day_of_this_month = + NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?; + + let next_month = if month == 12 { 1 } else { month + 1 }; + let next_month_year = if month == 12 { year + 1 } else { year }; + let first_day_of_next_month = + NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?; + + let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1); + + Some(( + first_day_of_this_month.and_utc(), + last_day_of_this_month.and_utc(), + )) + } else { + let subscription = subscription?; + let period_start_at = subscription.current_period_start_at()?; + let period_end_at = subscription.current_period_end_at()?; + + Some((period_start_at, period_end_at)) + } + } } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index acfe3ca95f..52c2acc584 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -6,12 +6,11 @@ use crate::llm::{ }; use crate::{Config, db::billing_preference}; use anyhow::{Result, anyhow}; -use chrono::{Datelike, NaiveDate, NaiveDateTime, Utc}; +use chrono::{NaiveDateTime, Utc}; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; use thiserror::Error; -use util::maybe; use uuid::Uuid; use zed_llm_client::Plan; @@ -111,34 +110,11 @@ impl LlmTokenClaims { has_extended_trial: feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG), - subscription_period: if is_staff { - maybe!({ - let now = Utc::now(); - let year = now.year(); - let month = now.month(); - - let first_day_of_this_month = - NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?; - - let next_month = if month == 12 { 1 } else { month + 1 }; - let next_month_year = if month == 12 { year + 1 } else { year }; - let first_day_of_next_month = - NaiveDate::from_ymd_opt(next_month_year, next_month, 1)? - .and_hms_opt(23, 59, 59)?; - - let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1); - - Some((first_day_of_this_month, last_day_of_this_month)) - }) - } else { - maybe!({ - let subscription = subscription?; - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at.naive_utc(), period_end_at.naive_utc())) - }) - }, + subscription_period: billing_subscription::Model::current_period( + subscription, + is_staff, + ) + .map(|(start, end)| (start.naive_utc(), end.naive_utc())), enable_model_request_overages: billing_preferences .as_ref() .map_or(false, |preferences| { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 68d92f4be6..37cab39861 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -37,7 +37,6 @@ use core::fmt::{self, Debug, Formatter}; use reqwest_client::ReqwestClient; use rpc::proto::split_repository_update; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; -use util::maybe; use futures::{ FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture, @@ -2713,13 +2712,10 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { let usage = if let Some(llm_db) = session.app_state.llm_db.clone() { let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_period = maybe!({ - let subscription = subscription?; - let period_start_at = subscription.current_period_start_at()?; - let period_end_at = subscription.current_period_end_at()?; - - Some((period_start_at, period_end_at)) - }); + let subscription_period = crate::db::billing_subscription::Model::current_period( + subscription, + session.is_staff(), + ); if let Some((period_start_at, period_end_at)) = subscription_period { llm_db