collab: Refresh the user's LLM token when their subscription changes (#19281)
This PR makes it so collab will trigger a refresh for a user's LLM token whenever their subscription changes. This allows us to proactively push down changes to their subscription. In order to facilitate this, the Stripe event processing has been moved from the `api` service to the `collab` service in order to access the RPC server. Release Notes: - N/A
This commit is contained in:
parent
9d944d0662
commit
598939d186
2 changed files with 16 additions and 5 deletions
|
@ -20,7 +20,7 @@ use stripe::{
|
|||
use util::ResultExt;
|
||||
|
||||
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
|
||||
use crate::rpc::ResultExt as _;
|
||||
use crate::rpc::{ResultExt as _, Server};
|
||||
use crate::{
|
||||
db::{
|
||||
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
|
||||
|
@ -404,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
|
|||
|
||||
/// Polls the Stripe events API periodically to reconcile the records in our
|
||||
/// database with the data in Stripe.
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
||||
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
return;
|
||||
|
@ -415,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
|||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
poll_stripe_events(&app, &stripe_client).await.log_err();
|
||||
poll_stripe_events(&app, &rpc_server, &stripe_client)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
executor.sleep(POLL_EVENTS_INTERVAL).await;
|
||||
}
|
||||
|
@ -425,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
|
|||
|
||||
async fn poll_stripe_events(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
|
@ -541,7 +544,7 @@ async fn poll_stripe_events(
|
|||
| EventType::CustomerSubscriptionPaused
|
||||
| EventType::CustomerSubscriptionResumed
|
||||
| EventType::CustomerSubscriptionDeleted => {
|
||||
handle_customer_subscription_event(app, stripe_client, event).await
|
||||
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
|
||||
}
|
||||
_ => Ok(()),
|
||||
};
|
||||
|
@ -609,6 +612,7 @@ async fn handle_customer_event(
|
|||
|
||||
async fn handle_customer_subscription_event(
|
||||
app: &Arc<AppState>,
|
||||
rpc_server: &Arc<Server>,
|
||||
stripe_client: &stripe::Client,
|
||||
event: stripe::Event,
|
||||
) -> anyhow::Result<()> {
|
||||
|
@ -654,6 +658,12 @@ async fn handle_customer_subscription_event(
|
|||
.await?;
|
||||
}
|
||||
|
||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||
// to either grant/revoke access.
|
||||
rpc_server
|
||||
.refresh_llm_tokens_for_user(billing_customer.user_id)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -132,6 +132,8 @@ async fn main() -> Result<()> {
|
|||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||
rpc_server.start().await?;
|
||||
|
||||
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::api::routes(rpc_server.clone()))
|
||||
.merge(collab::rpc::routes(rpc_server.clone()));
|
||||
|
@ -140,7 +142,6 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
if mode.is_api() {
|
||||
poll_stripe_events_periodically(state.clone());
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue