collab: Use StripeClient to retrieve prices and meters from Stripe (#31624)

This PR updates `StripeBilling` to use the `StripeClient` trait to
retrieve prices and meters from Stripe instead of using the
`stripe::Client` directly.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-05-28 15:51:06 -04:00 committed by GitHub
parent 05afe95539
commit 75e69a5ae9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 228 additions and 49 deletions

View file

@ -499,8 +499,10 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
let zed_pro_price_id: stripe::PriceId =
stripe_billing.zed_pro_price_id().await?.try_into()?;
let zed_free_price_id: stripe::PriceId =
stripe_billing.zed_free_price_id().await?.try_into()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;

View file

@ -4,14 +4,16 @@ use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::{PriceId, SubscriptionStatus};
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
use crate::stripe_client::{
RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
};
pub struct StripeBilling {
state: RwLock<StripeBillingState>,
@ -22,8 +24,8 @@ pub struct StripeBilling {
#[derive(Default)]
struct StripeBillingState {
meters_by_event_name: HashMap<String, StripeMeter>,
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
prices_by_lookup_key: HashMap<String, stripe::Price>,
price_ids_by_meter_id: HashMap<String, StripePriceId>,
prices_by_lookup_key: HashMap<String, StripePrice>,
}
impl StripeBilling {
@ -50,24 +52,16 @@ impl StripeBilling {
let mut state = self.state.write().await;
let (meters, prices) = futures::try_join!(
StripeMeter::list(&self.real_client),
stripe::Price::list(
&self.real_client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
}
)
)?;
let (meters, prices) =
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
for meter in meters.data {
for meter in meters {
state
.meters_by_event_name
.insert(meter.event_name.clone(), meter);
}
for price in prices.data {
for price in prices {
if let Some(lookup_key) = price.lookup_key.clone() {
state.prices_by_lookup_key.insert(lookup_key, price.clone());
}
@ -84,15 +78,15 @@ impl StripeBilling {
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
self.state
.read()
.await
@ -102,7 +96,7 @@ impl StripeBilling {
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
self.state
.read()
.await
@ -116,8 +110,10 @@ impl StripeBilling {
&self,
subscription: &stripe::Subscription,
) -> Option<SubscriptionKind> {
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
let zed_pro_price_id: stripe::PriceId =
self.zed_pro_price_id().await.ok()?.try_into().ok()?;
let zed_free_price_id: stripe::PriceId =
self.zed_free_price_id().await.ok()?.try_into().ok()?;
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
@ -171,12 +167,13 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price: &stripe::Price,
price: &StripePrice,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) {
let price_id = price.id.clone().try_into()?;
if subscription_contains_price(&subscription, &price_id) {
return Ok(());
}
@ -375,24 +372,6 @@ impl StripeBilling {
}
}
#[derive(Clone, Deserialize)]
struct StripeMeter {
id: String,
event_name: String,
}
impl StripeMeter {
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
client.get_query("/billing/meters", Params { limit: Some(100) })
}
}
#[derive(Deserialize)]
struct StripeMeterEvent {
identifier: String,

View file

@ -10,8 +10,9 @@ use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
use serde::Deserialize;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
@ -25,9 +26,38 @@ pub struct CreateCustomerParams<'a> {
pub email: Option<&'a str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
pub struct StripePriceId(pub Arc<str>);
#[derive(Debug, Clone)]
pub struct StripePrice {
pub id: StripePriceId,
pub unit_amount: Option<i64>,
pub lookup_key: Option<String>,
pub recurring: Option<StripePriceRecurring>,
}
#[derive(Debug, Clone)]
pub struct StripePriceRecurring {
pub meter: Option<String>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
pub struct StripeMeterId(pub Arc<str>);
#[derive(Debug, Clone, Deserialize)]
pub struct StripeMeter {
pub id: StripeMeterId,
pub event_name: String,
}
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
}

View file

@ -6,16 +6,23 @@ use collections::HashMap;
use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
StripeMeterId, StripePrice, StripePriceId,
};
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
}
impl FakeStripeClient {
pub fn new() -> Self {
Self {
customers: Arc::new(Mutex::new(HashMap::default())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
}
}
}
@ -44,4 +51,16 @@ impl StripeClient for FakeStripeClient {
Ok(customer)
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let prices = self.prices.lock().values().cloned().collect();
Ok(prices)
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
let meters = self.meters.lock().values().cloned().collect();
Ok(meters)
}
}

View file

@ -3,9 +3,13 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use async_trait::async_trait;
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
use serde::Serialize;
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
use crate::stripe_client::{
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
StripePriceId, StripePriceRecurring,
};
pub struct RealStripeClient {
client: Arc<stripe::Client>,
@ -48,6 +52,37 @@ impl StripeClient for RealStripeClient {
Ok(StripeCustomer::from(customer))
}
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
let response = stripe::Price::list(
&self.client,
&stripe::ListPrices {
limit: Some(100),
..Default::default()
},
)
.await?;
Ok(response.data.into_iter().map(StripePrice::from).collect())
}
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
#[derive(Serialize)]
struct Params {
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u64>,
}
let response = self
.client
.get_query::<stripe::List<StripeMeter>, _>(
"/billing/meters",
Params { limit: Some(100) },
)
.await?;
Ok(response.data)
}
}
impl From<CustomerId> for StripeCustomerId {
@ -72,3 +107,34 @@ impl From<Customer> for StripeCustomer {
}
}
}
impl From<PriceId> for StripePriceId {
fn from(value: PriceId) -> Self {
Self(value.as_str().into())
}
}
impl TryFrom<StripePriceId> for PriceId {
type Error = anyhow::Error;
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
}
}
impl From<Price> for StripePrice {
fn from(value: Price) -> Self {
Self {
id: value.id.into(),
unit_amount: value.unit_amount,
lookup_key: value.lookup_key,
recurring: value.recurring.map(StripePriceRecurring::from),
}
}
}
impl From<Recurring> for StripePriceRecurring {
fn from(value: Recurring) -> Self {
Self { meter: value.meter }
}
}

View file

@ -3,7 +3,9 @@ use std::sync::Arc;
use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::FakeStripeClient;
use crate::stripe_client::{
FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
let stripe_client = Arc::new(FakeStripeClient::new());
@ -12,6 +14,87 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
(stripe_billing, stripe_client)
}
#[gpui::test]
async fn test_initialize() {
let (stripe_billing, stripe_client) = make_stripe_billing();
// Add test meters
let meter1 = StripeMeter {
id: StripeMeterId("meter_1".into()),
event_name: "event_1".to_string(),
};
let meter2 = StripeMeter {
id: StripeMeterId("meter_2".into()),
event_name: "event_2".to_string(),
};
stripe_client
.meters
.lock()
.insert(meter1.id.clone(), meter1);
stripe_client
.meters
.lock()
.insert(meter2.id.clone(), meter2);
// Add test prices
let price1 = StripePrice {
id: StripePriceId("price_1".into()),
unit_amount: Some(1_000),
lookup_key: Some("zed-pro".to_string()),
recurring: None,
};
let price2 = StripePrice {
id: StripePriceId("price_2".into()),
unit_amount: Some(0),
lookup_key: Some("zed-free".to_string()),
recurring: None,
};
let price3 = StripePrice {
id: StripePriceId("price_3".into()),
unit_amount: Some(500),
lookup_key: None,
recurring: Some(StripePriceRecurring {
meter: Some("meter_1".to_string()),
}),
};
stripe_client
.prices
.lock()
.insert(price1.id.clone(), price1);
stripe_client
.prices
.lock()
.insert(price2.id.clone(), price2);
stripe_client
.prices
.lock()
.insert(price3.id.clone(), price3);
// Initialize the billing system
stripe_billing.initialize().await.unwrap();
// Verify that prices can be found by lookup key
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
assert_eq!(zed_pro_price_id.to_string(), "price_1");
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
assert_eq!(zed_free_price_id.to_string(), "price_2");
// Verify that a price can be found by lookup key
let zed_pro_price = stripe_billing
.find_price_by_lookup_key("zed-pro")
.await
.unwrap();
assert_eq!(zed_pro_price.id.to_string(), "price_1");
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
// Verify that finding a non-existent lookup key returns an error
let result = stripe_billing
.find_price_by_lookup_key("non-existent")
.await;
assert!(result.is_err());
}
#[gpui::test]
async fn test_find_or_create_customer_by_email() {
let (stripe_billing, stripe_client) = make_stripe_billing();