collab: Use StripeClient
for creating model usage meter events (#31633)
This PR updates the `StripeBilling::bill_model_request_usage` method to use the `StripeClient` trait. Release Notes: - N/A
This commit is contained in:
parent
a1c645e57e
commit
469824c350
6 changed files with 119 additions and 77 deletions
|
@ -29,6 +29,7 @@ use crate::db::billing_subscription::{
|
||||||
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
use crate::llm::db::subscription_usage_meter::CompletionMode;
|
||||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
|
||||||
use crate::rpc::{ResultExt as _, Server};
|
use crate::rpc::{ResultExt as _, Server};
|
||||||
|
use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
|
||||||
use crate::{AppState, Error, Result};
|
use crate::{AppState, Error, Result};
|
||||||
use crate::{db::UserId, llm::db::LlmDatabase};
|
use crate::{db::UserId, llm::db::LlmDatabase};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -1545,14 +1546,10 @@ async fn sync_model_request_usage_with_stripe(
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let stripe_customer_id = billing_customer
|
let stripe_customer_id =
|
||||||
.stripe_customer_id
|
StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
|
||||||
.parse::<stripe::CustomerId>()
|
let stripe_subscription_id =
|
||||||
.context("failed to parse Stripe customer ID from database")?;
|
StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
|
||||||
let stripe_subscription_id = billing_subscription
|
|
||||||
.stripe_subscription_id
|
|
||||||
.parse::<stripe::SubscriptionId>()
|
|
||||||
.context("failed to parse Stripe subscription ID from database")?;
|
|
||||||
|
|
||||||
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
||||||
|
|
||||||
|
@ -1578,7 +1575,7 @@ async fn sync_model_request_usage_with_stripe(
|
||||||
};
|
};
|
||||||
|
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.subscribe_to_price(&stripe_subscription_id.into(), price)
|
.subscribe_to_price(&stripe_subscription_id, price)
|
||||||
.await?;
|
.await?;
|
||||||
stripe_billing
|
stripe_billing
|
||||||
.bill_model_request_usage(
|
.bill_model_request_usage(
|
||||||
|
|
|
@ -3,7 +3,6 @@ use std::sync::Arc;
|
||||||
use anyhow::{Context as _, anyhow};
|
use anyhow::{Context as _, anyhow};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use stripe::SubscriptionStatus;
|
use stripe::SubscriptionStatus;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
@ -12,8 +11,9 @@ use crate::Result;
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||||
use crate::stripe_client::{
|
use crate::stripe_client::{
|
||||||
RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
|
RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
|
||||||
StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
|
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
|
||||||
|
StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
|
||||||
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
|
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
|
||||||
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
|
||||||
};
|
};
|
||||||
|
@ -204,16 +204,15 @@ impl StripeBilling {
|
||||||
|
|
||||||
pub async fn bill_model_request_usage(
|
pub async fn bill_model_request_usage(
|
||||||
&self,
|
&self,
|
||||||
customer_id: &stripe::CustomerId,
|
customer_id: &StripeCustomerId,
|
||||||
event_name: &str,
|
event_name: &str,
|
||||||
requests: i32,
|
requests: i32,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let timestamp = Utc::now().timestamp();
|
let timestamp = Utc::now().timestamp();
|
||||||
let idempotency_key = Uuid::new_v4();
|
let idempotency_key = Uuid::new_v4();
|
||||||
|
|
||||||
StripeMeterEvent::create(
|
self.client
|
||||||
&self.real_client,
|
.create_meter_event(StripeCreateMeterEventParams {
|
||||||
StripeCreateMeterEventParams {
|
|
||||||
identifier: &format!("model_requests/{}", idempotency_key),
|
identifier: &format!("model_requests/{}", idempotency_key),
|
||||||
event_name,
|
event_name,
|
||||||
payload: StripeCreateMeterEventPayload {
|
payload: StripeCreateMeterEventPayload {
|
||||||
|
@ -221,9 +220,8 @@ impl StripeBilling {
|
||||||
stripe_customer_id: customer_id,
|
stripe_customer_id: customer_id,
|
||||||
},
|
},
|
||||||
timestamp: Some(timestamp),
|
timestamp: Some(timestamp),
|
||||||
},
|
})
|
||||||
)
|
.await?;
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -371,52 +369,6 @@ impl StripeBilling {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct StripeMeterEvent {
|
|
||||||
identifier: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StripeMeterEvent {
|
|
||||||
pub async fn create(
|
|
||||||
client: &stripe::Client,
|
|
||||||
params: StripeCreateMeterEventParams<'_>,
|
|
||||||
) -> Result<Self, stripe::StripeError> {
|
|
||||||
let identifier = params.identifier;
|
|
||||||
match client.post_form("/billing/meter_events", params).await {
|
|
||||||
Ok(event) => Ok(event),
|
|
||||||
Err(stripe::StripeError::Stripe(error)) => {
|
|
||||||
if error.http_status == 400
|
|
||||||
&& error
|
|
||||||
.message
|
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |message| message.contains(identifier))
|
|
||||||
{
|
|
||||||
Ok(Self {
|
|
||||||
identifier: identifier.to_string(),
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Err(stripe::StripeError::Stripe(error))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => Err(error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct StripeCreateMeterEventParams<'a> {
|
|
||||||
identifier: &'a str,
|
|
||||||
event_name: &'a str,
|
|
||||||
payload: StripeCreateMeterEventPayload<'a>,
|
|
||||||
timestamp: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct StripeCreateMeterEventPayload<'a> {
|
|
||||||
value: u64,
|
|
||||||
stripe_customer_id: &'a stripe::CustomerId,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn subscription_contains_price(
|
fn subscription_contains_price(
|
||||||
subscription: &StripeSubscription,
|
subscription: &StripeSubscription,
|
||||||
price_id: &StripePriceId,
|
price_id: &StripePriceId,
|
||||||
|
|
|
@ -10,9 +10,9 @@ use async_trait::async_trait;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub use fake_stripe_client::*;
|
pub use fake_stripe_client::*;
|
||||||
pub use real_stripe_client::*;
|
pub use real_stripe_client::*;
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
|
||||||
pub struct StripeCustomerId(pub Arc<str>);
|
pub struct StripeCustomerId(pub Arc<str>);
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -97,6 +97,20 @@ pub struct StripeMeter {
|
||||||
pub event_name: String,
|
pub event_name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct StripeCreateMeterEventParams<'a> {
|
||||||
|
pub identifier: &'a str,
|
||||||
|
pub event_name: &'a str,
|
||||||
|
pub payload: StripeCreateMeterEventPayload<'a>,
|
||||||
|
pub timestamp: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct StripeCreateMeterEventPayload<'a> {
|
||||||
|
pub value: u64,
|
||||||
|
pub stripe_customer_id: &'a StripeCustomerId,
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait StripeClient: Send + Sync {
|
pub trait StripeClient: Send + Sync {
|
||||||
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
|
||||||
|
@ -117,4 +131,6 @@ pub trait StripeClient: Send + Sync {
|
||||||
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
|
||||||
|
|
||||||
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,11 +7,20 @@ use parking_lot::Mutex;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::stripe_client::{
|
use crate::stripe_client::{
|
||||||
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
|
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
|
||||||
StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
|
StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
|
||||||
UpdateSubscriptionParams,
|
StripeSubscriptionId, UpdateSubscriptionParams,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StripeCreateMeterEventCall {
|
||||||
|
pub identifier: Arc<str>,
|
||||||
|
pub event_name: Arc<str>,
|
||||||
|
pub value: u64,
|
||||||
|
pub stripe_customer_id: StripeCustomerId,
|
||||||
|
pub timestamp: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct FakeStripeClient {
|
pub struct FakeStripeClient {
|
||||||
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
|
||||||
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
|
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
|
||||||
|
@ -19,6 +28,7 @@ pub struct FakeStripeClient {
|
||||||
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
|
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
|
||||||
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
|
||||||
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
|
||||||
|
pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeStripeClient {
|
impl FakeStripeClient {
|
||||||
|
@ -29,6 +39,7 @@ impl FakeStripeClient {
|
||||||
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
|
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
|
||||||
prices: Arc::new(Mutex::new(HashMap::default())),
|
prices: Arc::new(Mutex::new(HashMap::default())),
|
||||||
meters: Arc::new(Mutex::new(HashMap::default())),
|
meters: Arc::new(Mutex::new(HashMap::default())),
|
||||||
|
create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,4 +105,18 @@ impl StripeClient for FakeStripeClient {
|
||||||
|
|
||||||
Ok(meters)
|
Ok(meters)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||||
|
self.create_meter_event_calls
|
||||||
|
.lock()
|
||||||
|
.push(StripeCreateMeterEventCall {
|
||||||
|
identifier: params.identifier.into(),
|
||||||
|
event_name: params.event_name.into(),
|
||||||
|
value: params.payload.value,
|
||||||
|
stripe_customer_id: params.payload.stripe_customer_id.clone(),
|
||||||
|
timestamp: params.timestamp,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::str::FromStr as _;
|
use std::str::FromStr as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{Context as _, Result};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use stripe::{
|
use stripe::{
|
||||||
|
@ -12,9 +12,10 @@ use stripe::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::stripe_client::{
|
use crate::stripe_client::{
|
||||||
CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
|
CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
|
||||||
StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
|
StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
|
||||||
StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams,
|
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
||||||
|
UpdateSubscriptionParams,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct RealStripeClient {
|
pub struct RealStripeClient {
|
||||||
|
@ -129,6 +130,26 @@ impl StripeClient for RealStripeClient {
|
||||||
|
|
||||||
Ok(response.data)
|
Ok(response.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
|
||||||
|
let identifier = params.identifier;
|
||||||
|
match self.client.post_form("/billing/meter_events", params).await {
|
||||||
|
Ok(event) => Ok(event),
|
||||||
|
Err(stripe::StripeError::Stripe(error)) => {
|
||||||
|
if error.http_status == 400
|
||||||
|
&& error
|
||||||
|
.message
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |message| message.contains(identifier))
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!(stripe::StripeError::Stripe(error)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => Err(anyhow!(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<CustomerId> for StripeCustomerId {
|
impl From<CustomerId> for StripeCustomerId {
|
||||||
|
|
|
@ -4,9 +4,9 @@ use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
use crate::stripe_billing::StripeBilling;
|
use crate::stripe_billing::StripeBilling;
|
||||||
use crate::stripe_client::{
|
use crate::stripe_client::{
|
||||||
FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
|
FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
|
||||||
StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
|
StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
|
||||||
UpdateSubscriptionItems,
|
StripeSubscriptionItemId, UpdateSubscriptionItems,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
|
||||||
|
@ -210,3 +210,34 @@ async fn test_subscribe_to_price() {
|
||||||
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
|
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_bill_model_request_usage() {
|
||||||
|
let (stripe_billing, stripe_client) = make_stripe_billing();
|
||||||
|
|
||||||
|
let customer_id = StripeCustomerId("cus_test".into());
|
||||||
|
|
||||||
|
stripe_billing
|
||||||
|
.bill_model_request_usage(&customer_id, "some_model/requests", 73)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let create_meter_event_calls = stripe_client
|
||||||
|
.create_meter_event_calls
|
||||||
|
.lock()
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert_eq!(create_meter_event_calls.len(), 1);
|
||||||
|
assert!(
|
||||||
|
create_meter_event_calls[0]
|
||||||
|
.identifier
|
||||||
|
.starts_with("model_requests/")
|
||||||
|
);
|
||||||
|
assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
|
||||||
|
assert_eq!(
|
||||||
|
create_meter_event_calls[0].event_name.as_ref(),
|
||||||
|
"some_model/requests"
|
||||||
|
);
|
||||||
|
assert_eq!(create_meter_event_calls[0].value, 73);
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue