diff --git a/Cargo.lock b/Cargo.lock index c9bd051aeb..a8baa9788b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -834,6 +834,26 @@ dependencies = [ "syn 2.0.59", ] +[[package]] +name = "async-stripe" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5" +dependencies = [ + "futures-util", + "http-types", + "hyper", + "hyper-rustls", + "serde", + "serde_json", + "serde_path_to_error", + "serde_qs 0.10.1", + "smart-default", + "smol_str", + "thiserror", + "tokio", +] + [[package]] name = "async-tar" version = "0.4.2" @@ -1462,6 +1482,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.7" @@ -2425,6 +2451,7 @@ dependencies = [ "anthropic", "anyhow", "assistant", + "async-stripe", "async-trait", "async-tungstenite", "audio", @@ -5254,6 +5281,27 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" +[[package]] +name = "http-types" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad" +dependencies = [ + "anyhow", + "async-channel 1.9.0", + "base64 0.13.1", + "futures-lite 1.13.0", + "http 0.2.9", + "infer", + "pin-project-lite", + "rand 0.7.3", + "serde", + "serde_json", + "serde_qs 0.8.5", + "serde_urlencoded", + "url", +] + [[package]] name = "http_client" version = "0.1.0" @@ -5512,6 +5560,12 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "infer" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" + [[package]] name = "inherent" version = "1.0.10" @@ -9564,6 +9618,28 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + +[[package]] +name = "serde_qs" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_repr" version = "0.1.16" @@ -9880,6 +9956,17 @@ dependencies = [ "serde", ] +[[package]] +name = "smart-default" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "smol" version = "1.3.0" @@ -9897,6 +9984,15 @@ dependencies = [ "futures-lite 1.13.0", ] +[[package]] +name = "smol_str" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9" +dependencies = [ + "serde", +] + [[package]] name = "snippet" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5c407ed0ef..a5be5bc027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,6 +309,7 @@ async-dispatcher = "0.1" async-fs = "1.6" async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "82d00a04211cf4e1236029aa03e6b6ce2a74c553" } async-recursion = "1.0.0" +async-stripe = { version = "0.37", default-features = false, features = ["runtime-tokio-hyper-rustls", "billing", "checkout"] } async-tar = "0.4.2" async-trait = "0.1" async-tungstenite = "0.23" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b321c4b95f..289212a6b1 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -20,6 +20,7 @@ test-support = ["sqlite"] [dependencies] anthropic.workspace = true anyhow.workspace = true +async-stripe.workspace = true async-tungstenite.workspace = true aws-config = { version = "1.1.5" } aws-sdk-s3 = { version = "1.15.0" } @@ -116,3 +117,6 @@ util.workspace = true workspace = { workspace = true, features = ["test-support"] } worktree = { workspace = true, features = ["test-support"] } headless.workspace = true + +[package.metadata.cargo-machete] +ignored = ["async-stripe"] diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index f5902114a1..35db23d58d 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,3 +1,4 @@ +pub mod billing; pub mod contributors; pub mod events; pub mod extensions; @@ -31,6 +32,7 @@ pub fn routes(rpc_server: Option>, state: Arc) -> Rou .route("/user", get(get_authenticated_user)) .route("/users/:id/access_tokens", post(create_access_token)) .route("/rpc_server_snapshot", get(get_rpc_server_snapshot)) + .merge(billing::router()) .merge(contributors::router()) .layer( ServiceBuilder::new() diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs new file mode 100644 index 0000000000..c454f4884b --- /dev/null +++ b/crates/collab/src/api/billing.rs @@ -0,0 +1,88 @@ +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::anyhow; +use axum::{extract, routing::post, Extension, Json, Router}; +use collections::HashSet; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use stripe::{CheckoutSession, CreateCheckoutSession, CreateCheckoutSessionLineItems, CustomerId}; + +use crate::{AppState, Error, Result}; + +pub fn router() -> Router { + Router::new().route("/billing/subscriptions", post(create_billing_subscription)) +} + +#[derive(Debug, Deserialize)] +struct CreateBillingSubscriptionBody { + github_user_id: i32, +} + +#[derive(Debug, Serialize)] +struct CreateBillingSubscriptionResponse { + checkout_session_url: String, +} + +/// Initiates a Stripe Checkout session for creating a billing subscription. +async fn create_billing_subscription( + Extension(app): Extension>, + extract::Json(body): extract::Json, +) -> Result> { + let user = app + .db + .get_user_by_github_user_id(body.github_user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + + let Some((stripe_client, stripe_price_id)) = app + .stripe_client + .clone() + .zip(app.config.stripe_price_id.clone()) + else { + log::error!("failed to retrieve Stripe client or price ID"); + Err(Error::Http( + StatusCode::NOT_IMPLEMENTED, + "not supported".into(), + ))? + }; + + let existing_customer_id = { + let existing_subscriptions = app.db.get_billing_subscriptions(user.id).await?; + let distinct_customer_ids = existing_subscriptions + .iter() + .map(|subscription| subscription.stripe_customer_id.as_str()) + .collect::>(); + // Sanity: Make sure we can determine a single Stripe customer ID for the user. + if distinct_customer_ids.len() > 1 { + Err(anyhow!("user has multiple existing customer IDs"))?; + } + + distinct_customer_ids + .into_iter() + .next() + .map(|id| CustomerId::from_str(id).map_err(|err| anyhow!(err))) + .transpose() + }?; + + let checkout_session = { + let mut params = CreateCheckoutSession::new(); + params.mode = Some(stripe::CheckoutSessionMode::Subscription); + params.customer = existing_customer_id; + params.client_reference_id = Some(user.github_login.as_str()); + params.line_items = Some(vec![CreateCheckoutSessionLineItems { + price: Some(stripe_price_id.to_string()), + quantity: Some(1), + ..Default::default() + }]); + params.success_url = Some("https://zed.dev/billing/success"); + + CheckoutSession::create(&stripe_client, params).await? + }; + + Ok(Json(CreateBillingSubscriptionResponse { + checkout_session_url: checkout_session + .url + .ok_or_else(|| anyhow!("no checkout session URL"))?, + })) +} diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index ce782e5eb3..fcacf7ee22 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -32,6 +32,26 @@ impl Database { .await } + /// Returns all of the billing subscriptions for the user with the specified ID. + /// + /// Note that this returns the subscriptions regardless of their status. + /// If you're wanting to check if a use has an active billing subscription, + /// use `get_active_billing_subscriptions` instead. + pub async fn get_billing_subscriptions( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let subscriptions = billing_subscription::Entity::find() + .filter(billing_subscription::Column::UserId.eq(user_id)) + .all(&*tx) + .await?; + + Ok(subscriptions) + }) + .await + } + /// Returns all of the active billing subscriptions for the user with the specified ID. pub async fn get_active_billing_subscriptions( &self, diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index ff5c4d5c17..fde6aa0b12 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -61,6 +61,17 @@ impl Database { .await } + /// Returns a user by GitHub user ID. There are no access checks here, so this should only be used internally. + pub async fn get_user_by_github_user_id(&self, github_user_id: i32) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&*tx) + .await?) + }) + .await + } + /// Returns a user by GitHub login. There are no access checks here, so this should only be used internally. pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { self.transaction(|tx| async move { diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 2673ca3fb8..89705428aa 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -26,6 +26,7 @@ pub enum Error { Http(StatusCode, String), Database(sea_orm::error::DbErr), Internal(anyhow::Error), + Stripe(stripe::StripeError), } impl From for Error { @@ -40,6 +41,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: stripe::StripeError) -> Self { + Self::Stripe(error) + } +} + impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -81,6 +88,14 @@ impl IntoResponse for Error { ); (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } + Error::Stripe(error) => { + log::error!( + "HTTP error {}: {:?}", + StatusCode::INTERNAL_SERVER_ERROR, + &error + ); + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } } } } @@ -91,6 +106,7 @@ impl std::fmt::Debug for Error { Error::Http(code, message) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), + Error::Stripe(error) => error.fmt(f), } } } @@ -101,6 +117,7 @@ impl std::fmt::Display for Error { Error::Http(code, message) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), + Error::Stripe(error) => error.fmt(f), } } } @@ -137,6 +154,8 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, + pub stripe_api_key: Option, + pub stripe_price_id: Option>, pub supermaven_admin_api_key: Option>, } @@ -150,6 +169,7 @@ pub struct AppState { pub db: Arc, pub live_kit_client: Option>, pub blob_store_client: Option, + pub stripe_client: Option>, pub rate_limiter: Arc, pub executor: Executor, pub clickhouse_client: Option, @@ -183,6 +203,10 @@ impl AppState { db: db.clone(), live_kit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), + stripe_client: build_stripe_client(&config) + .await + .map(|client| Arc::new(client)) + .log_err(), rate_limiter: Arc::new(RateLimiter::new(db)), executor, clickhouse_client: config @@ -195,6 +219,15 @@ impl AppState { } } +async fn build_stripe_client(config: &Config) -> anyhow::Result { + let api_key = config + .stripe_api_key + .as_ref() + .ok_or_else(|| anyhow!("missing stripe_api_key"))?; + + Ok(stripe::Client::new(api_key)) +} + async fn build_blob_store_client(config: &Config) -> anyhow::Result { let keys = aws_sdk_s3::config::Credentials::new( config diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index dea29b697f..7a3bc92a5f 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -637,6 +637,7 @@ impl TestServer { db: test_db.db().clone(), live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), blob_store_client: None, + stripe_client: None, rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), executor, clickhouse_client: None, @@ -669,6 +670,8 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, + stripe_api_key: None, + stripe_price_id: None, supermaven_admin_api_key: None, }, })