diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index dd0d5f79d9..a6e37b1bd5 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -17,9 +17,8 @@ use stripe::{ CreateBillingPortalSessionFlowDataAfterCompletionRedirect, CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm, CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems, - CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject, - EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, - SubscriptionStatus, + CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType, + Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus, }; use util::{ResultExt, maybe}; @@ -310,13 +309,6 @@ async fn create_billing_subscription( .await? .ok_or_else(|| anyhow!("user not found"))?; - let Some(stripe_client) = app.stripe_client.clone() else { - log::error!("failed to retrieve Stripe client"); - Err(Error::http( - StatusCode::NOT_IMPLEMENTED, - "not supported".into(), - ))? - }; let Some(stripe_billing) = app.stripe_billing.clone() else { log::error!("failed to retrieve Stripe billing object"); Err(Error::http( @@ -351,35 +343,9 @@ async fn create_billing_subscription( CustomerId::from_str(&existing_customer.stripe_customer_id) .context("failed to parse customer ID")? } else { - let existing_customer = if let Some(email) = user.email_address.as_deref() { - let customers = Customer::list( - &stripe_client, - &stripe::ListCustomers { - email: Some(email), - ..Default::default() - }, - ) - .await?; - - customers.data.first().cloned() - } else { - None - }; - - if let Some(existing_customer) = existing_customer { - existing_customer.id - } else { - let customer = Customer::create( - &stripe_client, - CreateCustomer { - email: user.email_address.as_deref(), - ..Default::default() - }, - ) - .await?; - - customer.id - } + stripe_billing + .find_or_create_customer_by_email(user.email_address.as_deref()) + .await? }; let success_url = format!( @@ -1487,7 +1453,7 @@ impl From for StripeCancellationReason { } /// Finds or creates a billing customer using the provided customer. -async fn find_or_create_billing_customer( +pub async fn find_or_create_billing_customer( app: &Arc, stripe_client: &stripe::Client, customer_or_id: Expandable, diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index a79bc7bc7b..87076ba299 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -32,9 +32,9 @@ impl Database { pub async fn create_billing_subscription( &self, params: &CreateBillingSubscriptionParams, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { - billing_subscription::Entity::insert(billing_subscription::ActiveModel { + let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel { billing_customer_id: ActiveValue::set(params.billing_customer_id), kind: ActiveValue::set(params.kind), stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), @@ -44,10 +44,14 @@ impl Database { stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end), ..Default::default() }) - .exec_without_returning(&*tx) - .await?; + .exec(&*tx) + .await? + .last_insert_id; - Ok(()) + Ok(billing_subscription::Entity::find_by_id(id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?) }) .await } diff --git a/crates/collab/src/llm/token.rs b/crates/collab/src/llm/token.rs index 7c0b3c5efd..c35b954503 100644 --- a/crates/collab/src/llm/token.rs +++ b/crates/collab/src/llm/token.rs @@ -42,7 +42,7 @@ impl LlmTokenClaims { is_staff: bool, billing_preferences: Option, feature_flags: &Vec, - subscription: Option, + subscription: billing_subscription::Model, system_id: Option, config: &Config, ) -> Result { @@ -54,17 +54,14 @@ impl LlmTokenClaims { let plan = if is_staff { Plan::ZedPro } else { - subscription - .as_ref() - .and_then(|subscription| subscription.kind) - .map_or(Plan::ZedFree, |kind| match kind { - SubscriptionKind::ZedFree => Plan::ZedFree, - SubscriptionKind::ZedPro => Plan::ZedPro, - SubscriptionKind::ZedProTrial => Plan::ZedProTrial, - }) + subscription.kind.map_or(Plan::ZedFree, |kind| match kind { + SubscriptionKind::ZedFree => Plan::ZedFree, + SubscriptionKind::ZedPro => Plan::ZedPro, + SubscriptionKind::ZedProTrial => Plan::ZedProTrial, + }) }; let subscription_period = - billing_subscription::Model::current_period(subscription, is_staff) + billing_subscription::Model::current_period(Some(subscription), is_staff) .map(|(start, end)| (start.naive_utc(), end.naive_utc())) .ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 06b011f0f9..1e9b7141f9 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,5 +1,6 @@ mod connection_pool; +use crate::api::billing::find_or_create_billing_customer; use crate::api::{CloudflareIpCountryHeader, SystemIdHeader}; use crate::db::billing_subscription::SubscriptionKind; use crate::llm::db::LlmDatabase; @@ -4024,7 +4025,56 @@ async fn get_llm_api_token( Err(anyhow!("terms of service not accepted"))? } - let billing_subscription = db.get_active_billing_subscription(user.id).await?; + let Some(stripe_client) = session.app_state.stripe_client.as_ref() else { + Err(anyhow!("failed to retrieve Stripe client"))? + }; + + let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else { + Err(anyhow!("failed to retrieve Stripe billing object"))? + }; + + let billing_customer = + if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? { + billing_customer + } else { + let customer_id = stripe_billing + .find_or_create_customer_by_email(user.email_address.as_deref()) + .await?; + + find_or_create_billing_customer( + &session.app_state, + &stripe_client, + stripe::Expandable::Id(customer_id), + ) + .await? + .ok_or_else(|| anyhow!("billing customer not found"))? + }; + + let billing_subscription = + if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? { + billing_subscription + } else { + let stripe_customer_id = billing_customer + .stripe_customer_id + .parse::() + .context("failed to parse Stripe customer ID from database")?; + + let stripe_subscription = stripe_billing + .subscribe_to_zed_free(stripe_customer_id) + .await?; + + db.create_billing_subscription(&db::CreateBillingSubscriptionParams { + billing_customer_id: billing_customer.id, + kind: Some(SubscriptionKind::ZedFree), + stripe_subscription_id: stripe_subscription.id.to_string(), + stripe_subscription_status: stripe_subscription.status.into(), + stripe_cancellation_reason: None, + stripe_current_period_start: Some(stripe_subscription.current_period_start), + stripe_current_period_end: Some(stripe_subscription.current_period_end), + }) + .await? + }; + let billing_preferences = db.get_billing_preferences(user.id).await?; let token = LlmTokenClaims::create( diff --git a/crates/collab/src/stripe_billing.rs b/crates/collab/src/stripe_billing.rs index 78680faf57..a538adf401 100644 --- a/crates/collab/src/stripe_billing.rs +++ b/crates/collab/src/stripe_billing.rs @@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow}; use chrono::Utc; use collections::HashMap; use serde::{Deserialize, Serialize}; -use stripe::{PriceId, SubscriptionStatus}; +use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus}; use tokio::sync::RwLock; use uuid::Uuid; @@ -122,6 +122,47 @@ impl StripeBilling { }) } + /// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does + /// not already exist. + /// + /// Always returns a new Stripe customer if the email address is `None`. + pub async fn find_or_create_customer_by_email( + &self, + email_address: Option<&str>, + ) -> Result { + let existing_customer = if let Some(email) = email_address { + let customers = Customer::list( + &self.client, + &stripe::ListCustomers { + email: Some(email), + ..Default::default() + }, + ) + .await?; + + customers.data.first().cloned() + } else { + None + }; + + let customer_id = if let Some(existing_customer) = existing_customer { + existing_customer.id + } else { + let customer = Customer::create( + &self.client, + CreateCustomer { + email: email_address, + ..Default::default() + }, + ) + .await?; + + customer.id + }; + + Ok(customer_id) + } + pub async fn subscribe_to_price( &self, subscription_id: &stripe::SubscriptionId,