collab: Use billing_customers.has_overdue_invoices to gate subscription access (#24240)

This PR updates the check that prevents subscribing with overdue
subscriptions to use the `billing_customers.has_overdue_invoices` field
instead.

This will allow us to set the value of `has_overdue_invoices` to `false`
when the invoices have been paid.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-04 18:38:00 -05:00 committed by GitHub
parent aa3da35e8e
commit f366b97899
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 23 additions and 167 deletions

View file

@ -249,29 +249,31 @@ async fn create_billing_subscription(
));
}
if app.db.has_overdue_billing_subscriptions(user.id).await? {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"user has overdue billing subscriptions".into(),
));
let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
if let Some(existing_billing_customer) = &existing_billing_customer {
if existing_billing_customer.has_overdue_invoices {
return Err(Error::http(
StatusCode::PAYMENT_REQUIRED,
"user has overdue invoices".into(),
));
}
}
let customer_id =
if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
let customer = Customer::create(
&stripe_client,
CreateCustomer {
email: user.email_address.as_deref(),
..Default::default()
},
)
.await?;
let customer_id = if let Some(existing_customer) = existing_billing_customer {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
let customer = Customer::create(
&stripe_client,
CreateCustomer {
email: user.email_address.as_deref(),
..Default::default()
},
)
.await?;
customer.id
};
customer.id
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;

View file

@ -170,40 +170,4 @@ impl Database {
})
.await
}
/// Returns whether the user has any overdue billing subscriptions.
pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0)
}
/// Returns the count of the overdue billing subscriptions for the user with the specified ID.
///
/// This includes subscriptions:
/// - Whose status is `past_due`
/// - Whose status is `canceled` and the cancellation reason is `payment_failed`
pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
let past_due = billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::PastDue);
let payment_failed = billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Canceled)
.and(
billing_subscription::Column::StripeCancellationReason
.eq(StripeCancellationReason::PaymentFailed),
);
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(
billing_customer::Column::UserId
.eq(user_id)
.and(past_due.or(payment_failed)),
)
.count(&*tx)
.await?;
Ok(count as usize)
})
.await
}
}

View file

@ -1,6 +1,6 @@
use std::sync::Arc;
use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus};
use crate::db::billing_subscription::StripeSubscriptionStatus;
use crate::db::tests::new_test_user;
use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams};
use crate::test_both_dbs;
@ -88,113 +88,3 @@ async fn test_get_active_billing_subscriptions(db: &Arc<Database>) {
assert_eq!(subscription_count, 0);
}
}
test_both_dbs!(
test_count_overdue_billing_subscriptions,
test_count_overdue_billing_subscriptions_postgres,
test_count_overdue_billing_subscriptions_sqlite
);
async fn test_count_overdue_billing_subscriptions(db: &Arc<Database>) {
// A user with no subscription has no overdue billing subscriptions.
{
let user_id = new_test_user(db, "no-subscription-user@example.com").await;
let subscription_count = db
.count_overdue_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 0);
}
// A user with a past-due subscription has an overdue billing subscription.
{
let user_id = new_test_user(db, "past-due-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_past_due_user".into(),
})
.await
.unwrap();
assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string());
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_past_due_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::PastDue,
stripe_cancellation_reason: None,
})
.await
.unwrap();
let subscription_count = db
.count_overdue_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 1);
}
// A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription.
{
let user_id =
new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(),
})
.await
.unwrap();
assert_eq!(
customer.stripe_customer_id,
"cus_canceled_subscription_payment_failed_user".to_string()
);
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Canceled,
stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed),
})
.await
.unwrap();
let subscription_count = db
.count_overdue_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 1);
}
// A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions.
{
let user_id = new_test_user(db, "canceled-subscription-user@example.com").await;
let customer = db
.create_billing_customer(&CreateBillingCustomerParams {
user_id,
stripe_customer_id: "cus_canceled_subscription_user".into(),
})
.await
.unwrap();
assert_eq!(
customer.stripe_customer_id,
"cus_canceled_subscription_user".to_string()
);
db.create_billing_subscription(&CreateBillingSubscriptionParams {
billing_customer_id: customer.id,
stripe_subscription_id: "sub_canceled_subscription_user".into(),
stripe_subscription_status: StripeSubscriptionStatus::Canceled,
stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested),
})
.await
.unwrap();
let subscription_count = db
.count_overdue_billing_subscriptions(user_id)
.await
.unwrap();
assert_eq!(subscription_count, 0);
}
}