From cae548a50dc190b7a1d7a16ee48d636842a21c35 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 9 Oct 2024 19:15:38 -0400 Subject: [PATCH] collab: Fix issues with syncing LLM usage to Stripe (#18970) This PR fixes some issues with our previous approach to synching LLM usage over to Stripe. We now have a separate LLM access price in Stripe that is a marker price to allow us to create the initial subscription with that as its subscription item We then dynamically set the LLM usage price during the reconciliation sync based on the usage for the current month. Release Notes: - N/A --------- Co-authored-by: Antonio Co-authored-by: Richard --- crates/collab/src/api/billing.rs | 51 ++++++++++++++++---------- crates/collab/src/lib.rs | 2 + crates/collab/src/main.rs | 3 +- crates/collab/src/tests/test_server.rs | 1 + 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index dca5a772f4..838dea1981 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -197,10 +197,10 @@ async fn create_billing_subscription( .await? .ok_or_else(|| anyhow!("user not found"))?; - let Some((stripe_client, stripe_price_id)) = app + let Some((stripe_client, stripe_access_price_id)) = app .stripe_client .clone() - .zip(app.config.stripe_llm_usage_price_id.clone()) + .zip(app.config.stripe_llm_access_price_id.clone()) else { log::error!("failed to retrieve Stripe client or price ID"); Err(Error::http( @@ -232,8 +232,8 @@ async fn create_billing_subscription( params.customer = Some(customer_id); params.client_reference_id = Some(user.github_login.as_str()); params.line_items = Some(vec![CreateCheckoutSessionLineItems { - price: Some(stripe_price_id.to_string()), - quantity: Some(0), + price: Some(stripe_access_price_id.to_string()), + quantity: Some(1), ..Default::default() }]); let success_url = format!("{}/account", app.config.zed_dot_dev_url()); @@ -787,22 +787,33 @@ async fn update_stripe_subscription( monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT); let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil(); - Subscription::update( - stripe_client, - &subscription_id, - stripe::UpdateSubscription { - items: Some(vec![stripe::UpdateSubscriptionItems { - // TODO: Do we need to send up the `id` if a subscription item - // with this price already exists, or will Stripe take care of - // it? - id: None, - price: Some(stripe_llm_usage_price_id.to_string()), - quantity: Some(new_quantity as u64), - ..Default::default() - }]), + let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?; + + let mut update_params = stripe::UpdateSubscription { + proration_behavior: Some( + stripe::generated::billing::subscription::SubscriptionProrationBehavior::None, + ), + ..Default::default() + }; + + if let Some(existing_item) = current_subscription.items.data.iter().find(|item| { + item.price.as_ref().map_or(false, |price| { + price.id == stripe_llm_usage_price_id.as_ref() + }) + }) { + update_params.items = Some(vec![stripe::UpdateSubscriptionItems { + id: Some(existing_item.id.to_string()), + quantity: Some(new_quantity as u64), ..Default::default() - }, - ) - .await?; + }]); + } else { + update_params.items = Some(vec![stripe::UpdateSubscriptionItems { + price: Some(stripe_llm_usage_price_id.to_string()), + quantity: Some(new_quantity as u64), + ..Default::default() + }]); + } + + Subscription::update(stripe_client, &subscription_id, update_params).await?; Ok(()) } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index a6141abb88..3896926f43 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -176,6 +176,7 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, + pub stripe_llm_access_price_id: Option>, pub stripe_llm_usage_price_id: Option>, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, @@ -237,6 +238,7 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_llm_access_price_id: None, stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index bbbd4e562c..bd227f17c7 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -155,7 +155,8 @@ async fn main() -> Result<()> { .await .trace_err(); - if let Some(llm_db) = llm_db { + if let Some(mut llm_db) = llm_db { + llm_db.initialize().await?; sync_llm_usage_with_stripe_periodically(state.clone(), llm_db); } diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 55bc279c8e..683a53a2f5 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -677,6 +677,7 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, + stripe_llm_access_price_id: None, stripe_llm_usage_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None,