diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5764aceea5..7d8bd8eb1b 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -422,6 +422,15 @@ CREATE TABLE dev_server_projects ( paths TEXT NOT NULL ); +CREATE TABLE IF NOT EXISTS billing_preferences ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + user_id INTEGER NOT NULL REFERENCES users(id), + max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL +); + +CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id); + CREATE TABLE IF NOT EXISTS billing_customers ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, diff --git a/crates/collab/migrations/20241009190639_add_billing_preferences.sql b/crates/collab/migrations/20241009190639_add_billing_preferences.sql new file mode 100644 index 0000000000..9aa5a1a303 --- /dev/null +++ b/crates/collab/migrations/20241009190639_add_billing_preferences.sql @@ -0,0 +1,8 @@ +create table if not exists billing_preferences ( + id serial primary key, + created_at timestamp without time zone not null default now(), + user_id integer not null references users(id) on delete cascade, + max_monthly_llm_usage_spending_in_cents integer not null +); + +create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id); diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index d2e43714a8..dca5a772f4 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -26,15 +26,19 @@ use crate::db::billing_subscription::{self, StripeSubscriptionStatus}; use crate::db::{ billing_customer, BillingSubscriptionId, CreateBillingCustomerParams, CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams, - UpdateBillingSubscriptionParams, + UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams, }; use crate::llm::db::LlmDatabase; -use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT; +use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT}; use crate::rpc::ResultExt as _; use crate::{AppState, Error, Result}; pub fn router() -> Router { Router::new() + .route( + "/billing/preferences", + get(get_billing_preferences).put(update_billing_preferences), + ) .route( "/billing/subscriptions", get(list_billing_subscriptions).post(create_billing_subscription), @@ -45,6 +49,82 @@ pub fn router() -> Router { ) } +#[derive(Debug, Deserialize)] +struct GetBillingPreferencesParams { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct BillingPreferencesResponse { + max_monthly_llm_usage_spending_in_cents: i32, +} + +async fn get_billing_preferences( + Extension(app): Extension>, + Query(params): Query, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(params.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let preferences = app.db.get_billing_preferences(user.id).await?; + + Ok(Json(BillingPreferencesResponse { + max_monthly_llm_usage_spending_in_cents: preferences + .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| { + preferences.max_monthly_llm_usage_spending_in_cents + }), + })) +} + +#[derive(Debug, Deserialize)] +struct UpdateBillingPreferencesBody { + github_user_id: i32, + max_monthly_llm_usage_spending_in_cents: i32, +} + +async fn update_billing_preferences( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let billing_preferences = + if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? { + app.db + .update_billing_preferences( + user.id, + &UpdateBillingPreferencesParams { + max_monthly_llm_usage_spending_in_cents: ActiveValue::set( + body.max_monthly_llm_usage_spending_in_cents, + ), + }, + ) + .await? + } else { + app.db + .create_billing_preferences( + user.id, + &crate::db::CreateBillingPreferencesParams { + max_monthly_llm_usage_spending_in_cents: body + .max_monthly_llm_usage_spending_in_cents, + }, + ) + .await? + }; + + Ok(Json(BillingPreferencesResponse { + max_monthly_llm_usage_spending_in_cents: billing_preferences + .max_monthly_llm_usage_spending_in_cents, + })) +} + #[derive(Debug, Deserialize)] struct ListBillingSubscriptionsParams { github_user_id: i32, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index f717566824..e966548493 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -42,6 +42,9 @@ pub use tests::TestDb; pub use ids::*; pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams}; +pub use queries::billing_preferences::{ + CreateBillingPreferencesParams, UpdateBillingPreferencesParams, +}; pub use queries::billing_subscriptions::{ CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams, }; diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index 9bf767329d..3a5bcff558 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -72,6 +72,7 @@ macro_rules! id_type { id_type!(AccessTokenId); id_type!(BillingCustomerId); id_type!(BillingSubscriptionId); +id_type!(BillingPreferencesId); id_type!(BufferId); id_type!(ChannelBufferCollaboratorId); id_type!(ChannelChatParticipantId); diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 459f66d89a..9c277790f9 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -2,6 +2,7 @@ use super::*; pub mod access_tokens; pub mod billing_customers; +pub mod billing_preferences; pub mod billing_subscriptions; pub mod buffers; pub mod channels; diff --git a/crates/collab/src/db/queries/billing_preferences.rs b/crates/collab/src/db/queries/billing_preferences.rs new file mode 100644 index 0000000000..fa35ffc068 --- /dev/null +++ b/crates/collab/src/db/queries/billing_preferences.rs @@ -0,0 +1,75 @@ +use super::*; + +#[derive(Debug)] +pub struct CreateBillingPreferencesParams { + pub max_monthly_llm_usage_spending_in_cents: i32, +} + +#[derive(Debug, Default)] +pub struct UpdateBillingPreferencesParams { + pub max_monthly_llm_usage_spending_in_cents: ActiveValue, +} + +impl Database { + /// Returns the billing preferences for the given user, if they exist. + pub async fn get_billing_preferences( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + Ok(billing_preference::Entity::find() + .filter(billing_preference::Column::UserId.eq(user_id)) + .one(&*tx) + .await?) + }) + .await + } + + /// Creates new billing preferences for the given user. + pub async fn create_billing_preferences( + &self, + user_id: UserId, + params: &CreateBillingPreferencesParams, + ) -> Result { + self.transaction(|tx| async move { + let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel { + user_id: ActiveValue::set(user_id), + max_monthly_llm_usage_spending_in_cents: ActiveValue::set( + params.max_monthly_llm_usage_spending_in_cents, + ), + ..Default::default() + }) + .exec_with_returning(&*tx) + .await?; + + Ok(preferences) + }) + .await + } + + /// Updates the billing preferences for the given user. + pub async fn update_billing_preferences( + &self, + user_id: UserId, + params: &UpdateBillingPreferencesParams, + ) -> Result { + self.transaction(|tx| async move { + let preferences = billing_preference::Entity::update_many() + .set(billing_preference::ActiveModel { + max_monthly_llm_usage_spending_in_cents: params + .max_monthly_llm_usage_spending_in_cents + .clone(), + ..Default::default() + }) + .filter(billing_preference::Column::UserId.eq(user_id)) + .exec_with_returning(&*tx) + .await?; + + Ok(preferences + .into_iter() + .next() + .ok_or_else(|| anyhow!("billing preferences not found"))?) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 07d070b569..01d3835dc1 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -1,5 +1,6 @@ pub mod access_token; pub mod billing_customer; +pub mod billing_preference; pub mod billing_subscription; pub mod buffer; pub mod buffer_operation; diff --git a/crates/collab/src/db/tables/billing_preference.rs b/crates/collab/src/db/tables/billing_preference.rs new file mode 100644 index 0000000000..0ad92c25d6 --- /dev/null +++ b/crates/collab/src/db/tables/billing_preference.rs @@ -0,0 +1,30 @@ +use crate::db::{BillingPreferencesId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "billing_preferences")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: BillingPreferencesId, + pub created_at: DateTime, + pub user_id: UserId, + pub max_monthly_llm_usage_spending_in_cents: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index c733beb596..efe28a81c7 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -442,6 +442,12 @@ fn normalize_model_name(known_models: Vec, name: String) -> String { /// before they have to pay. pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5); +/// The default value to use for maximum spend per month if the user did not +/// explicitly set a maximum spend. +/// +/// Used to prevent surprise bills. +pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10); + /// The maximum lifetime spending an individual user can reach before being cut off. const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);