collab: Add support for subscribing to Zed Pro trials (#28812)

This PR adds support for subscribing to Zed Pro trials (and then
upgrading from a trial to Zed Pro).

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-15 16:49:16 -04:00 committed by GitHub
parent 5619a3e618
commit dad6067e18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 143 additions and 42 deletions

View file

@ -15,10 +15,12 @@ use stripe::{
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession, BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion, CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject, CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus, EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
}; };
use util::ResultExt; use util::{ResultExt, maybe};
use crate::api::events::SnowflakeRow; use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{ use crate::db::billing_subscription::{
@ -159,6 +161,7 @@ struct BillingSubscriptionJson {
id: BillingSubscriptionId, id: BillingSubscriptionId,
name: String, name: String,
status: StripeSubscriptionStatus, status: StripeSubscriptionStatus,
trial_end_at: Option<String>,
cancel_at: Option<String>, cancel_at: Option<String>,
/// Whether this subscription can be canceled. /// Whether this subscription can be canceled.
is_cancelable: bool, is_cancelable: bool,
@ -188,9 +191,21 @@ async fn list_billing_subscriptions(
id: subscription.id, id: subscription.id,
name: match subscription.kind { name: match subscription.kind {
Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(), Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
None => "Zed LLM Usage".to_string(), None => "Zed LLM Usage".to_string(),
}, },
status: subscription.stripe_subscription_status, status: subscription.stripe_subscription_status,
trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
maybe!({
let end_at = subscription.stripe_current_period_end?;
let end_at = DateTime::from_timestamp(end_at, 0)?;
Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
})
} else {
None
},
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| { cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
cancel_at cancel_at
.and_utc() .and_utc()
@ -207,6 +222,7 @@ async fn list_billing_subscriptions(
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
enum ProductCode { enum ProductCode {
ZedPro, ZedPro,
ZedProTrial,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -286,24 +302,36 @@ async fn create_billing_subscription(
customer.id customer.id
}; };
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
let checkout_session_url = match body.product { let checkout_session_url = match body.product {
Some(ProductCode::ZedPro) => { Some(ProductCode::ZedPro) => {
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
stripe_billing stripe_billing
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url) .checkout_with_price(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.await?
}
Some(ProductCode::ZedProTrial) => {
stripe_billing
.checkout_with_price(
app.config.zed_pro_trial_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.await? .await?
} }
None => { None => {
let default_model = let default_model =
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?; llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?; let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!(
"{}/account?checkout_complete=1",
app.config.zed_dot_dev_url()
);
stripe_billing stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url) .checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await? .await?
@ -322,6 +350,8 @@ enum ManageSubscriptionIntent {
/// ///
/// This will open the Stripe billing portal without putting the user in a specific flow. /// This will open the Stripe billing portal without putting the user in a specific flow.
ManageSubscription, ManageSubscription,
/// The user intends to upgrade to Zed Pro.
UpgradeToPro,
/// The user intends to cancel their subscription. /// The user intends to cancel their subscription.
Cancel, Cancel,
/// The user intends to stop the cancellation of their subscription. /// The user intends to stop the cancellation of their subscription.
@ -373,11 +403,10 @@ async fn manage_billing_subscription(
.get_billing_subscription_by_id(body.subscription_id) .get_billing_subscription_by_id(body.subscription_id)
.await? .await?
.ok_or_else(|| anyhow!("subscription not found"))?; .ok_or_else(|| anyhow!("subscription not found"))?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
if body.intent == ManageSubscriptionIntent::StopCancellation { if body.intent == ManageSubscriptionIntent::StopCancellation {
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;
let updated_stripe_subscription = Subscription::update( let updated_stripe_subscription = Subscription::update(
&stripe_client, &stripe_client,
&subscription_id, &subscription_id,
@ -410,6 +439,47 @@ async fn manage_billing_subscription(
let flow = match body.intent { let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None, ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = app.config.zed_pro_price_id()?;
let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
let zed_free_price_id = app.config.zed_free_price_id()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
let subscription_item_to_update = stripe_subscription
.items
.data
.iter()
.find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
Some(item.id.clone())
} else {
None
}
})
.ok_or_else(|| anyhow!("No subscription item to update"))?;
Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
subscription_update_confirm: Some(
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
subscription: subscription.stripe_subscription_id,
items: vec![
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
id: subscription_item_to_update.to_string(),
price: Some(zed_pro_price_id.to_string()),
quantity: Some(1),
},
],
discounts: None,
},
),
..Default::default()
})
}
ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData { ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel, type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion { after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
@ -696,22 +766,25 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id); log::info!("handling Stripe {} event: {}", event.type_, event.id);
let subscription_kind = let subscription_kind = maybe!({
if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() { let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
let has_zed_pro_price = subscription.items.data.iter().any(|item| { let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
item.price let zed_free_price_id = app.config.zed_free_price_id().ok()?;
.as_ref()
.map_or(false, |price| price.id.as_str() == zed_pro_price_id)
});
if has_zed_pro_price { subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
if price.id == zed_pro_price_id {
Some(SubscriptionKind::ZedPro) Some(SubscriptionKind::ZedPro)
} else if price.id == zed_pro_trial_price_id {
Some(SubscriptionKind::ZedProTrial)
} else if price.id == zed_free_price_id {
Some(SubscriptionKind::ZedFree)
} else { } else {
None None
} }
} else { })
None });
};
let billing_customer = let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer) find_or_create_billing_customer(app, stripe_client, subscription.customer)

View file

@ -62,11 +62,14 @@ impl Database {
billing_subscription::Entity::update(billing_subscription::ActiveModel { billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id), id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(), billing_customer_id: params.billing_customer_id.clone(),
kind: params.kind.clone(),
stripe_subscription_id: params.stripe_subscription_id.clone(), stripe_subscription_id: params.stripe_subscription_id.clone(),
stripe_subscription_status: params.stripe_subscription_status.clone(), stripe_subscription_status: params.stripe_subscription_status.clone(),
stripe_cancel_at: params.stripe_cancel_at.clone(), stripe_cancel_at: params.stripe_cancel_at.clone(),
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(), stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
..Default::default() stripe_current_period_start: params.stripe_current_period_start.clone(),
stripe_current_period_end: params.stripe_current_period_end.clone(),
created_at: ActiveValue::not_set(),
}) })
.exec(&*tx) .exec(&*tx)
.await?; .await?;

View file

@ -43,6 +43,10 @@ impl ActiveModelBehavior for ActiveModel {}
pub enum SubscriptionKind { pub enum SubscriptionKind {
#[sea_orm(string_value = "zed_pro")] #[sea_orm(string_value = "zed_pro")]
ZedPro, ZedPro,
#[sea_orm(string_value = "zed_pro_trial")]
ZedProTrial,
#[sea_orm(string_value = "zed_free")]
ZedFree,
} }
/// The status of a Stripe subscription. /// The status of a Stripe subscription.

View file

@ -183,6 +183,8 @@ pub struct Config {
pub auto_join_channel_id: Option<ChannelId>, pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>, pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>, pub stripe_zed_pro_price_id: Option<String>,
pub stripe_zed_pro_trial_price_id: Option<String>,
pub stripe_zed_free_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>, pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>, pub user_backfiller_github_access_token: Option<Arc<str>>,
} }
@ -201,6 +203,29 @@ impl Config {
} }
} }
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
}
pub fn zed_pro_trial_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id(
"Zed Pro Trial",
self.stripe_zed_pro_trial_price_id.as_deref(),
)
}
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
}
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
use std::str::FromStr as _;
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
Ok(stripe::PriceId::from_str(price_id)?)
}
#[cfg(test)] #[cfg(test)]
pub fn test() -> Self { pub fn test() -> Self {
Self { Self {
@ -239,6 +264,8 @@ impl Config {
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_zed_pro_price_id: None, stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
kinesis_region: None, kinesis_region: None,
@ -324,12 +351,9 @@ impl AppState {
llm_db, llm_db,
livekit_client, livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(), blob_store_client: build_blob_store_client(&config).await.log_err(),
stripe_billing: stripe_client.clone().map(|stripe_client| { stripe_billing: stripe_client
Arc::new(StripeBilling::new( .clone()
stripe_client, .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
config.stripe_zed_pro_price_id.clone(),
))
}),
stripe_client, stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)), rate_limiter: Arc::new(RateLimiter::new(db)),
executor, executor,

View file

@ -1,16 +1,16 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{Cents, Result, llm}; use crate::{Cents, Result, llm};
use anyhow::{Context as _, anyhow}; use anyhow::Context as _;
use chrono::{Datelike, Utc}; use chrono::{Datelike, Utc};
use collections::HashMap; use collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use stripe::PriceId;
use tokio::sync::RwLock; use tokio::sync::RwLock;
pub struct StripeBilling { pub struct StripeBilling {
state: RwLock<StripeBillingState>, state: RwLock<StripeBillingState>,
client: Arc<stripe::Client>, client: Arc<stripe::Client>,
zed_pro_price_id: Option<String>,
} }
#[derive(Default)] #[derive(Default)]
@ -32,11 +32,10 @@ struct StripeBillingPrice {
} }
impl StripeBilling { impl StripeBilling {
pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self { pub fn new(client: Arc<stripe::Client>) -> Self {
Self { Self {
client, client,
state: RwLock::default(), state: RwLock::default(),
zed_pro_price_id,
} }
} }
@ -385,23 +384,19 @@ impl StripeBilling {
Ok(session.url.context("no checkout session URL")?) Ok(session.url.context("no checkout session URL")?)
} }
pub async fn checkout_with_zed_pro( pub async fn checkout_with_price(
&self, &self,
price_id: PriceId,
customer_id: stripe::CustomerId, customer_id: stripe::CustomerId,
github_login: &str, github_login: &str,
success_url: &str, success_url: &str,
) -> Result<String> { ) -> Result<String> {
let zed_pro_price_id = self
.zed_pro_price_id
.as_ref()
.ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
let mut params = stripe::CreateCheckoutSession::new(); let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription); params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id); params.customer = Some(customer_id);
params.client_reference_id = Some(github_login); params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems { params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.clone()), price: Some(price_id.to_string()),
quantity: Some(1), quantity: Some(1),
..Default::default() ..Default::default()
}]); }]);

View file

@ -558,6 +558,8 @@ impl TestServer {
seed_path: None, seed_path: None,
stripe_api_key: None, stripe_api_key: None,
stripe_zed_pro_price_id: None, stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None, supermaven_admin_api_key: None,
user_backfiller_github_access_token: None, user_backfiller_github_access_token: None,
kinesis_region: None, kinesis_region: None,