collab: Remove unused billing-related database code (#36282)

This PR removes a bunch of unused database code related to billing, as
we no longer need it.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-08-15 18:58:10 -04:00 committed by GitHub
parent bf34e185d5
commit e664a9bc48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 1 additions and 1419 deletions

1
Cargo.lock generated
View file

@ -3270,7 +3270,6 @@ dependencies = [
"chrono",
"client",
"clock",
"cloud_llm_client",
"collab_ui",
"collections",
"command_palette_hooks",

View file

@ -29,7 +29,6 @@ axum-extra = { version = "0.4", features = ["erased-json"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
envy = "0.4.2"

View file

@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind;
pub use tests::TestDb;
pub use ids::*;
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
pub use queries::billing_subscriptions::{
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
};
pub use queries::contributors::ContributorSelector;
pub use queries::processed_stripe_events::CreateProcessedStripeEventParams;
pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
pub use tables::*;

View file

@ -70,9 +70,6 @@ macro_rules! id_type {
}
id_type!(AccessTokenId);
id_type!(BillingCustomerId);
id_type!(BillingSubscriptionId);
id_type!(BillingPreferencesId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
id_type!(ChannelChatParticipantId);

View file

@ -1,9 +1,6 @@
use super::*;
pub mod access_tokens;
pub mod billing_customers;
pub mod billing_preferences;
pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;
pub mod contacts;
@ -12,7 +9,6 @@ pub mod embeddings;
pub mod extensions;
pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
pub mod rooms;
pub mod servers;

View file

@ -1,100 +0,0 @@
use super::*;
#[derive(Debug)]
pub struct CreateBillingCustomerParams {
pub user_id: UserId,
pub stripe_customer_id: String,
}
#[derive(Debug, Default)]
pub struct UpdateBillingCustomerParams {
pub user_id: ActiveValue<UserId>,
pub stripe_customer_id: ActiveValue<String>,
pub has_overdue_invoices: ActiveValue<bool>,
pub trial_started_at: ActiveValue<Option<DateTime>>,
}
impl Database {
/// Creates a new billing customer.
pub async fn create_billing_customer(
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
..Default::default()
})
.exec_with_returning(&*tx)
.await?;
Ok(customer)
})
.await
}
/// Updates the specified billing customer.
pub async fn update_billing_customer(
&self,
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
stripe_customer_id: params.stripe_customer_id.clone(),
has_overdue_invoices: params.has_overdue_invoices.clone(),
trial_started_at: params.trial_started_at.clone(),
created_at: ActiveValue::not_set(),
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
pub async fn get_billing_customer_by_id(
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
.await?)
})
.await
}
/// Returns the billing customer for the user with the specified ID.
pub async fn get_billing_customer_by_user_id(
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
.await?)
})
.await
}
/// Returns the billing customer for the user with the specified Stripe customer ID.
pub async fn get_billing_customer_by_stripe_customer_id(
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)
.await?)
})
.await
}
}

View file

@ -1,17 +0,0 @@
use super::*;
impl Database {
/// Returns the billing preferences for the given user, if they exist.
pub async fn get_billing_preferences(
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
.await?)
})
.await
}
}

View file

@ -1,158 +0,0 @@
use anyhow::Context as _;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use super::*;
#[derive(Debug)]
pub struct CreateBillingSubscriptionParams {
pub billing_customer_id: BillingCustomerId,
pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
pub stripe_current_period_start: Option<i64>,
pub stripe_current_period_end: Option<i64>,
}
#[derive(Debug, Default)]
pub struct UpdateBillingSubscriptionParams {
pub billing_customer_id: ActiveValue<BillingCustomerId>,
pub kind: ActiveValue<Option<SubscriptionKind>>,
pub stripe_subscription_id: ActiveValue<String>,
pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
pub stripe_current_period_start: ActiveValue<Option<i64>>,
pub stripe_current_period_end: ActiveValue<Option<i64>>,
}
impl Database {
/// Creates a new billing subscription.
pub async fn create_billing_subscription(
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
..Default::default()
})
.exec(&*tx)
.await?
.last_insert_id;
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?
.context("failed to retrieve inserted billing subscription")?)
})
.await
}
/// Updates the specified billing subscription.
pub async fn update_billing_subscription(
&self,
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
kind: params.kind.clone(),
stripe_subscription_id: params.stripe_subscription_id.clone(),
stripe_subscription_status: params.stripe_subscription_status.clone(),
stripe_cancel_at: params.stripe_cancel_at.clone(),
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
stripe_current_period_start: params.stripe_current_period_start.clone(),
stripe_current_period_end: params.stripe_current_period_end.clone(),
created_at: ActiveValue::not_set(),
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Returns the billing subscription with the specified Stripe subscription ID.
pub async fn get_billing_subscription_by_stripe_subscription_id(
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
)
.one(&*tx)
.await?)
})
.await
}
pub async fn get_active_billing_subscription(
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
.filter(
Condition::all()
.add(
Condition::any()
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.add(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Trialing),
),
)
.add(billing_subscription::Column::Kind.is_not_null()),
)
.one(&*tx)
.await?)
})
.await
}
/// Returns whether the user has an active billing subscription.
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
}
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(
billing_customer::Column::UserId.eq(user_id).and(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active)
.or(billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Trialing)),
),
)
.count(&*tx)
.await?;
Ok(count as usize)
})
.await
}
}

View file

@ -1,69 +0,0 @@
use super::*;
#[derive(Debug)]
pub struct CreateProcessedStripeEventParams {
pub stripe_event_id: String,
pub stripe_event_type: String,
pub stripe_event_created_timestamp: i64,
}
impl Database {
/// Creates a new processed Stripe event.
pub async fn create_processed_stripe_event(
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
stripe_event_created_timestamp: ActiveValue::set(
params.stripe_event_created_timestamp,
),
..Default::default()
})
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await
}
/// Returns the processed Stripe event with the specified event ID.
pub async fn get_processed_stripe_event_by_event_id(
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
})
.await
}
/// Returns the processed Stripe events with the specified event IDs.
pub async fn get_processed_stripe_events_by_event_ids(
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
)
.all(&*tx)
.await?)
})
.await
}
/// Returns whether the Stripe event with the specified ID has already been processed.
pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result<bool> {
Ok(self
.get_processed_stripe_event_by_event_id(event_id)
.await?
.is_some())
}
}

View file

@ -1,7 +1,4 @@
pub mod access_token;
pub mod billing_customer;
pub mod billing_preference;
pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;
pub mod buffer_snapshot;
@ -23,7 +20,6 @@ pub mod notification;
pub mod notification_kind;
pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod processed_stripe_event;
pub mod project;
pub mod project_collaborator;
pub mod project_repository;

View file

@ -1,41 +0,0 @@
use crate::db::{BillingCustomerId, UserId};
use sea_orm::entity::prelude::*;
/// A billing customer.
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_customers")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingCustomerId,
pub user_id: UserId,
pub stripe_customer_id: String,
pub has_overdue_invoices: bool,
pub trial_started_at: Option<DateTime>,
pub created_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
#[sea_orm(has_many = "super::billing_subscription::Entity")]
BillingSubscription,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl Related<super::billing_subscription::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingSubscription.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,32 +0,0 @@
use crate::db::{BillingPreferencesId, UserId};
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_preferences")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingPreferencesId,
pub created_at: DateTime,
pub user_id: UserId,
pub max_monthly_llm_usage_spending_in_cents: i32,
pub model_request_overages_enabled: bool,
pub model_request_overages_spend_limit_in_cents: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,161 +0,0 @@
use crate::db::{BillingCustomerId, BillingSubscriptionId};
use chrono::{Datelike as _, NaiveDate, Utc};
use sea_orm::entity::prelude::*;
use serde::Serialize;
/// A billing subscription.
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "billing_subscriptions")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: BillingSubscriptionId,
pub billing_customer_id: BillingCustomerId,
pub kind: Option<SubscriptionKind>,
pub stripe_subscription_id: String,
pub stripe_subscription_status: StripeSubscriptionStatus,
pub stripe_cancel_at: Option<DateTime>,
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
pub stripe_current_period_start: Option<i64>,
pub stripe_current_period_end: Option<i64>,
pub created_at: DateTime,
}
impl Model {
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
let period_start = self.stripe_current_period_start?;
chrono::DateTime::from_timestamp(period_start, 0)
}
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
let period_end = self.stripe_current_period_end?;
chrono::DateTime::from_timestamp(period_end, 0)
}
pub fn current_period(
subscription: Option<Self>,
is_staff: bool,
) -> Option<(DateTimeUtc, DateTimeUtc)> {
if is_staff {
let now = Utc::now();
let year = now.year();
let month = now.month();
let first_day_of_this_month =
NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?;
let next_month = if month == 12 { 1 } else { month + 1 };
let next_month_year = if month == 12 { year + 1 } else { year };
let first_day_of_next_month =
NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?;
let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1);
Some((
first_day_of_this_month.and_utc(),
last_day_of_this_month.and_utc(),
))
} else {
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;
let period_end_at = subscription.current_period_end_at()?;
Some((period_start_at, period_end_at))
}
}
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::billing_customer::Entity",
from = "Column::BillingCustomerId",
to = "super::billing_customer::Column::Id"
)]
BillingCustomer,
}
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingCustomer.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionKind {
#[sea_orm(string_value = "zed_pro")]
ZedPro,
#[sea_orm(string_value = "zed_pro_trial")]
ZedProTrial,
#[sea_orm(string_value = "zed_free")]
ZedFree,
}
impl From<SubscriptionKind> for cloud_llm_client::Plan {
fn from(value: SubscriptionKind) -> Self {
match value {
SubscriptionKind::ZedPro => Self::ZedPro,
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
SubscriptionKind::ZedFree => Self::ZedFree,
}
}
}
/// The status of a Stripe subscription.
///
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
#[derive(
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum StripeSubscriptionStatus {
#[default]
#[sea_orm(string_value = "incomplete")]
Incomplete,
#[sea_orm(string_value = "incomplete_expired")]
IncompleteExpired,
#[sea_orm(string_value = "trialing")]
Trialing,
#[sea_orm(string_value = "active")]
Active,
#[sea_orm(string_value = "past_due")]
PastDue,
#[sea_orm(string_value = "canceled")]
Canceled,
#[sea_orm(string_value = "unpaid")]
Unpaid,
#[sea_orm(string_value = "paused")]
Paused,
}
impl StripeSubscriptionStatus {
pub fn is_cancelable(&self) -> bool {
match self {
Self::Trialing | Self::Active | Self::PastDue => true,
Self::Incomplete
| Self::IncompleteExpired
| Self::Canceled
| Self::Unpaid
| Self::Paused => false,
}
}
}
/// The cancellation reason for a Stripe subscription.
///
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason)
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum StripeCancellationReason {
#[sea_orm(string_value = "cancellation_requested")]
CancellationRequested,
#[sea_orm(string_value = "payment_disputed")]
PaymentDisputed,
#[sea_orm(string_value = "payment_failed")]
PaymentFailed,
}

View file

@ -1,16 +0,0 @@
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "processed_stripe_events")]
pub struct Model {
#[sea_orm(primary_key)]
pub stripe_event_id: String,
pub stripe_event_type: String,
pub stripe_event_created_timestamp: i64,
pub processed_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -29,8 +29,6 @@ pub struct Model {
pub enum Relation {
#[sea_orm(has_many = "super::access_token::Entity")]
AccessToken,
#[sea_orm(has_one = "super::billing_customer::Entity")]
BillingCustomer,
#[sea_orm(has_one = "super::room_participant::Entity")]
RoomParticipant,
#[sea_orm(has_many = "super::project::Entity")]
@ -68,12 +66,6 @@ impl Related<super::access_token::Entity> for Entity {
}
}
impl Related<super::billing_customer::Entity> for Entity {
fn to() -> RelationDef {
Relation::BillingCustomer.def()
}
}
impl Related<super::room_participant::Entity> for Entity {
fn to() -> RelationDef {
Relation::RoomParticipant.def()

View file

@ -8,7 +8,6 @@ mod embedding_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
mod user_tests;
use crate::migrations::run_database_migrations;

View file

@ -1,38 +0,0 @@
use std::sync::Arc;
use crate::test_both_dbs;
use super::{CreateProcessedStripeEventParams, Database};
test_both_dbs!(
test_already_processed_stripe_event,
test_already_processed_stripe_event_postgres,
test_already_processed_stripe_event_sqlite
);
async fn test_already_processed_stripe_event(db: &Arc<Database>) {
let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string();
let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string();
db.create_processed_stripe_event(&CreateProcessedStripeEventParams {
stripe_event_id: processed_event_id.clone(),
stripe_event_type: "customer.created".into(),
stripe_event_created_timestamp: 1722355968,
})
.await
.unwrap();
assert!(
db.already_processed_stripe_event(&processed_event_id)
.await
.unwrap(),
"Expected {processed_event_id} to already be processed"
);
assert!(
!db.already_processed_stripe_event(&unprocessed_event_id)
.await
.unwrap(),
"Expected {unprocessed_event_id} to be unprocessed"
);
}

View file

@ -20,7 +20,6 @@ use axum::{
};
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@ -242,7 +241,6 @@ impl ServiceMode {
pub struct AppState {
pub db: Arc<Database>,
pub llm_db: Option<Arc<LlmDatabase>>,
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub executor: Executor,
@ -257,20 +255,6 @@ impl AppState {
let mut db = Database::new(db_options).await?;
db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
.llm_database_url
.clone()
.zip(config.llm_database_max_connections)
{
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
llm_db_options.max_connections(llm_database_max_connections);
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
llm_db.initialize().await?;
Some(Arc::new(llm_db))
} else {
None
};
let livekit_client = if let Some(((server, key), secret)) = config
.livekit_server
.as_ref()
@ -289,7 +273,6 @@ impl AppState {
let db = Arc::new(db);
let this = Self {
db: db.clone(),
llm_db,
livekit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
executor,

View file

@ -1,30 +1,9 @@
mod ids;
mod queries;
mod seed;
mod tables;
#[cfg(test)]
mod tests;
use cloud_llm_client::LanguageModelProvider;
use collections::HashMap;
pub use ids::*;
pub use seed::*;
pub use tables::*;
#[cfg(test)]
pub use tests::TestLlmDb;
use usage_measure::UsageMeasure;
use std::future::Future;
use std::sync::Arc;
use anyhow::Context;
pub use sea_orm::ConnectOptions;
use sea_orm::prelude::*;
use sea_orm::{
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
};
use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
use crate::Result;
use crate::db::TransactionHandle;
@ -36,9 +15,6 @@ pub struct LlmDatabase {
pool: DatabaseConnection,
#[allow(unused)]
executor: Executor,
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
models: HashMap<(LanguageModelProvider, String), model::Model>,
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}
@ -51,59 +27,11 @@ impl LlmDatabase {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
executor,
provider_ids: HashMap::default(),
models: HashMap::default(),
usage_measure_ids: HashMap::default(),
#[cfg(test)]
runtime: None,
})
}
pub async fn initialize(&mut self) -> Result<()> {
self.initialize_providers().await?;
self.initialize_models().await?;
self.initialize_usage_measures().await?;
Ok(())
}
/// Returns the list of all known models, with their [`LanguageModelProvider`].
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
self.models
.iter()
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
.collect::<Vec<_>>()
}
/// Returns the names of the known models for the given [`LanguageModelProvider`].
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
self.models
.keys()
.filter_map(|(model_provider, model_name)| {
if model_provider == &provider {
Some(model_name)
} else {
None
}
})
.cloned()
.collect::<Vec<_>>()
}
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models
.get(&(provider, name.to_string()))
.with_context(|| format!("unknown model {provider:?}:{name}"))?)
}
pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
Ok(self
.models
.values()
.find(|model| model.id == id)
.with_context(|| format!("no model for ID {id:?}"))?)
}
pub fn options(&self) -> &ConnectOptions {
&self.options
}

View file

@ -1,11 +0,0 @@
use sea_orm::{DbErr, entity::prelude::*};
use serde::{Deserialize, Serialize};
use crate::id_type;
id_type!(BillingEventId);
id_type!(ModelId);
id_type!(ProviderId);
id_type!(RevokedAccessTokenId);
id_type!(UsageId);
id_type!(UsageMeasureId);

View file

@ -1,5 +0,0 @@
use super::*;
pub mod providers;
pub mod subscription_usages;
pub mod usages;

View file

@ -1,134 +0,0 @@
use super::*;
use sea_orm::{QueryOrder, sea_query::OnConflict};
use std::str::FromStr;
use strum::IntoEnumIterator as _;
pub struct ModelParams {
pub provider: LanguageModelProvider,
pub name: String,
pub max_requests_per_minute: i64,
pub max_tokens_per_minute: i64,
pub max_tokens_per_day: i64,
pub price_per_million_input_tokens: i32,
pub price_per_million_output_tokens: i32,
}
impl LlmDatabase {
pub async fn initialize_providers(&mut self) -> Result<()> {
self.provider_ids = self
.transaction(|tx| async move {
let existing_providers = provider::Entity::find().all(&*tx).await?;
let mut new_providers = LanguageModelProvider::iter()
.filter(|provider| {
!existing_providers
.iter()
.any(|p| p.name == provider.to_string())
})
.map(|provider| provider::ActiveModel {
name: ActiveValue::set(provider.to_string()),
..Default::default()
})
.peekable();
if new_providers.peek().is_some() {
provider::Entity::insert_many(new_providers)
.exec(&*tx)
.await?;
}
let all_providers: HashMap<_, _> = provider::Entity::find()
.all(&*tx)
.await?
.iter()
.filter_map(|provider| {
LanguageModelProvider::from_str(&provider.name)
.ok()
.map(|p| (p, provider.id))
})
.collect();
Ok(all_providers)
})
.await?;
Ok(())
}
pub async fn initialize_models(&mut self) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.models = self
.transaction(|tx| async move {
let all_models: HashMap<_, _> = model::Entity::find()
.all(&*tx)
.await?
.into_iter()
.filter_map(|model| {
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
if *id == model.provider_id {
Some(provider)
} else {
None
}
})?;
Some(((*provider, model.name.clone()), model))
})
.collect();
Ok(all_models)
})
.await?;
Ok(())
}
pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
let all_provider_ids = &self.provider_ids;
self.transaction(|tx| async move {
model::Entity::insert_many(models.iter().map(|model_params| {
let provider_id = all_provider_ids[&model_params.provider];
model::ActiveModel {
provider_id: ActiveValue::set(provider_id),
name: ActiveValue::set(model_params.name.clone()),
max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
price_per_million_input_tokens: ActiveValue::set(
model_params.price_per_million_input_tokens,
),
price_per_million_output_tokens: ActiveValue::set(
model_params.price_per_million_output_tokens,
),
..Default::default()
}
}))
.on_conflict(
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
.update_columns([
model::Column::MaxRequestsPerMinute,
model::Column::MaxTokensPerMinute,
model::Column::MaxTokensPerDay,
model::Column::PricePerMillionInputTokens,
model::Column::PricePerMillionOutputTokens,
])
.to_owned(),
)
.exec_without_returning(&*tx)
.await?;
Ok(())
})
.await?;
self.initialize_models().await
}
/// Returns the list of LLM providers.
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
self.transaction(|tx| async move {
Ok(provider::Entity::find()
.order_by_asc(provider::Column::Name)
.all(&*tx)
.await?
.into_iter()
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
.collect())
})
.await
}
}

View file

@ -1,38 +0,0 @@
use crate::db::UserId;
use super::*;
impl LlmDatabase {
pub async fn get_subscription_usage_for_period(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
self.get_subscription_usage_for_period_in_tx(
user_id,
period_start_at,
period_end_at,
&tx,
)
.await
})
.await
}
async fn get_subscription_usage_for_period_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
tx: &DatabaseTransaction,
) -> Result<Option<subscription_usage::Model>> {
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(tx)
.await?)
}
}

View file

@ -1,44 +0,0 @@
use std::str::FromStr;
use strum::IntoEnumIterator as _;
use super::*;
impl LlmDatabase {
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
let all_measures = self
.transaction(|tx| async move {
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
let new_measures = UsageMeasure::iter()
.filter(|measure| {
!existing_measures
.iter()
.any(|m| m.name == measure.to_string())
})
.map(|measure| usage_measure::ActiveModel {
name: ActiveValue::set(measure.to_string()),
..Default::default()
})
.collect::<Vec<_>>();
if !new_measures.is_empty() {
usage_measure::Entity::insert_many(new_measures)
.exec(&*tx)
.await?;
}
Ok(usage_measure::Entity::find().all(&*tx).await?)
})
.await?;
self.usage_measure_ids = all_measures
.into_iter()
.filter_map(|measure| {
UsageMeasure::from_str(&measure.name)
.ok()
.map(|um| (um, measure.id))
})
.collect();
Ok(())
}
}

View file

@ -1,45 +0,0 @@
use super::*;
use crate::{Config, Result};
use queries::providers::ModelParams;
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
db.insert_models(&[
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-5-sonnet".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 300, // $3.00/MTok
price_per_million_output_tokens: 1500, // $15.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-opus".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 10_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 1500, // $15.00/MTok
price_per_million_output_tokens: 7500, // $75.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-sonnet".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 20_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 1500, // $15.00/MTok
price_per_million_output_tokens: 7500, // $75.00/MTok
},
ModelParams {
provider: LanguageModelProvider::Anthropic,
name: "claude-3-haiku".into(),
max_requests_per_minute: 5,
max_tokens_per_minute: 25_000,
max_tokens_per_day: 300_000,
price_per_million_input_tokens: 25, // $0.25/MTok
price_per_million_output_tokens: 125, // $1.25/MTok
},
])
.await
}

View file

@ -1,6 +0,0 @@
pub mod model;
pub mod provider;
pub mod subscription_usage;
pub mod subscription_usage_meter;
pub mod usage;
pub mod usage_measure;

View file

@ -1,48 +0,0 @@
use sea_orm::entity::prelude::*;
use crate::llm::db::{ModelId, ProviderId};
/// An LLM model.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "models")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ModelId,
pub provider_id: ProviderId,
pub name: String,
pub max_requests_per_minute: i64,
pub max_tokens_per_minute: i64,
pub max_input_tokens_per_minute: i64,
pub max_output_tokens_per_minute: i64,
pub max_tokens_per_day: i64,
pub price_per_million_input_tokens: i32,
pub price_per_million_cache_creation_input_tokens: i32,
pub price_per_million_cache_read_input_tokens: i32,
pub price_per_million_output_tokens: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::provider::Entity",
from = "Column::ProviderId",
to = "super::provider::Column::Id"
)]
Provider,
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
}
impl Related<super::provider::Entity> for Entity {
fn to() -> RelationDef {
Relation::Provider.def()
}
}
impl Related<super::usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::Usages.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,25 +0,0 @@
use crate::llm::db::ProviderId;
use sea_orm::entity::prelude::*;
/// An LLM provider.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "providers")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: ProviderId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::model::Entity")]
Models,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Models.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,22 +0,0 @@
use crate::db::UserId;
use crate::db::billing_subscription::SubscriptionKind;
use sea_orm::entity::prelude::*;
use time::PrimitiveDateTime;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usages_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: Uuid,
pub user_id: UserId,
pub period_start_at: PrimitiveDateTime,
pub period_end_at: PrimitiveDateTime,
pub plan: SubscriptionKind,
pub model_requests: i32,
pub edit_predictions: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,55 +0,0 @@
use sea_orm::entity::prelude::*;
use serde::Serialize;
use crate::llm::db::ModelId;
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "subscription_usage_meters_v2")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: Uuid,
pub subscription_usage_id: Uuid,
pub model_id: ModelId,
pub mode: CompletionMode,
pub requests: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::subscription_usage::Entity",
from = "Column::SubscriptionUsageId",
to = "super::subscription_usage::Column::Id"
)]
SubscriptionUsage,
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
}
impl Related<super::subscription_usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::SubscriptionUsage.def()
}
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
#[sea_orm(string_value = "normal")]
Normal,
#[sea_orm(string_value = "max")]
Max,
}

View file

@ -1,52 +0,0 @@
use crate::{
db::UserId,
llm::db::{ModelId, UsageId, UsageMeasureId},
};
use sea_orm::entity::prelude::*;
/// An LLM usage record.
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usages")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: UsageId,
/// The ID of the Zed user.
///
/// Corresponds to the `users` table in the primary collab database.
pub user_id: UserId,
pub model_id: ModelId,
pub measure_id: UsageMeasureId,
pub timestamp: DateTime,
pub buckets: Vec<i64>,
pub is_staff: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::model::Entity",
from = "Column::ModelId",
to = "super::model::Column::Id"
)]
Model,
#[sea_orm(
belongs_to = "super::usage_measure::Entity",
from = "Column::MeasureId",
to = "super::usage_measure::Column::Id"
)]
UsageMeasure,
}
impl Related<super::model::Entity> for Entity {
fn to() -> RelationDef {
Relation::Model.def()
}
}
impl Related<super::usage_measure::Entity> for Entity {
fn to() -> RelationDef {
Relation::UsageMeasure.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,36 +0,0 @@
use crate::llm::db::UsageMeasureId;
use sea_orm::entity::prelude::*;
#[derive(
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
)]
#[strum(serialize_all = "snake_case")]
pub enum UsageMeasure {
RequestsPerMinute,
TokensPerMinute,
InputTokensPerMinute,
OutputTokensPerMinute,
TokensPerDay,
}
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "usage_measures")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: UsageMeasureId,
pub name: String,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(has_many = "super::usage::Entity")]
Usages,
}
impl Related<super::usage::Entity> for Entity {
fn to() -> RelationDef {
Relation::Usages.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -1,107 +0,0 @@
mod provider_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::time::Duration;
use crate::migrations::run_database_migrations;
use super::*;
pub struct TestLlmDb {
pub db: Option<LlmDatabase>,
pub connection: Option<sqlx::AnyConnection>,
}
impl TestLlmDb {
pub fn postgres(background: BackgroundExecutor) -> Self {
static LOCK: Mutex<()> = Mutex::new(());
let _guard = LOCK.lock();
let mut rng = StdRng::from_entropy();
let url = format!(
"postgres://postgres@localhost/zed-llm-test-{}",
rng.r#gen::<u128>()
);
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap();
let mut db = runtime.block_on(async {
sqlx::Postgres::create_database(&url)
.await
.expect("failed to create test db");
let mut options = ConnectOptions::new(url);
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
let db = LlmDatabase::new(options, Executor::Deterministic(background))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
run_database_migrations(db.options(), migrations_path)
.await
.unwrap();
db
});
db.runtime = Some(runtime);
Self {
db: Some(db),
connection: None,
}
}
pub fn db(&mut self) -> &mut LlmDatabase {
self.db.as_mut().unwrap()
}
}
#[macro_export]
macro_rules! test_llm_db {
($test_name:ident, $postgres_test_name:ident) => {
#[gpui::test]
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
if !cfg!(target_os = "macos") {
return;
}
let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
$test_name(test_db.db()).await;
}
};
}
impl Drop for TestLlmDb {
fn drop(&mut self) {
let db = self.db.take().unwrap();
if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
db.runtime.as_ref().unwrap().block_on(async {
use util::ResultExt;
let query = "
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE
pg_stat_activity.datname = current_database() AND
pid <> pg_backend_pid();
";
db.pool
.execute(sea_orm::Statement::from_string(
db.pool.get_database_backend(),
query,
))
.await
.log_err();
sqlx::Postgres::drop_database(db.options.get_url())
.await
.log_err();
})
}
}
}

View file

@ -1,31 +0,0 @@
use cloud_llm_client::LanguageModelProvider;
use pretty_assertions::assert_eq;
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;
test_llm_db!(
test_initialize_providers,
test_initialize_providers_postgres
);
async fn test_initialize_providers(db: &mut LlmDatabase) {
let initial_providers = db.list_providers().await.unwrap();
assert_eq!(initial_providers, vec![]);
db.initialize_providers().await.unwrap();
// Do it twice, to make sure the operation is idempotent.
db.initialize_providers().await.unwrap();
let providers = db.list_providers().await.unwrap();
assert_eq!(
providers,
&[
LanguageModelProvider::Anthropic,
LanguageModelProvider::Google,
LanguageModelProvider::OpenAi,
]
)
}

View file

@ -62,13 +62,6 @@ async fn main() -> Result<()> {
db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, false).await?;
if let Some(llm_database_url) = config.llm_database_url.clone() {
let db_options = db::ConnectOptions::new(llm_database_url);
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
db.initialize().await?;
collab::llm::db::seed_database(&config, &mut db, true).await?;
}
}
Some("serve") => {
let mode = match args.next().as_deref() {
@ -263,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> {
.llm_database_migrations_path
.as_deref()
.unwrap_or_else(|| {
#[cfg(feature = "sqlite")]
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
#[cfg(not(feature = "sqlite"))]
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
Path::new(default_migrations)

View file

@ -565,7 +565,6 @@ impl TestServer {
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
llm_db: None,
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
blob_store_client: None,
executor,