collab: Push down plan changes to the client (#30447)
This PR makes it so we push down plan updates from the server when the user's subscription changes. Release Notes: - N/A
This commit is contained in:
parent
79ba22673b
commit
daa777440d
2 changed files with 143 additions and 111 deletions
|
@ -1137,6 +1137,12 @@ async fn handle_customer_subscription_event(
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When the user's subscription changes, push down any changes to their plan.
|
||||||
|
rpc_server
|
||||||
|
.update_plan_for_user(billing_customer.user_id)
|
||||||
|
.await
|
||||||
|
.trace_err();
|
||||||
|
|
||||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||||
// to either grant/revoke access.
|
// to either grant/revoke access.
|
||||||
rpc_server
|
rpc_server
|
||||||
|
|
|
@ -2,6 +2,7 @@ mod connection_pool;
|
||||||
|
|
||||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
|
use crate::llm::db::LlmDatabase;
|
||||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
|
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
|
||||||
use crate::{
|
use crate::{
|
||||||
AppState, Error, Result, auth,
|
AppState, Error, Result, auth,
|
||||||
|
@ -67,7 +68,7 @@ use std::{
|
||||||
time::{Duration, Instant},
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
use tokio::sync::{MutexGuard, Semaphore, watch};
|
use tokio::sync::{Semaphore, watch};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tracing::{
|
use tracing::{
|
||||||
Instrument,
|
Instrument,
|
||||||
|
@ -166,29 +167,6 @@ impl Session {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
|
|
||||||
if self.is_staff() {
|
|
||||||
return Ok(proto::Plan::ZedPro);
|
|
||||||
}
|
|
||||||
|
|
||||||
let user_id = self.user_id();
|
|
||||||
|
|
||||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
|
||||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
|
||||||
|
|
||||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
|
||||||
match subscription_kind {
|
|
||||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
|
||||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
|
||||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
proto::Plan::Free
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(plan)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn user_id(&self) -> UserId {
|
fn user_id(&self) -> UserId {
|
||||||
match &self.principal {
|
match &self.principal {
|
||||||
Principal::User(user) => user.id,
|
Principal::User(user) => user.id,
|
||||||
|
@ -953,6 +931,32 @@ impl Server {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||||
|
let user = self
|
||||||
|
.app_state
|
||||||
|
.db
|
||||||
|
.get_user_by_id(user_id)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("user not found"))?;
|
||||||
|
|
||||||
|
let update_user_plan = make_update_user_plan_message(
|
||||||
|
&self.app_state.db,
|
||||||
|
self.app_state.llm_db.clone(),
|
||||||
|
user_id,
|
||||||
|
user.admin,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let pool = self.connection_pool.lock();
|
||||||
|
for connection_id in pool.user_connection_ids(user_id) {
|
||||||
|
self.peer
|
||||||
|
.send(connection_id, update_user_plan.clone())
|
||||||
|
.trace_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||||
let pool = self.connection_pool.lock();
|
let pool = self.connection_pool.lock();
|
||||||
for connection_id in pool.user_connection_ids(user_id) {
|
for connection_id in pool.user_connection_ids(user_id) {
|
||||||
|
@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
||||||
version.0.minor() < 139
|
version.0.minor() < 139
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
|
||||||
let db = session.db().await;
|
if is_staff {
|
||||||
|
return Ok(proto::Plan::ZedPro);
|
||||||
|
}
|
||||||
|
|
||||||
|
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||||
|
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||||
|
|
||||||
|
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||||
|
match subscription_kind {
|
||||||
|
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||||
|
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||||
|
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
proto::Plan::Free
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(plan)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn make_update_user_plan_message(
|
||||||
|
db: &Arc<Database>,
|
||||||
|
llm_db: Option<Arc<LlmDatabase>>,
|
||||||
|
user_id: UserId,
|
||||||
|
is_staff: bool,
|
||||||
|
) -> Result<proto::UpdateUserPlan> {
|
||||||
let feature_flags = db.get_user_flags(user_id).await?;
|
let feature_flags = db.get_user_flags(user_id).await?;
|
||||||
let plan = session.current_plan(&db).await?;
|
let plan = current_plan(db, user_id, is_staff).await?;
|
||||||
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
|
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
|
||||||
let billing_preferences = db.get_billing_preferences(user_id).await?;
|
let billing_preferences = db.get_billing_preferences(user_id).await?;
|
||||||
|
|
||||||
let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
|
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
|
||||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||||
|
|
||||||
let subscription_period = crate::db::billing_subscription::Model::current_period(
|
let subscription_period =
|
||||||
subscription,
|
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
|
||||||
session.is_staff(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
|
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
|
||||||
llm_db
|
llm_db
|
||||||
|
@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||||
(None, None)
|
(None, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
session
|
Ok(proto::UpdateUserPlan {
|
||||||
.peer
|
plan: plan.into(),
|
||||||
.send(
|
trial_started_at: billing_customer
|
||||||
session.connection_id,
|
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||||
proto::UpdateUserPlan {
|
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
|
||||||
plan: plan.into(),
|
is_usage_based_billing_enabled: if is_staff {
|
||||||
trial_started_at: billing_customer
|
Some(true)
|
||||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
} else {
|
||||||
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
|
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
|
||||||
is_usage_based_billing_enabled: if session.is_staff() {
|
},
|
||||||
Some(true)
|
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
||||||
} else {
|
proto::SubscriptionPeriod {
|
||||||
billing_preferences
|
started_at: started_at.timestamp() as u64,
|
||||||
.map(|preferences| preferences.model_request_overages_enabled)
|
ended_at: ended_at.timestamp() as u64,
|
||||||
},
|
}
|
||||||
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
}),
|
||||||
proto::SubscriptionPeriod {
|
usage: usage.map(|usage| {
|
||||||
started_at: started_at.timestamp() as u64,
|
let plan = match plan {
|
||||||
ended_at: ended_at.timestamp() as u64,
|
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||||
}
|
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||||
}),
|
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||||
usage: usage.map(|usage| {
|
};
|
||||||
let plan = match plan {
|
|
||||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
let model_requests_limit = match plan.model_requests_limit() {
|
||||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
||||||
|
&& feature_flags
|
||||||
|
.iter()
|
||||||
|
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||||
|
{
|
||||||
|
1_000
|
||||||
|
} else {
|
||||||
|
limit
|
||||||
};
|
};
|
||||||
|
|
||||||
let model_requests_limit = match plan.model_requests_limit() {
|
zed_llm_client::UsageLimit::Limited(limit)
|
||||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
}
|
||||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
|
||||||
&& feature_flags
|
};
|
||||||
.iter()
|
|
||||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
|
||||||
{
|
|
||||||
1_000
|
|
||||||
} else {
|
|
||||||
limit
|
|
||||||
};
|
|
||||||
|
|
||||||
zed_llm_client::UsageLimit::Limited(limit)
|
proto::SubscriptionUsage {
|
||||||
|
model_requests_usage_amount: usage.model_requests as u32,
|
||||||
|
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||||
|
variant: Some(match model_requests_limit {
|
||||||
|
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||||
|
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||||
|
limit: limit as u32,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
zed_llm_client::UsageLimit::Unlimited => {
|
zed_llm_client::UsageLimit::Unlimited => {
|
||||||
zed_llm_client::UsageLimit::Unlimited
|
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||||
}
|
}
|
||||||
};
|
}),
|
||||||
|
|
||||||
proto::SubscriptionUsage {
|
|
||||||
model_requests_usage_amount: usage.model_requests as u32,
|
|
||||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
|
||||||
variant: Some(match model_requests_limit {
|
|
||||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
|
||||||
proto::usage_limit::Variant::Limited(
|
|
||||||
proto::usage_limit::Limited {
|
|
||||||
limit: limit as u32,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
zed_llm_client::UsageLimit::Unlimited => {
|
|
||||||
proto::usage_limit::Variant::Unlimited(
|
|
||||||
proto::usage_limit::Unlimited {},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
|
||||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
|
||||||
variant: Some(match plan.edit_predictions_limit() {
|
|
||||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
|
||||||
proto::usage_limit::Variant::Limited(
|
|
||||||
proto::usage_limit::Limited {
|
|
||||||
limit: limit as u32,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
zed_llm_client::UsageLimit::Unlimited => {
|
|
||||||
proto::usage_limit::Variant::Unlimited(
|
|
||||||
proto::usage_limit::Unlimited {},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
},
|
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||||
)
|
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||||
|
variant: Some(match plan.edit_predictions_limit() {
|
||||||
|
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||||
|
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||||
|
limit: limit as u32,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
zed_llm_client::UsageLimit::Unlimited => {
|
||||||
|
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||||
|
let db = session.db().await;
|
||||||
|
|
||||||
|
let update_user_plan = make_update_user_plan_message(
|
||||||
|
&db.0,
|
||||||
|
session.app_state.llm_db.clone(),
|
||||||
|
user_id,
|
||||||
|
session.is_staff(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
session
|
||||||
|
.peer
|
||||||
|
.send(session.connection_id, update_user_plan)
|
||||||
.trace_err();
|
.trace_err();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue