collab: Add POST /users/:id/update_plan
endpoint (#34953)
This PR adds a new `POST /users/:id/update_plan` endpoint to Collab to allow Cloud to push down plan updates to users. Release Notes: - N/A
This commit is contained in:
parent
326ab5fa3f
commit
14171e0721
3 changed files with 101 additions and 10 deletions
|
@ -11,7 +11,9 @@ use crate::{
|
||||||
db::{User, UserId},
|
db::{User, UserId},
|
||||||
rpc,
|
rpc,
|
||||||
};
|
};
|
||||||
|
use ::rpc::proto;
|
||||||
use anyhow::Context as _;
|
use anyhow::Context as _;
|
||||||
|
use axum::extract;
|
||||||
use axum::{
|
use axum::{
|
||||||
Extension, Json, Router,
|
Extension, Json, Router,
|
||||||
body::Body,
|
body::Body,
|
||||||
|
@ -23,6 +25,7 @@ use axum::{
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
};
|
};
|
||||||
use axum_extra::response::ErasedJson;
|
use axum_extra::response::ErasedJson;
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::{Arc, OnceLock};
|
use std::sync::{Arc, OnceLock};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
|
@ -101,6 +104,7 @@ pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||||
.route("/users/look_up", get(look_up_user))
|
.route("/users/look_up", get(look_up_user))
|
||||||
.route("/users/:id/access_tokens", post(create_access_token))
|
.route("/users/:id/access_tokens", post(create_access_token))
|
||||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||||
|
.route("/users/:id/update_plan", post(update_plan))
|
||||||
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
||||||
.merge(billing::router())
|
.merge(billing::router())
|
||||||
.merge(contributors::router())
|
.merge(contributors::router())
|
||||||
|
@ -347,3 +351,78 @@ async fn refresh_llm_tokens(
|
||||||
|
|
||||||
Ok(Json(RefreshLlmTokensResponse {}))
|
Ok(Json(RefreshLlmTokensResponse {}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct UpdatePlanBody {
|
||||||
|
pub plan: zed_llm_client::Plan,
|
||||||
|
pub subscription_period: SubscriptionPeriod,
|
||||||
|
pub usage: zed_llm_client::CurrentUsage,
|
||||||
|
pub trial_started_at: Option<DateTime<Utc>>,
|
||||||
|
pub is_usage_based_billing_enabled: bool,
|
||||||
|
pub is_account_too_young: bool,
|
||||||
|
pub has_overdue_invoices: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
struct SubscriptionPeriod {
|
||||||
|
pub started_at: DateTime<Utc>,
|
||||||
|
pub ended_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct UpdatePlanResponse {}
|
||||||
|
|
||||||
|
async fn update_plan(
|
||||||
|
Path(user_id): Path<UserId>,
|
||||||
|
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||||
|
extract::Json(body): extract::Json<UpdatePlanBody>,
|
||||||
|
) -> Result<Json<UpdatePlanResponse>> {
|
||||||
|
let plan = match body.plan {
|
||||||
|
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||||
|
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||||
|
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||||
|
};
|
||||||
|
|
||||||
|
let update_user_plan = proto::UpdateUserPlan {
|
||||||
|
plan: plan.into(),
|
||||||
|
trial_started_at: body
|
||||||
|
.trial_started_at
|
||||||
|
.map(|trial_started_at| trial_started_at.timestamp() as u64),
|
||||||
|
is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
|
||||||
|
usage: Some(proto::SubscriptionUsage {
|
||||||
|
model_requests_usage_amount: body.usage.model_requests.used,
|
||||||
|
model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
|
||||||
|
edit_predictions_usage_amount: body.usage.edit_predictions.used,
|
||||||
|
edit_predictions_usage_limit: Some(usage_limit_to_proto(
|
||||||
|
body.usage.edit_predictions.limit,
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
subscription_period: Some(proto::SubscriptionPeriod {
|
||||||
|
started_at: body.subscription_period.started_at.timestamp() as u64,
|
||||||
|
ended_at: body.subscription_period.ended_at.timestamp() as u64,
|
||||||
|
}),
|
||||||
|
account_too_young: Some(body.is_account_too_young),
|
||||||
|
has_overdue_invoices: Some(body.has_overdue_invoices),
|
||||||
|
};
|
||||||
|
|
||||||
|
rpc_server
|
||||||
|
.update_plan_for_user(user_id, update_user_plan)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(UpdatePlanResponse {}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||||
|
proto::UsageLimit {
|
||||||
|
variant: Some(match limit {
|
||||||
|
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||||
|
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||||
|
limit: limit as u32,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
zed_llm_client::UsageLimit::Unlimited => {
|
||||||
|
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -785,7 +785,7 @@ async fn handle_customer_subscription_event(
|
||||||
|
|
||||||
// When the user's subscription changes, push down any changes to their plan.
|
// When the user's subscription changes, push down any changes to their plan.
|
||||||
rpc_server
|
rpc_server
|
||||||
.update_plan_for_user(billing_customer.user_id)
|
.update_plan_for_user_legacy(billing_customer.user_id)
|
||||||
.await
|
.await
|
||||||
.trace_err();
|
.trace_err();
|
||||||
|
|
||||||
|
|
|
@ -1002,7 +1002,26 @@ impl Server {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
pub async fn update_plan_for_user(
|
||||||
|
self: &Arc<Self>,
|
||||||
|
user_id: UserId,
|
||||||
|
update_user_plan: proto::UpdateUserPlan,
|
||||||
|
) -> Result<()> {
|
||||||
|
let pool = self.connection_pool.lock();
|
||||||
|
for connection_id in pool.user_connection_ids(user_id) {
|
||||||
|
self.peer
|
||||||
|
.send(connection_id, update_user_plan.clone())
|
||||||
|
.trace_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
|
||||||
|
/// message on the Collab server.
|
||||||
|
///
|
||||||
|
/// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
|
||||||
|
pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||||
let user = self
|
let user = self
|
||||||
.app_state
|
.app_state
|
||||||
.db
|
.db
|
||||||
|
@ -1018,14 +1037,7 @@ impl Server {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let pool = self.connection_pool.lock();
|
self.update_plan_for_user(user_id, update_user_plan).await
|
||||||
for connection_id in pool.user_connection_ids(user_id) {
|
|
||||||
self.peer
|
|
||||||
.send(connection_id, update_user_plan.clone())
|
|
||||||
.trace_err();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue