collab: Refresh the user's LLM token when their subscription changes (#19281)

This PR makes it so collab will trigger a refresh for a user's LLM token
whenever their subscription changes.

This allows us to proactively push down changes to their subscription.

In order to facilitate this, the Stripe event processing has been moved
from the `api` service to the `collab` service in order to access the
RPC server.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-10-16 10:58:28 -04:00 committed by GitHub
parent 9d944d0662
commit 598939d186
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 5 deletions

View file

@ -20,7 +20,7 @@ use stripe::{
use util::ResultExt;
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
use crate::rpc::ResultExt as _;
use crate::rpc::{ResultExt as _, Server};
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
@ -404,7 +404,7 @@ const NUMBER_OF_ALREADY_PROCESSED_PAGES_BEFORE_WE_STOP: usize = 4;
/// Polls the Stripe events API periodically to reconcile the records in our
/// database with the data in Stripe.
pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
pub fn poll_stripe_events_periodically(app: Arc<AppState>, rpc_server: Arc<Server>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
@ -415,7 +415,9 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
let executor = executor.clone();
async move {
loop {
poll_stripe_events(&app, &stripe_client).await.log_err();
poll_stripe_events(&app, &rpc_server, &stripe_client)
.await
.log_err();
executor.sleep(POLL_EVENTS_INTERVAL).await;
}
@ -425,6 +427,7 @@ pub fn poll_stripe_events_periodically(app: Arc<AppState>) {
async fn poll_stripe_events(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
fn event_type_to_string(event_type: EventType) -> String {
@ -541,7 +544,7 @@ async fn poll_stripe_events(
| EventType::CustomerSubscriptionPaused
| EventType::CustomerSubscriptionResumed
| EventType::CustomerSubscriptionDeleted => {
handle_customer_subscription_event(app, stripe_client, event).await
handle_customer_subscription_event(app, rpc_server, stripe_client, event).await
}
_ => Ok(()),
};
@ -609,6 +612,7 @@ async fn handle_customer_event(
async fn handle_customer_subscription_event(
app: &Arc<AppState>,
rpc_server: &Arc<Server>,
stripe_client: &stripe::Client,
event: stripe::Event,
) -> anyhow::Result<()> {
@ -654,6 +658,12 @@ async fn handle_customer_subscription_event(
.await?;
}
// When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access.
rpc_server
.refresh_llm_tokens_for_user(billing_customer.user_id)
.await;
Ok(())
}

View file

@ -132,6 +132,8 @@ async fn main() -> Result<()> {
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
poll_stripe_events_periodically(state.clone(), rpc_server.clone());
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));
@ -140,7 +142,6 @@ async fn main() -> Result<()> {
}
if mode.is_api() {
poll_stripe_events_periodically(state.clone());
fetch_extensions_from_blob_store_periodically(state.clone());
spawn_user_backfiller(state.clone());