diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 624dfa2fc4..d04fcb9eee 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -362,12 +362,7 @@ async fn create_billing_subscription( let checkout_session_url = match body.product { Some(ProductCode::ZedPro) => { stripe_billing - .checkout_with_price( - app.config.zed_pro_price_id()?, - customer_id, - &user.github_login, - &success_url, - ) + .checkout_with_zed_pro(customer_id, &user.github_login, &success_url) .await? } Some(ProductCode::ZedProTrial) => { @@ -384,7 +379,6 @@ async fn create_billing_subscription( stripe_billing .checkout_with_zed_pro_trial( - app.config.zed_pro_price_id()?, customer_id, &user.github_login, feature_flags, @@ -458,6 +452,14 @@ async fn manage_billing_subscription( ))? }; + let Some(stripe_billing) = app.stripe_billing.clone() else { + log::error!("failed to retrieve Stripe billing object"); + Err(Error::http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + let customer = app .db .get_billing_customer_by_user_id(user.id) @@ -508,8 +510,8 @@ async fn manage_billing_subscription( let flow = match body.intent { ManageSubscriptionIntent::ManageSubscription => None, ManageSubscriptionIntent::UpgradeToPro => { - let zed_pro_price_id = app.config.zed_pro_price_id()?; - let zed_free_price_id = app.config.zed_free_price_id()?; + let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?; + let zed_free_price_id = stripe_billing.zed_free_price_id().await?; let stripe_subscription = Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?; @@ -856,9 +858,11 @@ async fn handle_customer_subscription_event( log::info!("handling Stripe {} event: {}", event.type_, event.id); - let subscription_kind = maybe!({ - let zed_pro_price_id = app.config.zed_pro_price_id().ok()?; - let zed_free_price_id = app.config.zed_free_price_id().ok()?; + let subscription_kind = maybe!(async { + let stripe_billing = app.stripe_billing.clone()?; + + let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?; + let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?; subscription.items.data.iter().find_map(|item| { let price = item.price.as_ref()?; @@ -875,7 +879,8 @@ async fn handle_customer_subscription_event( None } }) - }); + }) + .await; let billing_customer = find_or_create_billing_customer(app, stripe_client, subscription.customer) @@ -1398,13 +1403,13 @@ async fn sync_model_request_usage_with_stripe( .await?; let claude_3_5_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-5-sonnet-requests") + .find_price_id_by_lookup_key("claude-3-5-sonnet-requests") .await?; let claude_3_7_sonnet = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests") + .find_price_id_by_lookup_key("claude-3-7-sonnet-requests") .await?; let claude_3_7_sonnet_max = stripe_billing - .find_price_by_lookup_key("claude-3-7-sonnet-requests-max") + .find_price_id_by_lookup_key("claude-3-7-sonnet-requests-max") .await?; for (usage_meter, usage) in usage_meters { @@ -1430,11 +1435,11 @@ async fn sync_model_request_usage_with_stripe( let model = llm_db.model_by_id(usage_meter.model_id)?; let (price_id, meter_event_name) = match model.name.as_str() { - "claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"), + "claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"), "claude-3-7-sonnet" => match usage_meter.mode { - CompletionMode::Normal => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"), + CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"), CompletionMode::Max => { - (&claude_3_7_sonnet_max.id, "claude_3_7_sonnet/requests/max") + (&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max") } }, model_name => { diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 1d95cbaab1..0b16b5cf6b 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -180,9 +180,6 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub stripe_api_key: Option, - pub stripe_zed_pro_price_id: Option, - pub stripe_zed_pro_trial_price_id: Option, - pub stripe_zed_free_price_id: Option, pub supermaven_admin_api_key: Option>, pub user_backfiller_github_access_token: Option>, } @@ -201,22 +198,6 @@ impl Config { } } - pub fn zed_pro_price_id(&self) -> anyhow::Result { - Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref()) - } - - pub fn zed_free_price_id(&self) -> anyhow::Result { - Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref()) - } - - fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result { - use std::str::FromStr as _; - - let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?; - - Ok(stripe::PriceId::from_str(price_id)?) - } - #[cfg(test)] pub fn test() -> Self { Self { @@ -254,9 +235,6 @@ impl Config { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_zed_pro_price_id: None, - stripe_zed_pro_trial_price_id: None, - stripe_zed_free_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None, diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 04876708e5..4de9c001c0 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -81,13 +81,21 @@ impl StripeBilling { Ok(()) } - pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result { + pub async fn zed_pro_price_id(&self) -> Result { + self.find_price_id_by_lookup_key("zed-pro").await + } + + pub async fn zed_free_price_id(&self) -> Result { + self.find_price_id_by_lookup_key("zed-free").await + } + + pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result { self.state .read() .await .prices_by_lookup_key .get(lookup_key) - .cloned() + .map(|price| price.id.clone()) .ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}"))) } @@ -463,19 +471,20 @@ impl StripeBilling { Ok(session.url.context("no checkout session URL")?) } - pub async fn checkout_with_price( + pub async fn checkout_with_zed_pro( &self, - price_id: PriceId, customer_id: stripe::CustomerId, github_login: &str, success_url: &str, ) -> Result { + let zed_pro_price_id = self.zed_pro_price_id().await?; + let mut params = stripe::CreateCheckoutSession::new(); params.mode = Some(stripe::CheckoutSessionMode::Subscription); params.customer = Some(customer_id); params.client_reference_id = Some(github_login); params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { - price: Some(price_id.to_string()), + price: Some(zed_pro_price_id.to_string()), quantity: Some(1), ..Default::default() }]); @@ -487,12 +496,13 @@ impl StripeBilling { pub async fn checkout_with_zed_pro_trial( &self, - zed_pro_price_id: PriceId, customer_id: stripe::CustomerId, github_login: &str, feature_flags: Vec, success_url: &str, ) -> Result { + let zed_pro_price_id = self.zed_pro_price_id().await?; + let eligible_for_extended_trial = feature_flags .iter() .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG); diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index ca94312e0f..847ae8bbfa 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -554,9 +554,6 @@ impl TestServer { migrations_path: None, seed_path: None, stripe_api_key: None, - stripe_zed_pro_price_id: None, - stripe_zed_pro_trial_price_id: None, - stripe_zed_free_price_id: None, supermaven_admin_api_key: None, user_backfiller_github_access_token: None, kinesis_region: None,