collab: Create Zed Free subscription when issuing an LLM token (#30975)
This PR makes it so we create a Zed Free subscription when issuing an LLM token, if one does not already exist. Release Notes: - N/A --------- Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
This commit is contained in:
parent
83d513aef4
commit
f7a0834f54
5 changed files with 115 additions and 57 deletions
|
@ -17,9 +17,8 @@ use stripe::{
|
||||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
|
||||||
EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
|
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||||
SubscriptionStatus,
|
|
||||||
};
|
};
|
||||||
use util::{ResultExt, maybe};
|
use util::{ResultExt, maybe};
|
||||||
|
|
||||||
|
@ -310,13 +309,6 @@ async fn create_billing_subscription(
|
||||||
.await?
|
.await?
|
||||||
.ok_or_else(|| anyhow!("user not found"))?;
|
.ok_or_else(|| anyhow!("user not found"))?;
|
||||||
|
|
||||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
|
||||||
log::error!("failed to retrieve Stripe client");
|
|
||||||
Err(Error::http(
|
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
|
||||||
"not supported".into(),
|
|
||||||
))?
|
|
||||||
};
|
|
||||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||||
log::error!("failed to retrieve Stripe billing object");
|
log::error!("failed to retrieve Stripe billing object");
|
||||||
Err(Error::http(
|
Err(Error::http(
|
||||||
|
@ -351,35 +343,9 @@ async fn create_billing_subscription(
|
||||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||||
.context("failed to parse customer ID")?
|
.context("failed to parse customer ID")?
|
||||||
} else {
|
} else {
|
||||||
let existing_customer = if let Some(email) = user.email_address.as_deref() {
|
stripe_billing
|
||||||
let customers = Customer::list(
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
&stripe_client,
|
.await?
|
||||||
&stripe::ListCustomers {
|
|
||||||
email: Some(email),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customers.data.first().cloned()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(existing_customer) = existing_customer {
|
|
||||||
existing_customer.id
|
|
||||||
} else {
|
|
||||||
let customer = Customer::create(
|
|
||||||
&stripe_client,
|
|
||||||
CreateCustomer {
|
|
||||||
email: user.email_address.as_deref(),
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
customer.id
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let success_url = format!(
|
let success_url = format!(
|
||||||
|
@ -1487,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Finds or creates a billing customer using the provided customer.
|
/// Finds or creates a billing customer using the provided customer.
|
||||||
async fn find_or_create_billing_customer(
|
pub async fn find_or_create_billing_customer(
|
||||||
app: &Arc<AppState>,
|
app: &Arc<AppState>,
|
||||||
stripe_client: &stripe::Client,
|
stripe_client: &stripe::Client,
|
||||||
customer_or_id: Expandable<Customer>,
|
customer_or_id: Expandable<Customer>,
|
||||||
|
|
|
@ -32,9 +32,9 @@ impl Database {
|
||||||
pub async fn create_billing_subscription(
|
pub async fn create_billing_subscription(
|
||||||
&self,
|
&self,
|
||||||
params: &CreateBillingSubscriptionParams,
|
params: &CreateBillingSubscriptionParams,
|
||||||
) -> Result<()> {
|
) -> Result<billing_subscription::Model> {
|
||||||
self.transaction(|tx| async move {
|
self.transaction(|tx| async move {
|
||||||
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||||
kind: ActiveValue::set(params.kind),
|
kind: ActiveValue::set(params.kind),
|
||||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||||
|
@ -44,10 +44,14 @@ impl Database {
|
||||||
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
})
|
})
|
||||||
.exec_without_returning(&*tx)
|
.exec(&*tx)
|
||||||
.await?;
|
.await?
|
||||||
|
.last_insert_id;
|
||||||
|
|
||||||
Ok(())
|
Ok(billing_subscription::Entity::find_by_id(id)
|
||||||
|
.one(&*tx)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ impl LlmTokenClaims {
|
||||||
is_staff: bool,
|
is_staff: bool,
|
||||||
billing_preferences: Option<billing_preference::Model>,
|
billing_preferences: Option<billing_preference::Model>,
|
||||||
feature_flags: &Vec<String>,
|
feature_flags: &Vec<String>,
|
||||||
subscription: Option<billing_subscription::Model>,
|
subscription: billing_subscription::Model,
|
||||||
system_id: Option<String>,
|
system_id: Option<String>,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
@ -54,17 +54,14 @@ impl LlmTokenClaims {
|
||||||
let plan = if is_staff {
|
let plan = if is_staff {
|
||||||
Plan::ZedPro
|
Plan::ZedPro
|
||||||
} else {
|
} else {
|
||||||
subscription
|
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
||||||
.as_ref()
|
|
||||||
.and_then(|subscription| subscription.kind)
|
|
||||||
.map_or(Plan::ZedFree, |kind| match kind {
|
|
||||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
let subscription_period =
|
let subscription_period =
|
||||||
billing_subscription::Model::current_period(subscription, is_staff)
|
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
||||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||||
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
|
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod connection_pool;
|
mod connection_pool;
|
||||||
|
|
||||||
|
use crate::api::billing::find_or_create_billing_customer;
|
||||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||||
use crate::db::billing_subscription::SubscriptionKind;
|
use crate::db::billing_subscription::SubscriptionKind;
|
||||||
use crate::llm::db::LlmDatabase;
|
use crate::llm::db::LlmDatabase;
|
||||||
|
@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
|
||||||
Err(anyhow!("terms of service not accepted"))?
|
Err(anyhow!("terms of service not accepted"))?
|
||||||
}
|
}
|
||||||
|
|
||||||
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
|
let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
|
||||||
|
Err(anyhow!("failed to retrieve Stripe client"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
|
||||||
|
Err(anyhow!("failed to retrieve Stripe billing object"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let billing_customer =
|
||||||
|
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
|
||||||
|
billing_customer
|
||||||
|
} else {
|
||||||
|
let customer_id = stripe_billing
|
||||||
|
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
find_or_create_billing_customer(
|
||||||
|
&session.app_state,
|
||||||
|
&stripe_client,
|
||||||
|
stripe::Expandable::Id(customer_id),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| anyhow!("billing customer not found"))?
|
||||||
|
};
|
||||||
|
|
||||||
|
let billing_subscription =
|
||||||
|
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||||
|
billing_subscription
|
||||||
|
} else {
|
||||||
|
let stripe_customer_id = billing_customer
|
||||||
|
.stripe_customer_id
|
||||||
|
.parse::<stripe::CustomerId>()
|
||||||
|
.context("failed to parse Stripe customer ID from database")?;
|
||||||
|
|
||||||
|
let stripe_subscription = stripe_billing
|
||||||
|
.subscribe_to_zed_free(stripe_customer_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
||||||
|
billing_customer_id: billing_customer.id,
|
||||||
|
kind: Some(SubscriptionKind::ZedFree),
|
||||||
|
stripe_subscription_id: stripe_subscription.id.to_string(),
|
||||||
|
stripe_subscription_status: stripe_subscription.status.into(),
|
||||||
|
stripe_cancellation_reason: None,
|
||||||
|
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
||||||
|
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
||||||
|
})
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||||
|
|
||||||
let token = LlmTokenClaims::create(
|
let token = LlmTokenClaims::create(
|
||||||
|
|
|
@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use collections::HashMap;
|
use collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use stripe::{PriceId, SubscriptionStatus};
|
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -122,6 +122,47 @@ impl StripeBilling {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
|
||||||
|
/// not already exist.
|
||||||
|
///
|
||||||
|
/// Always returns a new Stripe customer if the email address is `None`.
|
||||||
|
pub async fn find_or_create_customer_by_email(
|
||||||
|
&self,
|
||||||
|
email_address: Option<&str>,
|
||||||
|
) -> Result<CustomerId> {
|
||||||
|
let existing_customer = if let Some(email) = email_address {
|
||||||
|
let customers = Customer::list(
|
||||||
|
&self.client,
|
||||||
|
&stripe::ListCustomers {
|
||||||
|
email: Some(email),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
customers.data.first().cloned()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||||
|
existing_customer.id
|
||||||
|
} else {
|
||||||
|
let customer = Customer::create(
|
||||||
|
&self.client,
|
||||||
|
CreateCustomer {
|
||||||
|
email: email_address,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
customer.id
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(customer_id)
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn subscribe_to_price(
|
pub async fn subscribe_to_price(
|
||||||
&self,
|
&self,
|
||||||
subscription_id: &stripe::SubscriptionId,
|
subscription_id: &stripe::SubscriptionId,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue