collab: Add Zed Pro checkout flow (#28776)

This PR adds support for initiating a checkout flow for Zed Pro.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-15 11:45:51 -04:00 committed by GitHub
parent afabcd1547
commit 90dec1d451
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 70 additions and 14 deletions

View file

@ -198,9 +198,16 @@ async fn list_billing_subscriptions(
})) }))
} }
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ProductCode {
ZedPro,
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct CreateBillingSubscriptionBody { struct CreateBillingSubscriptionBody {
github_user_id: i32, github_user_id: i32,
product: Option<ProductCode>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -274,15 +281,30 @@ async fn create_billing_subscription(
customer.id customer.id
}; };
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; let checkout_session_url = match body.product {
let stripe_model = stripe_billing.register_model(default_model).await?; Some(ProductCode::ZedPro) => {
let success_url = format!( let success_url = format!(
"{}/account?checkout_complete=1", "{}/account?checkout_complete=1",
app.config.zed_dot_dev_url() app.config.zed_dot_dev_url()
); );
let checkout_session_url = stripe_billing stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url) .checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.await?; .await?
}
None => {
let default_model =
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?
}
};
Ok(Json(CreateBillingSubscriptionResponse { Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url, checkout_session_url,
})) }))

View file

@ -182,6 +182,7 @@ pub struct Config {
pub slack_panics_webhook: Option<String>, pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>, pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>, pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>, pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>, pub user_backfiller_github_access_token: Option<Arc<str>>,
} }
@ -237,6 +238,7 @@ impl Config {
migrations_path: None, migrations_path: None,
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_zed_pro_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
kinesis_region: None, kinesis_region: None,
@ -322,9 +324,12 @@ impl AppState {
llm_db, llm_db,
livekit_client, livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(), blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client stripe_billing: stripe_client.clone().map(|stripe_client| {
.clone() Arc::new(StripeBilling::new(
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), stripe_client,
config.stripe_zed_pro_price_id.clone(),
))
}),
stripe_client, stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)), rate_limiter: Arc::new(RateLimiter::new(db)),
executor, executor,

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{Cents, Result, llm}; use crate::{Cents, Result, llm};
use anyhow::Context as _; use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc}; use chrono::{Datelike, Utc};
use collections::HashMap; use collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -10,6 +10,7 @@ use tokio::sync::RwLock;
pub struct StripeBilling { pub struct StripeBilling {
state: RwLock<StripeBillingState>, state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>, client: Arc<stripe::Client>,
zed_pro_price_id: Option<String>,
} }
#[derive(Default)] #[derive(Default)]
@ -31,10 +32,11 @@ struct StripeBillingPrice {
} }
impl StripeBilling { impl StripeBilling {
pub fn new(client: Arc<stripe::Client>) -> Self { pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
Self { Self {
client, client,
state: RwLock::default(), state: RwLock::default(),
zed_pro_price_id,
} }
} }
@ -382,6 +384,32 @@ impl StripeBilling {
let session = stripe::CheckoutSession::create(&self.client, params).await?; let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?) Ok(session.url.context("no checkout session URL")?)
} }
pub async fn checkout_with_zed_pro(
&self,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self
.zed_pro_price_id
.as_ref()
.ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.clone()),
quantity: Some(1),
..Default::default()
}]);
params.success_url = Some(success_url);
let session = stripe::CheckoutSession::create(&self.client, params).await?;
Ok(session.url.context("no checkout session URL")?)
}
} }
#[derive(Serialize)] #[derive(Serialize)]

View file

@ -557,6 +557,7 @@ impl TestServer {
migrations_path: None, migrations_path: None,
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_zed_pro_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
kinesis_region: None, kinesis_region: None,