collab: Create Zed Free subscription when issuing an LLM token (#30975)

This PR makes it so we create a Zed Free subscription when issuing an
LLM token, if one does not already exist.

Release Notes:

- N/A

---------

Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
Marshall Bowers 2025-05-19 18:45:22 -04:00 committed by GitHub
parent 83d513aef4
commit f7a0834f54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 115 additions and 57 deletions

View file

@ -17,9 +17,8 @@ use stripe::{
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
@ -310,13 +309,6 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
@ -351,35 +343,9 @@ async fn create_billing_subscription(
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
let existing_customer = if let Some(email) = user.email_address.as_deref() {
let customers = Customer::list(
&stripe_client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
customers.data.first().cloned()
} else {
None
};
if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = Customer::create(
&stripe_client,
CreateCustomer {
email: user.email_address.as_deref(),
..Default::default()
},
)
.await?;
customer.id
}
stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?
};
let success_url = format!(
@ -1487,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
}
/// Finds or creates a billing customer using the provided customer.
async fn find_or_create_billing_customer(
pub async fn find_or_create_billing_customer(
app: &Arc<AppState>,
stripe_client: &stripe::Client,
customer_or_id: Expandable<Customer>,

View file

@ -32,9 +32,9 @@ impl Database {
pub async fn create_billing_subscription(
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<()> {
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
@ -44,10 +44,14 @@ impl Database {
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
..Default::default()
})
.exec_without_returning(&*tx)
.await?;
.exec(&*tx)
.await?
.last_insert_id;
Ok(())
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?
.ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
})
.await
}

View file

@ -42,7 +42,7 @@ impl LlmTokenClaims {
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
subscription: Option<billing_subscription::Model>,
subscription: billing_subscription::Model,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
@ -54,17 +54,14 @@ impl LlmTokenClaims {
let plan = if is_staff {
Plan::ZedPro
} else {
subscription
.as_ref()
.and_then(|subscription| subscription.kind)
.map_or(Plan::ZedFree, |kind| match kind {
SubscriptionKind::ZedFree => Plan::ZedFree,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
SubscriptionKind::ZedFree => Plan::ZedFree,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
};
let subscription_period =
billing_subscription::Model::current_period(subscription, is_staff)
billing_subscription::Model::current_period(Some(subscription), is_staff)
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;

View file

@ -1,5 +1,6 @@
mod connection_pool;
use crate::api::billing::find_or_create_billing_customer;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
Err(anyhow!("terms of service not accepted"))?
}
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
Err(anyhow!("failed to retrieve Stripe client"))?
};
let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
Err(anyhow!("failed to retrieve Stripe billing object"))?
};
let billing_customer =
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
billing_customer
} else {
let customer_id = stripe_billing
.find_or_create_customer_by_email(user.email_address.as_deref())
.await?;
find_or_create_billing_customer(
&session.app_state,
&stripe_client,
stripe::Expandable::Id(customer_id),
)
.await?
.ok_or_else(|| anyhow!("billing customer not found"))?
};
let billing_subscription =
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
billing_subscription
} else {
let stripe_customer_id = billing_customer
.stripe_customer_id
.parse::<stripe::CustomerId>()
.context("failed to parse Stripe customer ID from database")?;
let stripe_subscription = stripe_billing
.subscribe_to_zed_free(stripe_customer_id)
.await?;
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
billing_customer_id: billing_customer.id,
kind: Some(SubscriptionKind::ZedFree),
stripe_subscription_id: stripe_subscription.id.to_string(),
stripe_subscription_status: stripe_subscription.status.into(),
stripe_cancellation_reason: None,
stripe_current_period_start: Some(stripe_subscription.current_period_start),
stripe_current_period_end: Some(stripe_subscription.current_period_end),
})
.await?
};
let billing_preferences = db.get_billing_preferences(user.id).await?;
let token = LlmTokenClaims::create(

View file

@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
use serde::{Deserialize, Serialize};
use stripe::{PriceId, SubscriptionStatus};
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
use tokio::sync::RwLock;
use uuid::Uuid;
@ -122,6 +122,47 @@ impl StripeBilling {
})
}
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
/// not already exist.
///
/// Always returns a new Stripe customer if the email address is `None`.
pub async fn find_or_create_customer_by_email(
&self,
email_address: Option<&str>,
) -> Result<CustomerId> {
let existing_customer = if let Some(email) = email_address {
let customers = Customer::list(
&self.client,
&stripe::ListCustomers {
email: Some(email),
..Default::default()
},
)
.await?;
customers.data.first().cloned()
} else {
None
};
let customer_id = if let Some(existing_customer) = existing_customer {
existing_customer.id
} else {
let customer = Customer::create(
&self.client,
CreateCustomer {
email: email_address,
..Default::default()
},
)
.await?;
customer.id
};
Ok(customer_id)
}
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,