collab: Add endpoint for initiating a billing subscription (#15452)

This PR adds a new `POST /billing/subscriptions` endpoint that can be
used to initiate a billing subscription.

The endpoint will use the provided `github_user_id` to look up a user,
generate a Stripe Checkout session, and then return the URL.

The caller would then redirect the user to the URL to initiate the
checkout flow.

Here's an example of how to call it:

```sh
curl -X POST "http://localhost:8080/billing/subscriptions" \
     -H "Authorization: <ADMIN_TOKEN>" \
     -H "Content-Type: application/json" \
     -d '{"github_user_id": 12345}'
```

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-07-29 17:31:36 -04:00 committed by GitHub
parent 8bb34fd93e
commit e15d59c445
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 258 additions and 0 deletions

96
Cargo.lock generated
View file

@ -834,6 +834,26 @@ dependencies = [
"syn 2.0.59",
]
[[package]]
name = "async-stripe"
version = "0.37.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5"
dependencies = [
"futures-util",
"http-types",
"hyper",
"hyper-rustls",
"serde",
"serde_json",
"serde_path_to_error",
"serde_qs 0.10.1",
"smart-default",
"smol_str",
"thiserror",
"tokio",
]
[[package]]
name = "async-tar"
version = "0.4.2"
@ -1462,6 +1482,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.7"
@ -2425,6 +2451,7 @@ dependencies = [
"anthropic",
"anyhow",
"assistant",
"async-stripe",
"async-trait",
"async-tungstenite",
"audio",
@ -5254,6 +5281,27 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
[[package]]
name = "http-types"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
dependencies = [
"anyhow",
"async-channel 1.9.0",
"base64 0.13.1",
"futures-lite 1.13.0",
"http 0.2.9",
"infer",
"pin-project-lite",
"rand 0.7.3",
"serde",
"serde_json",
"serde_qs 0.8.5",
"serde_urlencoded",
"url",
]
[[package]]
name = "http_client"
version = "0.1.0"
@ -5512,6 +5560,12 @@ version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "infer"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
[[package]]
name = "inherent"
version = "1.0.10"
@ -9564,6 +9618,28 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_qs"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
dependencies = [
"percent-encoding",
"serde",
"thiserror",
]
[[package]]
name = "serde_qs"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
dependencies = [
"percent-encoding",
"serde",
"thiserror",
]
[[package]]
name = "serde_repr"
version = "0.1.16"
@ -9880,6 +9956,17 @@ dependencies = [
"serde",
]
[[package]]
name = "smart-default"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "smol"
version = "1.3.0"
@ -9897,6 +9984,15 @@ dependencies = [
"futures-lite 1.13.0",
]
[[package]]
name = "smol_str"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
dependencies = [
"serde",
]
[[package]]
name = "snippet"
version = "0.1.0"

View file

@ -309,6 +309,7 @@ async-dispatcher = "0.1"
async-fs = "1.6"
async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" }
async-recursion = "1.0.0"
async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] }
async-tar = "0.4.2"
async-trait = "0.1"
async-tungstenite = "0.23"

View file

@ -20,6 +20,7 @@ test-support = ["sqlite"]
[dependencies]
anthropic.workspace = true
anyhow.workspace = true
async-stripe.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-s3 = { version = "1.15.0" }
@ -116,3 +117,6 @@ util.workspace = true
workspace = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
headless.workspace = true
[package.metadata.cargo-machete]
ignored = ["async-stripe"]

View file

@ -1,3 +1,4 @@
pub mod billing;
pub mod contributors;
pub mod events;
pub mod extensions;
@ -31,6 +32,7 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
.route("/user", get(get_authenticated_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
.merge(billing::router())
.merge(contributors::router())
.layer(
ServiceBuilder::new()

View file

@ -0,0 +1,88 @@
use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use axum::{extract, routing::post, Extension, Json, Router};
use collections::HashSet;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId};
use crate::{AppState, Error, Result};
pub fn router() -> Router {
Router::new().route("/billing/subscriptions", post(create_billing_subscription))
}
#[derive(Debug, Deserialize)]
struct CreateBillingSubscriptionBody {
github_user_id: i32,
}
#[derive(Debug, Serialize)]
struct CreateBillingSubscriptionResponse {
checkout_session_url: String,
}
/// Initiates a Stripe Checkout session for creating a billing subscription.
async fn create_billing_subscription(
Extension(app): Extension<Arc<AppState>>,
extract::Json(body): extract::Json<CreateBillingSubscriptionBody>,
) -> Result<Json<CreateBillingSubscriptionResponse>> {
let user = app
.db
.get_user_by_github_user_id(body.github_user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some((stripe_client, stripe_price_id)) = app
.stripe_client
.clone()
.zip(app.config.stripe_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
Err(Error::Http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let existing_customer_id = {
let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?;
let distinct_customer_ids = existing_subscriptions
.iter()
.map(|subscription| subscription.stripe_customer_id.as_str())
.collect::<HashSet<_>>();
// Sanity: Make sure we can determine a single Stripe customer ID for the user.
if distinct_customer_ids.len() > 1 {
Err(anyhow!("user has multiple existing customer IDs"))?;
}
distinct_customer_ids
.into_iter()
.next()
.map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err)))
.transpose()
}?;
let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = existing_customer_id;
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some("https://zed.dev/billing/success");
CheckoutSession::create(&stripe_client, params).await?
};
Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url: checkout_session
.url
.ok_or_else(|| anyhow!("no checkout session URL"))?,
}))
}

View file

@ -32,6 +32,26 @@ impl Database {
.await
}
/// Returns all of the billing subscriptions for the user with the specified ID.
///
/// Note that this returns the subscriptions regardless of their status.
/// If you're wanting to check if a use has an active billing subscription,
/// use `get_active_billing_subscriptions` instead.
pub async fn get_billing_subscriptions(
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.filter(billing_subscription::Column::UserId.eq(user_id))
.all(&*tx)
.await?;
Ok(subscriptions)
})
.await
}
/// Returns all of the active billing subscriptions for the user with the specified ID.
pub async fn get_active_billing_subscriptions(
&self,

View file

@ -61,6 +61,17 @@ impl Database {
.await
}
/// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally.
pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result<Option<User>> {
self.transaction(|tx| async move {
Ok(user::Entity::find()
.filter(user::Column::GithubUserId.eq(github_user_id))
.one(&*tx)
.await?)
})
.await
}
/// Returns a user by GitHub login. There are no access checks here, so this should only be used internally.
pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
self.transaction(|tx| async move {

View file

@ -26,6 +26,7 @@ pub enum Error {
Http(StatusCode, String),
Database(sea_orm::error::DbErr),
Internal(anyhow::Error),
Stripe(stripe::StripeError),
}
impl From<anyhow::Error> for Error {
@ -40,6 +41,12 @@ impl From<sea_orm::error::DbErr> for Error {
}
}
impl From<stripe::StripeError> for Error {
fn from(error: stripe::StripeError) -> Self {
Self::Stripe(error)
}
}
impl From<axum::Error> for Error {
fn from(error: axum::Error) -> Self {
Self::Internal(error.into())
@ -81,6 +88,14 @@ impl IntoResponse for Error {
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
Error::Stripe(error) => {
log::error!(
"HTTP error {}: {:?}",
StatusCode::INTERNAL_SERVER_ERROR,
&error
);
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
}
}
}
}
@ -91,6 +106,7 @@ impl std::fmt::Debug for Error {
Error::Http(code, message) => (code, message).fmt(f),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@ -101,6 +117,7 @@ impl std::fmt::Display for Error {
Error::Http(code, message) => write!(f, "{code}: {message}"),
Error::Database(error) => error.fmt(f),
Error::Internal(error) => error.fmt(f),
Error::Stripe(error) => error.fmt(f),
}
}
}
@ -137,6 +154,8 @@ pub struct Config {
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_price_id: Option<Arc<str>>,
pub supermaven_admin_api_key: Option<Arc<str>>,
}
@ -150,6 +169,7 @@ pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub clickhouse_client: Option<clickhouse::Client>,
@ -183,6 +203,10 @@ impl AppState {
db: db.clone(),
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_client: build_stripe_client(&config)
.await
.map(|client| Arc::new(client))
.log_err(),
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
clickhouse_client: config
@ -195,6 +219,15 @@ impl AppState {
}
}
async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
let api_key = config
.stripe_api_key
.as_ref()
.ok_or_else(|| anyhow!("missing stripe_api_key"))?;
Ok(stripe::Client::new(api_key))
}
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
let keys = aws_sdk_s3::config::Credentials::new(
config

View file

@ -637,6 +637,7 @@ impl TestServer {
db: test_db.db().clone(),
live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
stripe_client: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
clickhouse_client: None,
@ -669,6 +670,8 @@ impl TestServer {
auto_join_channel_id: None,
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_price_id: None,
supermaven_admin_api_key: None,
},
})