collab: Add POST /users/:id/update_plan endpoint (#34953)

This PR adds a new `POST /users/:id/update_plan` endpoint to Collab to
allow Cloud to push down plan updates to users.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-07-23 10:30:08 -04:00 committed by GitHub
parent 326ab5fa3f
commit 14171e0721
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 10 deletions

View file

@ -11,7 +11,9 @@ use crate::{
db::{User, UserId}, db::{User, UserId},
rpc, rpc,
}; };
use ::rpc::proto;
use anyhow::Context as _; use anyhow::Context as _;
use axum::extract;
use axum::{ use axum::{
Extension, Json, Router, Extension, Json, Router,
body::Body, body::Body,
@ -23,6 +25,7 @@ use axum::{
routing::{get, post}, routing::{get, post},
}; };
use axum_extra::response::ErasedJson; use axum_extra::response::ErasedJson;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -101,6 +104,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
.route("/users/look_up", get(look_up_user)) .route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/access_tokens", post(create_access_token))
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
.route("/users/:id/update_plan", post(update_plan))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(billing::router()) .merge(billing::router())
.merge(contributors::router()) .merge(contributors::router())
@ -347,3 +351,78 @@ async fn refresh_llm_tokens(
Ok(Json(RefreshLlmTokensResponse {})) Ok(Json(RefreshLlmTokensResponse {}))
} }
#[derive(Debug, Serialize, Deserialize)]
struct UpdatePlanBody {
pub plan: zed_llm_client::Plan,
pub subscription_period: SubscriptionPeriod,
pub usage: zed_llm_client::CurrentUsage,
pub trial_started_at: Option<DateTime<Utc>>,
pub is_usage_based_billing_enabled: bool,
pub is_account_too_young: bool,
pub has_overdue_invoices: bool,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
struct SubscriptionPeriod {
pub started_at: DateTime<Utc>,
pub ended_at: DateTime<Utc>,
}
#[derive(Serialize)]
struct UpdatePlanResponse {}
async fn update_plan(
Path(user_id): Path<UserId>,
Extension(rpc_server): Extension<Arc<rpc::Server>>,
extract::Json(body): extract::Json<UpdatePlanBody>,
) -> Result<Json<UpdatePlanResponse>> {
let plan = match body.plan {
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
let update_user_plan = proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: body
.trial_started_at
.map(|trial_started_at| trial_started_at.timestamp() as u64),
is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
usage: Some(proto::SubscriptionUsage {
model_requests_usage_amount: body.usage.model_requests.used,
model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
edit_predictions_usage_amount: body.usage.edit_predictions.used,
edit_predictions_usage_limit: Some(usage_limit_to_proto(
body.usage.edit_predictions.limit,
)),
}),
subscription_period: Some(proto::SubscriptionPeriod {
started_at: body.subscription_period.started_at.timestamp() as u64,
ended_at: body.subscription_period.ended_at.timestamp() as u64,
}),
account_too_young: Some(body.is_account_too_young),
has_overdue_invoices: Some(body.has_overdue_invoices),
};
rpc_server
.update_plan_for_user(user_id, update_user_plan)
.await?;
Ok(Json(UpdatePlanResponse {}))
}
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
proto::UsageLimit {
variant: Some(match limit {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}
}

View file

@ -785,7 +785,7 @@ async fn handle_customer_subscription_event(
// When the user's subscription changes, push down any changes to their plan. // When the user's subscription changes, push down any changes to their plan.
rpc_server rpc_server
.update_plan_for_user(billing_customer.user_id) .update_plan_for_user_legacy(billing_customer.user_id)
.await .await
.trace_err(); .trace_err();

View file

@ -1002,7 +1002,26 @@ impl Server {
Ok(()) Ok(())
} }
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> { pub async fn update_plan_for_user(
self: &Arc<Self>,
user_id: UserId,
update_user_plan: proto::UpdateUserPlan,
) -> Result<()> {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, update_user_plan.clone())
.trace_err();
}
Ok(())
}
/// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
/// message on the Collab server.
///
/// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
let user = self let user = self
.app_state .app_state
.db .db
@ -1018,14 +1037,7 @@ impl Server {
) )
.await?; .await?;
let pool = self.connection_pool.lock(); self.update_plan_for_user(user_id, update_user_plan).await
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, update_user_plan.clone())
.trace_err();
}
Ok(())
} }
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) { pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {