collab: Sync model request overages to Stripe (#29583)
This PR adds syncing of model request overages to Stripe. Release Notes: - N/A
This commit is contained in:
parent
3a212e72a4
commit
5092f0f18b
9 changed files with 318 additions and 16 deletions
|
@ -1,12 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{Cents, Result, llm};
|
||||
use anyhow::Context as _;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use chrono::{Datelike, Utc};
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::PriceId;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
|
@ -17,9 +18,10 @@ pub struct StripeBilling {
|
|||
struct StripeBillingState {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
prices_by_lookup_key: HashMap<String, stripe::Price>,
|
||||
}
|
||||
|
||||
pub struct StripeModel {
|
||||
pub struct StripeModelTokenPrices {
|
||||
input_tokens_price: StripeBillingPrice,
|
||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||
input_cache_read_tokens_price: StripeBillingPrice,
|
||||
|
@ -62,6 +64,10 @@ impl StripeBilling {
|
|||
}
|
||||
|
||||
for price in prices.data {
|
||||
if let Some(lookup_key) = price.lookup_key.clone() {
|
||||
state.prices_by_lookup_key.insert(lookup_key, price.clone());
|
||||
}
|
||||
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
state.price_ids_by_meter_id.insert(meter, price.id);
|
||||
|
@ -74,36 +80,49 @@ impl StripeBilling {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
|
||||
self.state
|
||||
.read()
|
||||
.await
|
||||
.prices_by_lookup_key
|
||||
.get(lookup_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn register_model_for_token_based_usage(
|
||||
&self,
|
||||
model: &llm::db::model::Model,
|
||||
) -> Result<StripeModelTokenPrices> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
&format!("{} (Input Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_creation_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_read_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/output_tokens", model.id),
|
||||
&format!("{} (Output Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_output_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
Ok(StripeModel {
|
||||
Ok(StripeModelTokenPrices {
|
||||
input_tokens_price,
|
||||
input_cache_creation_tokens_price,
|
||||
input_cache_read_tokens_price,
|
||||
|
@ -111,7 +130,7 @@ impl StripeBilling {
|
|||
})
|
||||
}
|
||||
|
||||
async fn get_or_insert_price(
|
||||
async fn get_or_insert_token_price(
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
|
@ -207,10 +226,43 @@ impl StripeBilling {
|
|||
})
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
price_id: &stripe::PriceId,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
if subscription_contains_price(&subscription, price_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(vec![stripe::UpdateSubscriptionItems {
|
||||
price: Some(price_id.to_string()),
|
||||
..Default::default()
|
||||
}]),
|
||||
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
|
||||
end_behavior: stripe::UpdateSubscriptionTrialSettingsEndBehavior {
|
||||
missing_payment_method: stripe::UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
},
|
||||
}),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_model(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
model: &StripeModel,
|
||||
model: &StripeModelTokenPrices,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
@ -271,7 +323,7 @@ impl StripeBilling {
|
|||
pub async fn bill_model_token_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
model: &StripeModel,
|
||||
model: &StripeModelTokenPrices,
|
||||
event: &llm::db::billing_event::Model,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
|
@ -343,11 +395,37 @@ impl StripeBilling {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_request_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
event_name: &str,
|
||||
requests: i32,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let idempotency_key = Uuid::new_v4();
|
||||
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("model_requests/{}", idempotency_key),
|
||||
event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: requests as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
model: &StripeModel,
|
||||
model: &StripeModelTokenPrices,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let first_of_next_month = Utc::now()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue