collab: Push down plan changes to the client (#30447)

This PR makes it so we push down plan updates from the server when the
user's subscription changes.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-09 20:08:48 -04:00 committed by GitHub
parent 79ba22673b
commit daa777440d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 143 additions and 111 deletions

View file

@ -1137,6 +1137,12 @@ async fn handle_customer_subscription_event(
.await?; .await?;
} }
// When the user's subscription changes, push down any changes to their plan.
rpc_server
.update_plan_for_user(billing_customer.user_id)
.await
.trace_err();
// When the user's subscription changes, we want to refresh their LLM tokens // When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access. // to either grant/revoke access.
rpc_server rpc_server

View file

@ -2,6 +2,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind; use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims}; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::{ use crate::{
AppState, Error, Result, auth, AppState, Error, Result, auth,
@ -67,7 +68,7 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use time::OffsetDateTime; use time::OffsetDateTime;
use tokio::sync::{MutexGuard, Semaphore, watch}; use tokio::sync::{Semaphore, watch};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tracing::{ use tracing::{
Instrument, Instrument,
@ -166,29 +167,6 @@ impl Session {
} }
} }
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
if self.is_staff() {
return Ok(proto::Plan::ZedPro);
}
let user_id = self.user_id();
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
fn user_id(&self) -> UserId { fn user_id(&self) -> UserId {
match &self.principal { match &self.principal {
Principal::User(user) => user.id, Principal::User(user) => user.id,
@ -953,6 +931,32 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
let user = self
.app_state
.db
.get_user_by_id(user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let update_user_plan = make_update_user_plan_message(
&self.app_state.db,
self.app_state.llm_db.clone(),
user_id,
user.admin,
)
.await?;
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, update_user_plan.clone())
.trace_err();
}
Ok(())
}
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) { pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock(); let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) { for connection_id in pool.user_connection_ids(user_id) {
@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139 version.0.minor() < 139
} }
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> { async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
let db = session.db().await; if is_staff {
return Ok(proto::Plan::ZedPro);
}
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
async fn make_update_user_plan_message(
db: &Arc<Database>,
llm_db: Option<Arc<LlmDatabase>>,
user_id: UserId,
is_staff: bool,
) -> Result<proto::UpdateUserPlan> {
let feature_flags = db.get_user_flags(user_id).await?; let feature_flags = db.get_user_flags(user_id).await?;
let plan = session.current_plan(&db).await?; let plan = current_plan(db, user_id, is_staff).await?;
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?; let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
let billing_preferences = db.get_billing_preferences(user_id).await?; let billing_preferences = db.get_billing_preferences(user_id).await?;
let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() { let (subscription_period, usage) = if let Some(llm_db) = llm_db {
let subscription = db.get_active_billing_subscription(user_id).await?; let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_period = crate::db::billing_subscription::Model::current_period( let subscription_period =
subscription, crate::db::billing_subscription::Model::current_period(subscription, is_staff);
session.is_staff(),
);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period { let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db llm_db
@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
(None, None) (None, None)
}; };
session Ok(proto::UpdateUserPlan {
.peer plan: plan.into(),
.send( trial_started_at: billing_customer
session.connection_id, .and_then(|billing_customer| billing_customer.trial_started_at)
proto::UpdateUserPlan { .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
plan: plan.into(), is_usage_based_billing_enabled: if is_staff {
trial_started_at: billing_customer Some(true)
.and_then(|billing_customer| billing_customer.trial_started_at) } else {
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64), billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
is_usage_based_billing_enabled: if session.is_staff() { },
Some(true) subscription_period: subscription_period.map(|(started_at, ended_at)| {
} else { proto::SubscriptionPeriod {
billing_preferences started_at: started_at.timestamp() as u64,
.map(|preferences| preferences.model_request_overages_enabled) ended_at: ended_at.timestamp() as u64,
}, }
subscription_period: subscription_period.map(|(started_at, ended_at)| { }),
proto::SubscriptionPeriod { usage: usage.map(|usage| {
started_at: started_at.timestamp() as u64, let plan = match plan {
ended_at: ended_at.timestamp() as u64, proto::Plan::Free => zed_llm_client::Plan::ZedFree,
} proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
}), proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
usage: usage.map(|usage| { };
let plan = match plan {
proto::Plan::Free => zed_llm_client::Plan::ZedFree, let model_requests_limit = match plan.model_requests_limit() {
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro, zed_llm_client::UsageLimit::Limited(limit) => {
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial, let limit = if plan == zed_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
}; };
let model_requests_limit = match plan.model_requests_limit() { zed_llm_client::UsageLimit::Limited(limit)
zed_llm_client::UsageLimit::Limited(limit) => { }
let limit = if plan == zed_llm_client::Plan::ZedProTrial zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
&& feature_flags };
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
};
zed_llm_client::UsageLimit::Limited(limit) proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
} }
zed_llm_client::UsageLimit::Unlimited => { zed_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
} }
}; }),
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(
proto::usage_limit::Limited {
limit: limit as u32,
},
)
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(
proto::usage_limit::Unlimited {},
)
}
}),
}),
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(
proto::usage_limit::Limited {
limit: limit as u32,
},
)
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(
proto::usage_limit::Unlimited {},
)
}
}),
}),
}
}), }),
}, edit_predictions_usage_amount: usage.edit_predictions as u32,
) edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}),
})
}
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
let db = session.db().await;
let update_user_plan = make_update_user_plan_message(
&db.0,
session.app_state.llm_db.clone(),
user_id,
session.is_staff(),
)
.await?;
session
.peer
.send(session.connection_id, update_user_plan)
.trace_err(); .trace_err();
Ok(()) Ok(())