collab: Use StripeClient
to retrieve prices and meters from Stripe (#31624)
This PR updates `StripeBilling` to use the `StripeClient` trait to retrieve prices and meters from Stripe instead of using the `stripe::Client` directly. Release Notes: - N/A
This commit is contained in:
parent
05afe95539
commit
75e69a5ae9
6 changed files with 228 additions and 49 deletions
|
@ -499,8 +499,10 @@ async fn manage_billing_subscription(
|
|||
let flow = match body.intent {
|
||||
ManageSubscriptionIntent::ManageSubscription => None,
|
||||
ManageSubscriptionIntent::UpgradeToPro => {
|
||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
|
||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
|
||||
let zed_pro_price_id: stripe::PriceId =
|
||||
stripe_billing.zed_pro_price_id().await?.try_into()?;
|
||||
let zed_free_price_id: stripe::PriceId =
|
||||
stripe_billing.zed_free_price_id().await?.try_into()?;
|
||||
|
||||
let stripe_subscription =
|
||||
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
||||
|
|
|
@ -4,14 +4,16 @@ use anyhow::{Context as _, anyhow};
|
|||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::{PriceId, SubscriptionStatus};
|
||||
use stripe::SubscriptionStatus;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use crate::stripe_client::{RealStripeClient, StripeClient, StripeCustomerId};
|
||||
use crate::stripe_client::{
|
||||
RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
|
||||
};
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
|
@ -22,8 +24,8 @@ pub struct StripeBilling {
|
|||
#[derive(Default)]
|
||||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
prices_by_lookup_key: HashMap<String, stripe::Price>,
|
||||
price_ids_by_meter_id: HashMap<String, StripePriceId>,
|
||||
prices_by_lookup_key: HashMap<String, StripePrice>,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
|
@ -50,24 +52,16 @@ impl StripeBilling {
|
|||
|
||||
let mut state = self.state.write().await;
|
||||
|
||||
let (meters, prices) = futures::try_join!(
|
||||
StripeMeter::list(&self.real_client),
|
||||
stripe::Price::list(
|
||||
&self.real_client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
)?;
|
||||
let (meters, prices) =
|
||||
futures::try_join!(self.client.list_meters(), self.client.list_prices())?;
|
||||
|
||||
for meter in meters.data {
|
||||
for meter in meters {
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
|
||||
for price in prices.data {
|
||||
for price in prices {
|
||||
if let Some(lookup_key) = price.lookup_key.clone() {
|
||||
state.prices_by_lookup_key.insert(lookup_key, price.clone());
|
||||
}
|
||||
|
@ -84,15 +78,15 @@ impl StripeBilling {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
|
||||
pub async fn zed_pro_price_id(&self) -> Result<StripePriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-pro").await
|
||||
}
|
||||
|
||||
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
|
||||
pub async fn zed_free_price_id(&self) -> Result<StripePriceId> {
|
||||
self.find_price_id_by_lookup_key("zed-free").await
|
||||
}
|
||||
|
||||
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
|
||||
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<StripePriceId> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
|
@ -102,7 +96,7 @@ impl StripeBilling {
|
|||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
|
||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<StripePrice> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
|
@ -116,8 +110,10 @@ impl StripeBilling {
|
|||
&self,
|
||||
subscription: &stripe::Subscription,
|
||||
) -> Option<SubscriptionKind> {
|
||||
let zed_pro_price_id = self.zed_pro_price_id().await.ok()?;
|
||||
let zed_free_price_id = self.zed_free_price_id().await.ok()?;
|
||||
let zed_pro_price_id: stripe::PriceId =
|
||||
self.zed_pro_price_id().await.ok()?.try_into().ok()?;
|
||||
let zed_free_price_id: stripe::PriceId =
|
||||
self.zed_free_price_id().await.ok()?.try_into().ok()?;
|
||||
|
||||
subscription.items.data.iter().find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
@ -171,12 +167,13 @@ impl StripeBilling {
|
|||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
price: &stripe::Price,
|
||||
price: &StripePrice,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.real_client, &subscription_id, &[]).await?;
|
||||
|
||||
if subscription_contains_price(&subscription, &price.id) {
|
||||
let price_id = price.id.clone().try_into()?;
|
||||
if subscription_contains_price(&subscription, &price_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -375,24 +372,6 @@ impl StripeBilling {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct StripeMeter {
|
||||
id: String,
|
||||
event_name: String,
|
||||
}
|
||||
|
||||
impl StripeMeter {
|
||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
client.get_query("/billing/meters", Params { limit: Some(100) })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StripeMeterEvent {
|
||||
identifier: String,
|
||||
|
|
|
@ -10,8 +10,9 @@ use async_trait::async_trait;
|
|||
#[cfg(test)]
|
||||
pub use fake_stripe_client::*;
|
||||
pub use real_stripe_client::*;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
pub struct StripeCustomerId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -25,9 +26,38 @@ pub struct CreateCustomerParams<'a> {
|
|||
pub email: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
||||
pub struct StripePriceId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripePrice {
|
||||
pub id: StripePriceId,
|
||||
pub unit_amount: Option<i64>,
|
||||
pub lookup_key: Option<String>,
|
||||
pub recurring: Option<StripePriceRecurring>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StripePriceRecurring {
|
||||
pub meter: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Deserialize)]
|
||||
pub struct StripeMeterId(pub Arc<str>);
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct StripeMeter {
|
||||
pub id: StripeMeterId,
|
||||
pub event_name: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait StripeClient: Send + Sync {
|
||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||
|
||||
async fn create_customer(&self, params: CreateCustomerParams<'_>) -> Result<StripeCustomer>;
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||
}
|
||||
|
|
|
@ -6,16 +6,23 @@ use collections::HashMap;
|
|||
use parking_lot::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
|
||||
StripeMeterId, StripePrice, StripePriceId,
|
||||
};
|
||||
|
||||
pub struct FakeStripeClient {
|
||||
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
||||
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
||||
}
|
||||
|
||||
impl FakeStripeClient {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
customers: Arc::new(Mutex::new(HashMap::default())),
|
||||
prices: Arc::new(Mutex::new(HashMap::default())),
|
||||
meters: Arc::new(Mutex::new(HashMap::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,4 +51,16 @@ impl StripeClient for FakeStripeClient {
|
|||
|
||||
Ok(customer)
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let prices = self.prices.lock().values().cloned().collect();
|
||||
|
||||
Ok(prices)
|
||||
}
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||
let meters = self.meters.lock().values().cloned().collect();
|
||||
|
||||
Ok(meters)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,13 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::{Context as _, Result};
|
||||
use async_trait::async_trait;
|
||||
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers};
|
||||
use serde::Serialize;
|
||||
use stripe::{CreateCustomer, Customer, CustomerId, ListCustomers, Price, PriceId, Recurring};
|
||||
|
||||
use crate::stripe_client::{CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId};
|
||||
use crate::stripe_client::{
|
||||
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
|
||||
StripePriceId, StripePriceRecurring,
|
||||
};
|
||||
|
||||
pub struct RealStripeClient {
|
||||
client: Arc<stripe::Client>,
|
||||
|
@ -48,6 +52,37 @@ impl StripeClient for RealStripeClient {
|
|||
|
||||
Ok(StripeCustomer::from(customer))
|
||||
}
|
||||
|
||||
async fn list_prices(&self) -> Result<Vec<StripePrice>> {
|
||||
let response = stripe::Price::list(
|
||||
&self.client,
|
||||
&stripe::ListPrices {
|
||||
limit: Some(100),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response.data.into_iter().map(StripePrice::from).collect())
|
||||
}
|
||||
|
||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get_query::<stripe::List<StripeMeter>, _>(
|
||||
"/billing/meters",
|
||||
Params { limit: Some(100) },
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response.data)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CustomerId> for StripeCustomerId {
|
||||
|
@ -72,3 +107,34 @@ impl From<Customer> for StripeCustomer {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PriceId> for StripePriceId {
|
||||
fn from(value: PriceId) -> Self {
|
||||
Self(value.as_str().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<StripePriceId> for PriceId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: StripePriceId) -> Result<Self, Self::Error> {
|
||||
Self::from_str(value.0.as_ref()).context("failed to parse Stripe price ID")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Price> for StripePrice {
|
||||
fn from(value: Price) -> Self {
|
||||
Self {
|
||||
id: value.id.into(),
|
||||
unit_amount: value.unit_amount,
|
||||
lookup_key: value.lookup_key,
|
||||
recurring: value.recurring.map(StripePriceRecurring::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Recurring> for StripePriceRecurring {
|
||||
fn from(value: Recurring) -> Self {
|
||||
Self { meter: value.meter }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@ use std::sync::Arc;
|
|||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::FakeStripeClient;
|
||||
use crate::stripe_client::{
|
||||
FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
|
||||
};
|
||||
|
||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||
let stripe_client = Arc::new(FakeStripeClient::new());
|
||||
|
@ -12,6 +14,87 @@ fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
|||
(stripe_billing, stripe_client)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_initialize() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
||||
// Add test meters
|
||||
let meter1 = StripeMeter {
|
||||
id: StripeMeterId("meter_1".into()),
|
||||
event_name: "event_1".to_string(),
|
||||
};
|
||||
let meter2 = StripeMeter {
|
||||
id: StripeMeterId("meter_2".into()),
|
||||
event_name: "event_2".to_string(),
|
||||
};
|
||||
stripe_client
|
||||
.meters
|
||||
.lock()
|
||||
.insert(meter1.id.clone(), meter1);
|
||||
stripe_client
|
||||
.meters
|
||||
.lock()
|
||||
.insert(meter2.id.clone(), meter2);
|
||||
|
||||
// Add test prices
|
||||
let price1 = StripePrice {
|
||||
id: StripePriceId("price_1".into()),
|
||||
unit_amount: Some(1_000),
|
||||
lookup_key: Some("zed-pro".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
let price2 = StripePrice {
|
||||
id: StripePriceId("price_2".into()),
|
||||
unit_amount: Some(0),
|
||||
lookup_key: Some("zed-free".to_string()),
|
||||
recurring: None,
|
||||
};
|
||||
let price3 = StripePrice {
|
||||
id: StripePriceId("price_3".into()),
|
||||
unit_amount: Some(500),
|
||||
lookup_key: None,
|
||||
recurring: Some(StripePriceRecurring {
|
||||
meter: Some("meter_1".to_string()),
|
||||
}),
|
||||
};
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price1.id.clone(), price1);
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price2.id.clone(), price2);
|
||||
stripe_client
|
||||
.prices
|
||||
.lock()
|
||||
.insert(price3.id.clone(), price3);
|
||||
|
||||
// Initialize the billing system
|
||||
stripe_billing.initialize().await.unwrap();
|
||||
|
||||
// Verify that prices can be found by lookup key
|
||||
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.unwrap();
|
||||
assert_eq!(zed_pro_price_id.to_string(), "price_1");
|
||||
|
||||
let zed_free_price_id = stripe_billing.zed_free_price_id().await.unwrap();
|
||||
assert_eq!(zed_free_price_id.to_string(), "price_2");
|
||||
|
||||
// Verify that a price can be found by lookup key
|
||||
let zed_pro_price = stripe_billing
|
||||
.find_price_by_lookup_key("zed-pro")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(zed_pro_price.id.to_string(), "price_1");
|
||||
assert_eq!(zed_pro_price.unit_amount, Some(1_000));
|
||||
|
||||
// Verify that finding a non-existent lookup key returns an error
|
||||
let result = stripe_billing
|
||||
.find_price_by_lookup_key("non-existent")
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_find_or_create_customer_by_email() {
|
||||
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue