collab: Create Zed Free subscription when issuing an LLM token (#30975)
This PR makes it so we create a Zed Free subscription when issuing an LLM token, if one does not already exist. Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
83d513aef4
commit
f7a0834f54
5 changed files with 115 additions and 57 deletions
|
@ -17,9 +17,8 @@ use stripe::{
|
|||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
|
||||
SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
|
||||
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
|
@ -310,13 +309,6 @@ async fn create_billing_subscription(
|
|||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
|
@ -351,35 +343,9 @@ async fn create_billing_subscription(
|
|||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?
|
||||
} else {
|
||||
let existing_customer = if let Some(email) = user.email_address.as_deref() {
|
||||
let customers = Customer::list(
|
||||
&stripe_client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customers.data.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&stripe_client,
|
||||
CreateCustomer {
|
||||
email: user.email_address.as_deref(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
}
|
||||
stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
|
@ -1487,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
|
|||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
async fn find_or_create_billing_customer(
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
customer_or_id: Expandable<Customer>,
|
||||
|
|
|
@ -32,9 +32,9 @@ impl Database {
|
|||
pub async fn create_billing_subscription(
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
) -> Result<billing_subscription::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
kind: ActiveValue::set(params.kind),
|
||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||
|
@ -44,10 +44,14 @@ impl Database {
|
|||
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
.last_insert_id;
|
||||
|
||||
Ok(())
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ impl LlmTokenClaims {
|
|||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
feature_flags: &Vec<String>,
|
||||
subscription: Option<billing_subscription::Model>,
|
||||
subscription: billing_subscription::Model,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
|
@ -54,17 +54,14 @@ impl LlmTokenClaims {
|
|||
let plan = if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription
|
||||
.as_ref()
|
||||
.and_then(|subscription| subscription.kind)
|
||||
.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
};
|
||||
let subscription_period =
|
||||
billing_subscription::Model::current_period(subscription, is_staff)
|
||||
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mod connection_pool;
|
||||
|
||||
use crate::api::billing::find_or_create_billing_customer;
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
|
@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
|
|||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
|
||||
let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
|
||||
Err(anyhow!("failed to retrieve Stripe client"))?
|
||||
};
|
||||
|
||||
let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
|
||||
Err(anyhow!("failed to retrieve Stripe billing object"))?
|
||||
};
|
||||
|
||||
let billing_customer =
|
||||
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?;
|
||||
|
||||
find_or_create_billing_customer(
|
||||
&session.app_state,
|
||||
&stripe_client,
|
||||
stripe::Expandable::Id(customer_id),
|
||||
)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("billing customer not found"))?
|
||||
};
|
||||
|
||||
let billing_subscription =
|
||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||
billing_subscription
|
||||
} else {
|
||||
let stripe_customer_id = billing_customer
|
||||
.stripe_customer_id
|
||||
.parse::<stripe::CustomerId>()
|
||||
.context("failed to parse Stripe customer ID from database")?;
|
||||
|
||||
let stripe_subscription = stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
|
||||
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
kind: Some(SubscriptionKind::ZedFree),
|
||||
stripe_subscription_id: stripe_subscription.id.to_string(),
|
||||
stripe_subscription_status: stripe_subscription.status.into(),
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
||||
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
||||
})
|
||||
.await?
|
||||
};
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
|
|
|
@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
|
|||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::{PriceId, SubscriptionStatus};
|
||||
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -122,6 +122,47 @@ impl StripeBilling {
|
|||
})
|
||||
}
|
||||
|
||||
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
|
||||
/// not already exist.
|
||||
///
|
||||
/// Always returns a new Stripe customer if the email address is `None`.
|
||||
pub async fn find_or_create_customer_by_email(
|
||||
&self,
|
||||
email_address: Option<&str>,
|
||||
) -> Result<CustomerId> {
|
||||
let existing_customer = if let Some(email) = email_address {
|
||||
let customers = Customer::list(
|
||||
&self.client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customers.data.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&self.client,
|
||||
CreateCustomer {
|
||||
email: email_address,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
};
|
||||
|
||||
Ok(customer_id)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue