diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 8f1433a26f..3b0f5396a7 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -11,7 +11,9 @@ use crate::{ db::{User, UserId}, rpc, }; +use ::rpc::proto; use anyhow::Context as _; +use axum::extract; use axum::{ Extension, Json, Router, body::Body, @@ -23,6 +25,7 @@ use axum::{ routing::{get, post}, }; use axum_extra::response::ErasedJson; +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::sync::{Arc, OnceLock}; use tower::ServiceBuilder; @@ -101,6 +104,7 @@ pub fn routes(rpc_server: Arc) -> Router<(), Body> { .route("/users/look_up", get(look_up_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens)) + .route("/users/:id/update_plan", post(update_plan)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) .merge(billing::router()) .merge(contributors::router()) @@ -347,3 +351,78 @@ async fn refresh_llm_tokens( Ok(Json(RefreshLlmTokensResponse {})) } + +#[derive(Debug, Serialize, Deserialize)] +struct UpdatePlanBody { + pub plan: zed_llm_client::Plan, + pub subscription_period: SubscriptionPeriod, + pub usage: zed_llm_client::CurrentUsage, + pub trial_started_at: Option>, + pub is_usage_based_billing_enabled: bool, + pub is_account_too_young: bool, + pub has_overdue_invoices: bool, +} + +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +struct SubscriptionPeriod { + pub started_at: DateTime, + pub ended_at: DateTime, +} + +#[derive(Serialize)] +struct UpdatePlanResponse {} + +async fn update_plan( + Path(user_id): Path, + Extension(rpc_server): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let plan = match body.plan { + zed_llm_client::Plan::ZedFree => proto::Plan::Free, + zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro, + zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial, + }; + + let update_user_plan = proto::UpdateUserPlan { + plan: plan.into(), + trial_started_at: body + .trial_started_at + .map(|trial_started_at| trial_started_at.timestamp() as u64), + is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled), + usage: Some(proto::SubscriptionUsage { + model_requests_usage_amount: body.usage.model_requests.used, + model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)), + edit_predictions_usage_amount: body.usage.edit_predictions.used, + edit_predictions_usage_limit: Some(usage_limit_to_proto( + body.usage.edit_predictions.limit, + )), + }), + subscription_period: Some(proto::SubscriptionPeriod { + started_at: body.subscription_period.started_at.timestamp() as u64, + ended_at: body.subscription_period.ended_at.timestamp() as u64, + }), + account_too_young: Some(body.is_account_too_young), + has_overdue_invoices: Some(body.has_overdue_invoices), + }; + + rpc_server + .update_plan_for_user(user_id, update_user_plan) + .await?; + + Ok(Json(UpdatePlanResponse {})) +} + +fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit { + proto::UsageLimit { + variant: Some(match limit { + zed_llm_client::UsageLimit::Limited(limit) => { + proto::usage_limit::Variant::Limited(proto::usage_limit::Limited { + limit: limit as u32, + }) + } + zed_llm_client::UsageLimit::Unlimited => { + proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {}) + } + }), + } +} diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index d6e42ad2fb..bd7b99b3eb 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -785,7 +785,7 @@ async fn handle_customer_subscription_event( // When the user's subscription changes, push down any changes to their plan. rpc_server - .update_plan_for_user(billing_customer.user_id) + .update_plan_for_user_legacy(billing_customer.user_id) .await .trace_err(); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 924784109b..0735b08e89 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1002,7 +1002,26 @@ impl Server { Ok(()) } - pub async fn update_plan_for_user(self: &Arc, user_id: UserId) -> Result<()> { + pub async fn update_plan_for_user( + self: &Arc, + user_id: UserId, + update_user_plan: proto::UpdateUserPlan, + ) -> Result<()> { + let pool = self.connection_pool.lock(); + for connection_id in pool.user_connection_ids(user_id) { + self.peer + .send(connection_id, update_user_plan.clone()) + .trace_err(); + } + + Ok(()) + } + + /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan` + /// message on the Collab server. + /// + /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint. + pub async fn update_plan_for_user_legacy(self: &Arc, user_id: UserId) -> Result<()> { let user = self .app_state .db @@ -1018,14 +1037,7 @@ impl Server { ) .await?; - let pool = self.connection_pool.lock(); - for connection_id in pool.user_connection_ids(user_id) { - self.peer - .send(connection_id, update_user_plan.clone()) - .trace_err(); - } - - Ok(()) + self.update_plan_for_user(user_id, update_user_plan).await } pub async fn refresh_llm_tokens_for_user(self: &Arc, user_id: UserId) {