From e15d59c44597282f0dafb462639b2bfc8815cda5 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 29 Jul 2024 17:31:36 -0400 Subject: [PATCH] collab: Add endpoint for initiating a billing subscription (#15452) This PR adds a new `POST /billing/subscriptions` endpoint that can be used to initiate a billing subscription. The endpoint will use the provided `github_user_id` to look up a user, generate a Stripe Checkout session, and then return the URL. The caller would then redirect the user to the URL to initiate the checkout flow. Here's an example of how to call it: ```sh curl -X POST "http://localhost:8080/billing/subscriptions" \ -H "Authorization: " \ -H "Content-Type: application/json" \ -d '{"github_user_id": 12345}' ``` Release Notes: - N/A --- Cargo.lock | 96 +++++++++++++++++++ Cargo.toml | 1 + crates/collab/Cargo.toml | 4 + crates/collab/src/api.rs | 2 + crates/collab/src/api/billing.rs | 88 +++++++++++++++++ .../src/db/queries/billing_subscriptions.rs | 20 ++++ crates/collab/src/db/queries/users.rs | 11 +++ crates/collab/src/lib.rs | 33 +++++++ crates/collab/src/tests/test_server.rs | 3 + 9 files changed, 258 insertions(+) create mode 100644 crates/collab/src/api/billing.rs diff --git a/Cargo.lock b/Cargo.lock index c9bd051aeb..a8baa9788b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -834,6 +834,26 @@ dependencies = [ "syn 2.0.59", ] +[[package]] +name = "async-stripe" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5" +dependencies = [ + "futures-util", + "http-types", + "hyper", + "hyper-rustls", + "serde", + "serde_json", + "serde_path_to_error", + "serde_qs 0.10.1", + "smart-default", + "smol_str", + "thiserror", + "tokio", +] + [[package]] name = "async-tar" version = "0.4.2" @@ -1462,6 +1482,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.7" @@ -2425,6 +2451,7 @@ dependencies = [ "anthropic", "anyhow", "assistant", + "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -5254,6 +5281,27 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" +[[package]] +name = "http-types" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" +dependencies = [ + "anyhow", + "async-channel 1.9.0", + "base64 0.13.1", + "futures-lite 1.13.0", + "http 0.2.9", + "infer", + "pin-project-lite", + "rand 0.7.3", + "serde", + "serde_json", + "serde_qs 0.8.5", + "serde_urlencoded", + "url", +] + [[package]] name = "http_client" version = "0.1.0" @@ -5512,6 +5560,12 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "infer" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" + [[package]] name = "inherent" version = "1.0.10" @@ -9564,6 +9618,28 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + +[[package]] +name = "serde_qs" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_repr" version = "0.1.16" @@ -9880,6 +9956,17 @@ dependencies = [ "serde", ] +[[package]] +name = "smart-default" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "smol" version = "1.3.0" @@ -9897,6 +9984,15 @@ dependencies = [ "futures-lite 1.13.0", ] +[[package]] +name = "smol_str" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" +dependencies = [ + "serde", +] + [[package]] name = "snippet" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5c407ed0ef..a5be5bc027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,6 +309,7 @@ async-dispatcher = "0.1" async-fs = "1.6" async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" } async-recursion = "1.0.0" +async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] } async-tar = "0.4.2" async-trait = "0.1" async-tungstenite = "0.23" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b321c4b95f..289212a6b1 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -20,6 +20,7 @@ test-support = ["sqlite"] [dependencies] anthropic.workspace = true anyhow.workspace = true +async-stripe.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } aws-sdk-s3 = { version = "1.15.0" } @@ -116,3 +117,6 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } headless.workspace = true + +[package.metadata.cargo-machete] +ignored = ["async-stripe"] diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index f5902114a1..35db23d58d 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,3 +1,4 @@ +pub mod billing; pub mod contributors; pub mod events; pub mod extensions; @@ -31,6 +32,7 @@ pub fn routes(rpc_server: Option>, state: Arc) -> Rou .route("/user", get(get_authenticated_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) + .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs new file mode 100644 index 0000000000..c454f4884b --- /dev/null +++ b/crates/collab/src/api/billing.rs @@ -0,0 +1,88 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::anyhow; +use axum::{extract, routing::post, Extension, Json, Router}; +use collections::HashSet; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId}; + +use crate::{AppState, Error, Result}; + +pub fn router() -> Router { + Router::new().route("/billing/subscriptions", post(create_billing_subscription)) +} + +#[derive(Debug, Deserialize)] +struct CreateBillingSubscriptionBody { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct CreateBillingSubscriptionResponse { + checkout_session_url: String, +} + +/// Initiates a Stripe Checkout session for creating a billing subscription. +async fn create_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let Some((stripe_client, stripe_price_id)) = app + .stripe_client + .clone() + .zip(app.config.stripe_price_id.clone()) + else { + log::error!("failed to retrieve Stripe client or price ID"); + Err(Error::Http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let existing_customer_id = { + let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?; + let distinct_customer_ids = existing_subscriptions + .iter() + .map(|subscription| subscription.stripe_customer_id.as_str()) + .collect::>(); + // Sanity: Make sure we can determine a single Stripe customer ID for the user. + if distinct_customer_ids.len() > 1 { + Err(anyhow!("user has multiple existing customer IDs"))?; + } + + distinct_customer_ids + .into_iter() + .next() + .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err))) + .transpose() + }?; + + let checkout_session = { + let mut params = CreateCheckoutSession::new(); + params.mode = Some(stripe::CheckoutSessionMode::Subscription); + params.customer = existing_customer_id; + params.client_reference_id = Some(user.github_login.as_str()); + params.line_items = Some(vec![CreateCheckoutSessionLineItems { + price: Some(stripe_price_id.to_string()), + quantity: Some(1), + ..Default::default() + }]); + params.success_url = Some("https://zed.dev/billing/success"); + + CheckoutSession::create(&stripe_client, params).await? + }; + + Ok(Json(CreateBillingSubscriptionResponse { + checkout_session_url: checkout_session + .url + .ok_or_else(|| anyhow!("no checkout session URL"))?, + })) +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index ce782e5eb3..fcacf7ee22 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -32,6 +32,26 @@ impl Database { .await } + /// Returns all of the billing subscriptions for the user with the specified ID. + /// + /// Note that this returns the subscriptions regardless of their status. + /// If you're wanting to check if a use has an active billing subscription, + /// use `get_active_billing_subscriptions` instead. + pub async fn get_billing_subscriptions( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let subscriptions = billing_subscription::Entity::find() + .filter(billing_subscription::Column::UserId.eq(user_id)) + .all(&*tx) + .await?; + + Ok(subscriptions) + }) + .await + } + /// Returns all of the active billing subscriptions for the user with the specified ID. pub async fn get_active_billing_subscriptions( &self, diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index ff5c4d5c17..fde6aa0b12 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -61,6 +61,17 @@ impl Database { .await } + /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally. + pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&*tx) + .await?) + }) + .await + } + /// Returns a user by GitHub login. There are no access checks here, so this should only be used internally. pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { self.transaction(|tx| async move { diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 2673ca3fb8..89705428aa 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -26,6 +26,7 @@ pub enum Error { Http(StatusCode, String), Database(sea_orm::error::DbErr), Internal(anyhow::Error), + Stripe(stripe::StripeError), } impl From for Error { @@ -40,6 +41,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: stripe::StripeError) -> Self { + Self::Stripe(error) + } +} + impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -81,6 +88,14 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } + Error::Stripe(error) => { + log::error!( + "HTTP error {}: {:?}", + StatusCode::INTERNAL_SERVER_ERROR, + &error + ); + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } } } } @@ -91,6 +106,7 @@ impl std::fmt::Debug for Error { Error::Http(code, message) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), + Error::Stripe(error) => error.fmt(f), } } } @@ -101,6 +117,7 @@ impl std::fmt::Display for Error { Error::Http(code, message) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), + Error::Stripe(error) => error.fmt(f), } } } @@ -137,6 +154,8 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, + pub stripe_api_key: Option, + pub stripe_price_id: Option>, pub supermaven_admin_api_key: Option>, } @@ -150,6 +169,7 @@ pub struct AppState { pub db: Arc, pub live_kit_client: Option>, pub blob_store_client: Option, + pub stripe_client: Option>, pub rate_limiter: Arc, pub executor: Executor, pub clickhouse_client: Option, @@ -183,6 +203,10 @@ impl AppState { db: db.clone(), live_kit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), + stripe_client: build_stripe_client(&config) + .await + .map(|client| Arc::new(client)) + .log_err(), rate_limiter: Arc::new(RateLimiter::new(db)), executor, clickhouse_client: config @@ -195,6 +219,15 @@ impl AppState { } } +async fn build_stripe_client(config: &Config) -> anyhow::Result { + let api_key = config + .stripe_api_key + .as_ref() + .ok_or_else(|| anyhow!("missing stripe_api_key"))?; + + Ok(stripe::Client::new(api_key)) +} + async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index dea29b697f..7a3bc92a5f 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -637,6 +637,7 @@ impl TestServer { db: test_db.db().clone(), live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), blob_store_client: None, + stripe_client: None, rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), executor, clickhouse_client: None, @@ -669,6 +670,8 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, + stripe_api_key: None, + stripe_price_id: None, supermaven_admin_api_key: None, }, })