collab: Make the StripeBilling
object long-lived (#19090)
This PR makes the `StripeBilling` object long-lived so that we can make better use of the cached data on it. We now hold it on the `AppState` and spawn a background task to initialize the cache on startup. Release Notes: - N/A Co-authored-by: Richard <richard@zed.dev>
This commit is contained in:
parent
550064f80f
commit
f280b29859
5 changed files with 129 additions and 92 deletions
|
@ -210,6 +210,13 @@ async fn create_billing_subscription(
|
|||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
Err(Error::http(
|
||||
|
@ -236,7 +243,6 @@ async fn create_billing_subscription(
|
|||
};
|
||||
|
||||
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
|
||||
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
|
||||
let checkout_session_url = stripe_billing
|
||||
|
@ -716,8 +722,8 @@ async fn find_or_create_billing_customer(
|
|||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::warn!("failed to retrieve Stripe client");
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
|
@ -730,7 +736,7 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
|||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(&app, &llm_db, &stripe_client)
|
||||
sync_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.trace_err();
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
|
@ -742,10 +748,8 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
|||
async fn sync_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_client: &Arc<stripe::Client>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
|
||||
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
|
|
|
@ -31,6 +31,8 @@ use serde::Deserialize;
|
|||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
|
@ -274,6 +276,7 @@ pub struct AppState {
|
|||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub clickhouse_client: Option<::clickhouse::Client>,
|
||||
|
@ -317,12 +320,16 @@ impl AppState {
|
|||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
live_kit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_client: build_stripe_client(&config).map(Arc::new).log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
clickhouse_client: config
|
||||
|
|
|
@ -111,6 +111,13 @@ async fn main() -> Result<()> {
|
|||
|
||||
let state = AppState::new(config, Executor::Production).await?;
|
||||
|
||||
if let Some(stripe_billing) = state.stripe_billing.clone() {
|
||||
let executor = state.executor.clone();
|
||||
executor.spawn_detached(async move {
|
||||
stripe_billing.initialize().await.trace_err();
|
||||
});
|
||||
}
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
|
|
|
@ -5,10 +5,11 @@ use anyhow::Context;
|
|||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct StripeBilling {
|
||||
meters_by_event_name: HashMap<String, StripeMeter>,
|
||||
price_ids_by_meter_id: HashMap<String, stripe::PriceId>,
|
||||
meters_by_event_name: RwLock<HashMap<String, StripeMeter>>,
|
||||
price_ids_by_meter_id: RwLock<HashMap<String, stripe::PriceId>>,
|
||||
client: Arc<stripe::Client>,
|
||||
}
|
||||
|
||||
|
@ -25,32 +26,43 @@ struct StripeBillingPrice {
|
|||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub async fn new(client: Arc<stripe::Client>) -> Result<Self> {
|
||||
let mut meters_by_event_name = HashMap::default();
|
||||
for meter in StripeMeter::list(&client).await?.data {
|
||||
meters_by_event_name.insert(meter.event_name.clone(), meter);
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
meters_by_event_name: RwLock::new(HashMap::default()),
|
||||
price_ids_by_meter_id: RwLock::new(HashMap::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
log::info!("initializing StripeBilling");
|
||||
|
||||
{
|
||||
let meters = StripeMeter::list(&self.client).await?.data;
|
||||
let mut meters_by_event_name = self.meters_by_event_name.write().await;
|
||||
for meter in meters {
|
||||
meters_by_event_name.insert(meter.event_name.clone(), meter);
|
||||
}
|
||||
}
|
||||
|
||||
let mut price_ids_by_meter_id = HashMap::default();
|
||||
for price in stripe::Price::list(&client, &stripe::ListPrices::default())
|
||||
.await?
|
||||
.data
|
||||
{
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
price_ids_by_meter_id.insert(meter, price.id);
|
||||
let prices = stripe::Price::list(&self.client, &stripe::ListPrices::default())
|
||||
.await?
|
||||
.data;
|
||||
let mut price_ids_by_meter_id = self.price_ids_by_meter_id.write().await;
|
||||
for price in prices {
|
||||
if let Some(recurring) = price.recurring {
|
||||
if let Some(meter) = recurring.meter {
|
||||
price_ids_by_meter_id.insert(meter, price.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
meters_by_event_name,
|
||||
price_ids_by_meter_id,
|
||||
client,
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn register_model(&mut self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
pub async fn register_model(&self, model: &llm::db::model::Model) -> Result<StripeModel> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
|
@ -88,78 +100,84 @@ impl StripeBilling {
|
|||
}
|
||||
|
||||
async fn get_or_insert_price(
|
||||
&mut self,
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
price_per_million_tokens: Cents,
|
||||
) -> Result<StripeBillingPrice> {
|
||||
let meter = if let Some(meter) = self.meters_by_event_name.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
self.meters_by_event_name
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
let meter =
|
||||
if let Some(meter) = self.meters_by_event_name.read().await.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
self.meters_by_event_name
|
||||
.write()
|
||||
.await
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
|
||||
let price_id = if let Some(price_id) = self.price_ids_by_meter_id.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
let price_id =
|
||||
if let Some(price_id) = self.price_ids_by_meter_id.read().await.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
self.price_ids_by_meter_id
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
active: Some(true),
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
self.price_ids_by_meter_id
|
||||
.write()
|
||||
.await
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
|
||||
Ok(StripeBillingPrice {
|
||||
id: price_id,
|
||||
|
|
|
@ -639,6 +639,7 @@ impl TestServer {
|
|||
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
stripe_billing: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
clickhouse_client: None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue