diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index bbb8ac8428..36843ced56 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -54,6 +54,7 @@ pub fn router() -> Router { post(manage_billing_subscription), ) .route("/billing/monthly_spend", get(get_monthly_spend)) + .route("/billing/usage", get(get_current_usage)) } #[derive(Debug, Deserialize)] @@ -947,6 +948,93 @@ async fn get_monthly_spend( })) } +#[derive(Debug, Deserialize)] +struct GetCurrentUsageParams { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct UsageCounts { + pub used: i32, + pub limit: Option, + pub remaining: Option, +} + +#[derive(Debug, Serialize)] +struct GetCurrentUsageResponse { + pub model_requests: UsageCounts, + pub edit_predictions: UsageCounts, +} + +async fn get_current_usage( + Extension(app): Extension>, + Query(params): Query, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(params.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let Some(llm_db) = app.llm_db.clone() else { + return Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "LLM database not available".into(), + )); + }; + + let empty_usage = GetCurrentUsageResponse { + model_requests: UsageCounts { + used: 0, + limit: Some(0), + remaining: Some(0), + }, + edit_predictions: UsageCounts { + used: 0, + limit: Some(0), + remaining: Some(0), + }, + }; + + let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else { + return Ok(Json(empty_usage)); + }; + + let subscription_period = maybe!({ + let period_start_at = subscription.current_period_start_at()?; + let period_end_at = subscription.current_period_end_at()?; + + Some((period_start_at, period_end_at)) + }); + + let Some((period_start_at, period_end_at)) = subscription_period else { + return Ok(Json(empty_usage)); + }; + + let usage = llm_db + .get_subscription_usage_for_period(user.id, period_start_at, period_end_at) + .await?; + let Some(usage) = usage else { + return Ok(Json(empty_usage)); + }; + + let model_requests_limit = Some(500); + let edit_prediction_limit = Some(2000); + + Ok(Json(GetCurrentUsageResponse { + model_requests: UsageCounts { + used: usage.model_requests, + limit: model_requests_limit, + remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)), + }, + edit_predictions: UsageCounts { + used: usage.edit_predictions, + limit: edit_prediction_limit, + remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)), + }, + })) +} + impl From for StripeSubscriptionStatus { fn from(value: SubscriptionStatus) -> Self { match value { diff --git a/crates/collab/src/db/tables/billing_subscription.rs b/crates/collab/src/db/tables/billing_subscription.rs index d834a2d3ac..4fb06a6985 100644 --- a/crates/collab/src/db/tables/billing_subscription.rs +++ b/crates/collab/src/db/tables/billing_subscription.rs @@ -19,6 +19,18 @@ pub struct Model { pub created_at: DateTime, } +impl Model { + pub fn current_period_start_at(&self) -> Option { + let period_start = self.stripe_current_period_start?; + chrono::DateTime::from_timestamp(period_start, 0) + } + + pub fn current_period_end_at(&self) -> Option { + let period_end = self.stripe_current_period_end?; + chrono::DateTime::from_timestamp(period_end, 0) + } +} + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { #[sea_orm( diff --git a/crates/collab/src/llm/db/queries.rs b/crates/collab/src/llm/db/queries.rs index 4a4a10fb51..6f9aab4a68 100644 --- a/crates/collab/src/llm/db/queries.rs +++ b/crates/collab/src/llm/db/queries.rs @@ -2,4 +2,5 @@ use super::*; pub mod billing_events; pub mod providers; +pub mod subscription_usages; pub mod usages; diff --git a/crates/collab/src/llm/db/queries/subscription_usages.rs b/crates/collab/src/llm/db/queries/subscription_usages.rs new file mode 100644 index 0000000000..2a04b4aba6 --- /dev/null +++ b/crates/collab/src/llm/db/queries/subscription_usages.rs @@ -0,0 +1,22 @@ +use crate::db::UserId; + +use super::*; + +impl LlmDatabase { + pub async fn get_subscription_usage_for_period( + &self, + user_id: UserId, + period_start_at: DateTimeUtc, + period_end_at: DateTimeUtc, + ) -> Result> { + self.transaction(|tx| async move { + Ok(subscription_usage::Entity::find() + .filter(subscription_usage::Column::UserId.eq(user_id)) + .filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at)) + .filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at)) + .one(&*tx) + .await?) + }) + .await + } +} diff --git a/crates/collab/src/llm/db/tables.rs b/crates/collab/src/llm/db/tables.rs index 5f2d357a87..0e99e01144 100644 --- a/crates/collab/src/llm/db/tables.rs +++ b/crates/collab/src/llm/db/tables.rs @@ -2,5 +2,6 @@ pub mod billing_event; pub mod model; pub mod monthly_usage; pub mod provider; +pub mod subscription_usage; pub mod usage; pub mod usage_measure; diff --git a/crates/collab/src/llm/db/tables/subscription_usage.rs b/crates/collab/src/llm/db/tables/subscription_usage.rs new file mode 100644 index 0000000000..33311d22f6 --- /dev/null +++ b/crates/collab/src/llm/db/tables/subscription_usage.rs @@ -0,0 +1,20 @@ +use crate::db::UserId; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel)] +#[sea_orm(table_name = "subscription_usages")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub user_id: UserId, + pub period_start_at: PrimitiveDateTime, + pub period_end_at: PrimitiveDateTime, + pub model_requests: i32, + pub edit_predictions: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index e52e155d38..0979ff13b4 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -3,7 +3,7 @@ use crate::db::{billing_subscription, user}; use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; use crate::{Config, db::billing_preference}; use anyhow::{Result, anyhow}; -use chrono::{DateTime, NaiveDateTime, Utc}; +use chrono::{NaiveDateTime, Utc}; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation}; use serde::{Deserialize, Serialize}; use std::time::Duration; @@ -84,13 +84,10 @@ impl LlmTokenClaims { plan, subscription_period: maybe!({ let subscription = subscription?; - let period_start = subscription.stripe_current_period_start?; - let period_start = DateTime::from_timestamp(period_start, 0)?; + let period_start_at = subscription.current_period_start_at()?; + let period_end_at = subscription.current_period_end_at()?; - let period_end = subscription.stripe_current_period_end?; - let period_end = DateTime::from_timestamp(period_end, 0)?; - - Some((period_start.naive_utc(), period_end.naive_utc())) + Some((period_start_at.naive_utc(), period_end_at.naive_utc())) }), };