From 469824c350ab5d574df42349dfa94c36e2b5513a Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 28 May 2025 18:19:43 -0400 Subject: [PATCH] collab: Use `StripeClient` for creating model usage meter events (#31633) This PR updates the `StripeBilling::bill_model_request_usage` method to use the `StripeClient` trait. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 15 ++--- crates/collab/src/stripe_billing.rs | 64 +++---------------- crates/collab/src/stripe_client.rs | 20 +++++- .../src/stripe_client/fake_stripe_client.rs | 31 ++++++++- .../src/stripe_client/real_stripe_client.rs | 29 +++++++-- .../collab/src/tests/stripe_billing_tests.rs | 37 ++++++++++- 6 files changed, 119 insertions(+), 77 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 607576bb04..c438e67e53 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -29,6 +29,7 @@ use crate::db::billing_subscription::{ use crate::llm::db::subscription_usage_meter::CompletionMode; use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND}; use crate::rpc::{ResultExt as _, Server}; +use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId}; use crate::{AppState, Error, Result}; use crate::{db::UserId, llm::db::LlmDatabase}; use crate::{ @@ -1545,14 +1546,10 @@ async fn sync_model_request_usage_with_stripe( ); }; - let stripe_customer_id = billing_customer - .stripe_customer_id - .parse::() - .context("failed to parse Stripe customer ID from database")?; - let stripe_subscription_id = billing_subscription - .stripe_subscription_id - .parse::() - .context("failed to parse Stripe subscription ID from database")?; + let stripe_customer_id = + StripeCustomerId(billing_customer.stripe_customer_id.clone().into()); + let stripe_subscription_id = + StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into()); let model = llm_db.model_by_id(usage_meter.model_id)?; @@ -1578,7 +1575,7 @@ async fn sync_model_request_usage_with_stripe( }; stripe_billing - .subscribe_to_price(&stripe_subscription_id.into(), price) + .subscribe_to_price(&stripe_subscription_id, price) .await?; stripe_billing .bill_model_request_usage( diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index d3f062042b..ec5a534148 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use anyhow::{Context as _, anyhow}; use chrono::Utc; use collections::HashMap; -use serde::{Deserialize, Serialize}; use stripe::SubscriptionStatus; use tokio::sync::RwLock; use uuid::Uuid; @@ -12,8 +11,9 @@ use crate::Result; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG; use crate::stripe_client::{ - RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId, - StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams, + RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload, + StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription, + StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams, UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior, UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod, }; @@ -204,16 +204,15 @@ impl StripeBilling { pub async fn bill_model_request_usage( &self, - customer_id: &stripe::CustomerId, + customer_id: &StripeCustomerId, event_name: &str, requests: i32, ) -> Result<()> { let timestamp = Utc::now().timestamp(); let idempotency_key = Uuid::new_v4(); - StripeMeterEvent::create( - &self.real_client, - StripeCreateMeterEventParams { + self.client + .create_meter_event(StripeCreateMeterEventParams { identifier: &format!("model_requests/{}", idempotency_key), event_name, payload: StripeCreateMeterEventPayload { @@ -221,9 +220,8 @@ impl StripeBilling { stripe_customer_id: customer_id, }, timestamp: Some(timestamp), - }, - ) - .await?; + }) + .await?; Ok(()) } @@ -371,52 +369,6 @@ impl StripeBilling { } } -#[derive(Deserialize)] -struct StripeMeterEvent { - identifier: String, -} - -impl StripeMeterEvent { - pub async fn create( - client: &stripe::Client, - params: StripeCreateMeterEventParams<'_>, - ) -> Result { - let identifier = params.identifier; - match client.post_form("/billing/meter_events", params).await { - Ok(event) => Ok(event), - Err(stripe::StripeError::Stripe(error)) => { - if error.http_status == 400 - && error - .message - .as_ref() - .map_or(false, |message| message.contains(identifier)) - { - Ok(Self { - identifier: identifier.to_string(), - }) - } else { - Err(stripe::StripeError::Stripe(error)) - } - } - Err(error) => Err(error), - } - } -} - -#[derive(Serialize)] -struct StripeCreateMeterEventParams<'a> { - identifier: &'a str, - event_name: &'a str, - payload: StripeCreateMeterEventPayload<'a>, - timestamp: Option, -} - -#[derive(Serialize)] -struct StripeCreateMeterEventPayload<'a> { - value: u64, - stripe_customer_id: &'a stripe::CustomerId, -} - fn subscription_contains_price( subscription: &StripeSubscription, price_id: &StripePriceId, diff --git a/crates/collab/src/stripe_client.rs b/crates/collab/src/stripe_client.rs index 8ecf0b2fe5..f15e373a9e 100644 --- a/crates/collab/src/stripe_client.rs +++ b/crates/collab/src/stripe_client.rs @@ -10,9 +10,9 @@ use async_trait::async_trait; #[cfg(test)] pub use fake_stripe_client::*; pub use real_stripe_client::*; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; -#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)] +#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)] pub struct StripeCustomerId(pub Arc); #[derive(Debug, Clone)] @@ -97,6 +97,20 @@ pub struct StripeMeter { pub event_name: String, } +#[derive(Debug, Serialize)] +pub struct StripeCreateMeterEventParams<'a> { + pub identifier: &'a str, + pub event_name: &'a str, + pub payload: StripeCreateMeterEventPayload<'a>, + pub timestamp: Option, +} + +#[derive(Debug, Serialize)] +pub struct StripeCreateMeterEventPayload<'a> { + pub value: u64, + pub stripe_customer_id: &'a StripeCustomerId, +} + #[async_trait] pub trait StripeClient: Send + Sync { async fn list_customers_by_email(&self, email: &str) -> Result>; @@ -117,4 +131,6 @@ pub trait StripeClient: Send + Sync { async fn list_prices(&self) -> Result>; async fn list_meters(&self) -> Result>; + + async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>; } diff --git a/crates/collab/src/stripe_client/fake_stripe_client.rs b/crates/collab/src/stripe_client/fake_stripe_client.rs index 3c3be84da1..ddcdaacc3d 100644 --- a/crates/collab/src/stripe_client/fake_stripe_client.rs +++ b/crates/collab/src/stripe_client/fake_stripe_client.rs @@ -7,11 +7,20 @@ use parking_lot::Mutex; use uuid::Uuid; use crate::stripe_client::{ - CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, - StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId, - UpdateSubscriptionParams, + CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer, + StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription, + StripeSubscriptionId, UpdateSubscriptionParams, }; +#[derive(Debug, Clone)] +pub struct StripeCreateMeterEventCall { + pub identifier: Arc, + pub event_name: Arc, + pub value: u64, + pub stripe_customer_id: StripeCustomerId, + pub timestamp: Option, +} + pub struct FakeStripeClient { pub customers: Arc>>, pub subscriptions: Arc>>, @@ -19,6 +28,7 @@ pub struct FakeStripeClient { Arc>>, pub prices: Arc>>, pub meters: Arc>>, + pub create_meter_event_calls: Arc>>, } impl FakeStripeClient { @@ -29,6 +39,7 @@ impl FakeStripeClient { update_subscription_calls: Arc::new(Mutex::new(Vec::new())), prices: Arc::new(Mutex::new(HashMap::default())), meters: Arc::new(Mutex::new(HashMap::default())), + create_meter_event_calls: Arc::new(Mutex::new(Vec::new())), } } } @@ -94,4 +105,18 @@ impl StripeClient for FakeStripeClient { Ok(meters) } + + async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { + self.create_meter_event_calls + .lock() + .push(StripeCreateMeterEventCall { + identifier: params.identifier.into(), + event_name: params.event_name.into(), + value: params.payload.value, + stripe_customer_id: params.payload.stripe_customer_id.clone(), + timestamp: params.timestamp, + }); + + Ok(()) + } } diff --git a/crates/collab/src/stripe_client/real_stripe_client.rs b/crates/collab/src/stripe_client/real_stripe_client.rs index 62f436d617..fa0b08790d 100644 --- a/crates/collab/src/stripe_client/real_stripe_client.rs +++ b/crates/collab/src/stripe_client/real_stripe_client.rs @@ -1,7 +1,7 @@ use std::str::FromStr as _; use std::sync::Arc; -use anyhow::{Context as _, Result}; +use anyhow::{Context as _, Result, anyhow}; use async_trait::async_trait; use serde::Serialize; use stripe::{ @@ -12,9 +12,10 @@ use stripe::{ }; use crate::stripe_client::{ - CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice, - StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId, - StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams, + CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer, + StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring, + StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, + UpdateSubscriptionParams, }; pub struct RealStripeClient { @@ -129,6 +130,26 @@ impl StripeClient for RealStripeClient { Ok(response.data) } + + async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> { + let identifier = params.identifier; + match self.client.post_form("/billing/meter_events", params).await { + Ok(event) => Ok(event), + Err(stripe::StripeError::Stripe(error)) => { + if error.http_status == 400 + && error + .message + .as_ref() + .map_or(false, |message| message.contains(identifier)) + { + Ok(()) + } else { + Err(anyhow!(stripe::StripeError::Stripe(error))) + } + } + Err(error) => Err(anyhow!(error)), + } + } } impl From for StripeCustomerId { diff --git a/crates/collab/src/tests/stripe_billing_tests.rs b/crates/collab/src/tests/stripe_billing_tests.rs index b12fa722f3..6a8bab90fe 100644 --- a/crates/collab/src/tests/stripe_billing_tests.rs +++ b/crates/collab/src/tests/stripe_billing_tests.rs @@ -4,9 +4,9 @@ use pretty_assertions::assert_eq; use crate::stripe_billing::StripeBilling; use crate::stripe_client::{ - FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring, - StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId, - UpdateSubscriptionItems, + FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, + StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, + StripeSubscriptionItemId, UpdateSubscriptionItems, }; fn make_stripe_billing() -> (StripeBilling, Arc) { @@ -210,3 +210,34 @@ async fn test_subscribe_to_price() { assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1); } } + +#[gpui::test] +async fn test_bill_model_request_usage() { + let (stripe_billing, stripe_client) = make_stripe_billing(); + + let customer_id = StripeCustomerId("cus_test".into()); + + stripe_billing + .bill_model_request_usage(&customer_id, "some_model/requests", 73) + .await + .unwrap(); + + let create_meter_event_calls = stripe_client + .create_meter_event_calls + .lock() + .iter() + .cloned() + .collect::>(); + assert_eq!(create_meter_event_calls.len(), 1); + assert!( + create_meter_event_calls[0] + .identifier + .starts_with("model_requests/") + ); + assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id); + assert_eq!( + create_meter_event_calls[0].event_name.as_ref(), + "some_model/requests" + ); + assert_eq!(create_meter_event_calls[0].value, 73); +}