diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 01ace385de..6a498dcf83 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -1137,6 +1137,12 @@ async fn handle_customer_subscription_event( .await?; } + // When the user's subscription changes, push down any changes to their plan. + rpc_server + .update_plan_for_user(billing_customer.user_id) + .await + .trace_err(); + // When the user's subscription changes, we want to refresh their LLM tokens // to either grant/revoke access. rpc_server diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 0b4b353f92..06b011f0f9 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,6 +2,7 @@ mod connection_pool; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; +use crate::llm::db::LlmDatabase; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims}; use crate::{ AppState, Error, Result, auth, @@ -67,7 +68,7 @@ use std::{ time::{Duration, Instant}, }; use time::OffsetDateTime; -use tokio::sync::{MutexGuard, Semaphore, watch}; +use tokio::sync::{Semaphore, watch}; use tower::ServiceBuilder; use tracing::{ Instrument, @@ -166,29 +167,6 @@ impl Session { } } - pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result { - if self.is_staff() { - return Ok(proto::Plan::ZedPro); - } - - let user_id = self.user_id(); - - let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_kind = subscription.and_then(|subscription| subscription.kind); - - let plan = if let Some(subscription_kind) = subscription_kind { - match subscription_kind { - SubscriptionKind::ZedPro => proto::Plan::ZedPro, - SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, - SubscriptionKind::ZedFree => proto::Plan::Free, - } - } else { - proto::Plan::Free - }; - - Ok(plan) - } - fn user_id(&self) -> UserId { match &self.principal { Principal::User(user) => user.id, @@ -953,6 +931,32 @@ impl Server { Ok(()) } + pub async fn update_plan_for_user(self: &Arc, user_id: UserId) -> Result<()> { + let user = self + .app_state + .db + .get_user_by_id(user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let update_user_plan = make_update_user_plan_message( + &self.app_state.db, + self.app_state.llm_db.clone(), + user_id, + user.admin, + ) + .await?; + + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer + .send(connection_id, update_user_plan.clone()) + .trace_err(); + } + + Ok(()) + } + pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) { let pool = self.connection_pool.lock(); for connection_id in pool.user_connection_ids(user_id) { @@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool { version.0.minor() < 139 } -async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { - let db = session.db().await; +async fn current_plan(db: &Arc, user_id: UserId, is_staff: bool) -> Result { + if is_staff { + return Ok(proto::Plan::ZedPro); + } + let subscription = db.get_active_billing_subscription(user_id).await?; + let subscription_kind = subscription.and_then(|subscription| subscription.kind); + + let plan = if let Some(subscription_kind) = subscription_kind { + match subscription_kind { + SubscriptionKind::ZedPro => proto::Plan::ZedPro, + SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial, + SubscriptionKind::ZedFree => proto::Plan::Free, + } + } else { + proto::Plan::Free + }; + + Ok(plan) +} + +async fn make_update_user_plan_message( + db: &Arc, + llm_db: Option>, + user_id: UserId, + is_staff: bool, +) -> Result { let feature_flags = db.get_user_flags(user_id).await?; - let plan = session.current_plan(&db).await?; + let plan = current_plan(db, user_id, is_staff).await?; let billing_customer = db.get_billing_customer_by_user_id(user_id).await?; let billing_preferences = db.get_billing_preferences(user_id).await?; - let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() { + let (subscription_period, usage) = if let Some(llm_db) = llm_db { let subscription = db.get_active_billing_subscription(user_id).await?; - let subscription_period = crate::db::billing_subscription::Model::current_period( - subscription, - session.is_staff(), - ); + let subscription_period = + crate::db::billing_subscription::Model::current_period(subscription, is_staff); let usage = if let Some((period_start_at, period_end_at)) = subscription_period { llm_db @@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { (None, None) }; - session - .peer - .send( - session.connection_id, - proto::UpdateUserPlan { - plan: plan.into(), - trial_started_at: billing_customer - .and_then(|billing_customer| billing_customer.trial_started_at) - .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), - is_usage_based_billing_enabled: if session.is_staff() { - Some(true) - } else { - billing_preferences - .map(|preferences| preferences.model_request_overages_enabled) - }, - subscription_period: subscription_period.map(|(started_at, ended_at)| { - proto::SubscriptionPeriod { - started_at: started_at.timestamp() as u64, - ended_at: ended_at.timestamp() as u64, - } - }), - usage: usage.map(|usage| { - let plan = match plan { - proto::Plan::Free => zed_llm_client::Plan::ZedFree, - proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, - proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + Ok(proto::UpdateUserPlan { + plan: plan.into(), + trial_started_at: billing_customer + .and_then(|billing_customer| billing_customer.trial_started_at) + .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), + is_usage_based_billing_enabled: if is_staff { + Some(true) + } else { + billing_preferences.map(|preferences| preferences.model_request_overages_enabled) + }, + subscription_period: subscription_period.map(|(started_at, ended_at)| { + proto::SubscriptionPeriod { + started_at: started_at.timestamp() as u64, + ended_at: ended_at.timestamp() as u64, + } + }), + usage: usage.map(|usage| { + let plan = match plan { + proto::Plan::Free => zed_llm_client::Plan::ZedFree, + proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, + proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, + }; + + let model_requests_limit = match plan.model_requests_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + let limit = if plan == zed_llm_client::Plan::ZedProTrial + && feature_flags + .iter() + .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) + { + 1_000 + } else { + limit }; - let model_requests_limit = match plan.model_requests_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - let limit = if plan == zed_llm_client::Plan::ZedProTrial - && feature_flags - .iter() - .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG) - { - 1_000 - } else { - limit - }; + zed_llm_client::UsageLimit::Limited(limit) + } + zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited, + }; - zed_llm_client::UsageLimit::Limited(limit) + proto::SubscriptionUsage { + model_requests_usage_amount: usage.model_requests as u32, + model_requests_usage_limit: Some(proto::UsageLimit { + variant: Some(match model_requests_limit { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) } zed_llm_client::UsageLimit::Unlimited => { - zed_llm_client::UsageLimit::Unlimited + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) } - }; - - proto::SubscriptionUsage { - model_requests_usage_amount: usage.model_requests as u32, - model_requests_usage_limit: Some(proto::UsageLimit { - variant: Some(match model_requests_limit { - zed_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited( - proto::usage_limit::Limited { - limit: limit as u32, - }, - ) - } - zed_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited( - proto::usage_limit::Unlimited {}, - ) - } - }), - }), - edit_predictions_usage_amount: usage.edit_predictions as u32, - edit_predictions_usage_limit: Some(proto::UsageLimit { - variant: Some(match plan.edit_predictions_limit() { - zed_llm_client::UsageLimit::Limited(limit) => { - proto::usage_limit::Variant::Limited( - proto::usage_limit::Limited { - limit: limit as u32, - }, - ) - } - zed_llm_client::UsageLimit::Unlimited => { - proto::usage_limit::Variant::Unlimited( - proto::usage_limit::Unlimited {}, - ) - } - }), - }), - } + }), }), - }, - ) + edit_predictions_usage_amount: usage.edit_predictions as u32, + edit_predictions_usage_limit: Some(proto::UsageLimit { + variant: Some(match plan.edit_predictions_limit() { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + }), + } + }), + }) +} + +async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + + let update_user_plan = make_update_user_plan_message( + &db.0, + session.app_state.llm_db.clone(), + user_id, + session.is_staff(), + ) + .await?; + + session + .peer + .send(session.connection_id, update_user_plan) .trace_err(); Ok(())