collab: Remove unused billing-related database code (#36282)
This PR removes a bunch of unused database code related to billing, as we no longer need it. Release Notes: - N/A
This commit is contained in:
parent
bf34e185d5
commit
e664a9bc48
36 changed files with 1 additions and 1419 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -3270,7 +3270,6 @@ dependencies = [
|
|||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collab_ui",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
|
|
|
@ -29,7 +29,6 @@ axum-extra = { version = "0.4", features = ["erased-json"] }
|
|||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
envy = "0.4.2"
|
||||
|
|
|
@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind;
|
|||
pub use tests::TestDb;
|
||||
|
||||
pub use ids::*;
|
||||
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
|
||||
pub use queries::billing_subscriptions::{
|
||||
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
pub use queries::contributors::ContributorSelector;
|
||||
pub use queries::processed_stripe_events::CreateProcessedStripeEventParams;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
pub use tables::user::Model as User;
|
||||
pub use tables::*;
|
||||
|
|
|
@ -70,9 +70,6 @@ macro_rules! id_type {
|
|||
}
|
||||
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BillingCustomerId);
|
||||
id_type!(BillingSubscriptionId);
|
||||
id_type!(BillingPreferencesId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
use super::*;
|
||||
|
||||
pub mod access_tokens;
|
||||
pub mod billing_customers;
|
||||
pub mod billing_preferences;
|
||||
pub mod billing_subscriptions;
|
||||
pub mod buffers;
|
||||
pub mod channels;
|
||||
pub mod contacts;
|
||||
|
@ -12,7 +9,6 @@ pub mod embeddings;
|
|||
pub mod extensions;
|
||||
pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod processed_stripe_events;
|
||||
pub mod projects;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingCustomerParams {
|
||||
pub user_id: UserId,
|
||||
pub stripe_customer_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingCustomerParams {
|
||||
pub user_id: ActiveValue<UserId>,
|
||||
pub stripe_customer_id: ActiveValue<String>,
|
||||
pub has_overdue_invoices: ActiveValue<bool>,
|
||||
pub trial_started_at: ActiveValue<Option<DateTime>>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new billing customer.
|
||||
pub async fn create_billing_customer(
|
||||
&self,
|
||||
params: &CreateBillingCustomerParams,
|
||||
) -> Result<billing_customer::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
|
||||
user_id: ActiveValue::set(params.user_id),
|
||||
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(customer)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the specified billing customer.
|
||||
pub async fn update_billing_customer(
|
||||
&self,
|
||||
id: BillingCustomerId,
|
||||
params: &UpdateBillingCustomerParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_customer::Entity::update(billing_customer::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
user_id: params.user_id.clone(),
|
||||
stripe_customer_id: params.stripe_customer_id.clone(),
|
||||
has_overdue_invoices: params.has_overdue_invoices.clone(),
|
||||
trial_started_at: params.trial_started_at.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_billing_customer_by_id(
|
||||
&self,
|
||||
id: BillingCustomerId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::Id.eq(id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing customer for the user with the specified ID.
|
||||
pub async fn get_billing_customer_by_user_id(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing customer for the user with the specified Stripe customer ID.
|
||||
pub async fn get_billing_customer_by_stripe_customer_id(
|
||||
&self,
|
||||
stripe_customer_id: &str,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
use super::*;
|
||||
|
||||
impl Database {
|
||||
/// Returns the billing preferences for the given user, if they exist.
|
||||
pub async fn get_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,158 +0,0 @@
|
|||
use anyhow::Context as _;
|
||||
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingSubscriptionParams {
|
||||
pub billing_customer_id: BillingCustomerId,
|
||||
pub kind: Option<SubscriptionKind>,
|
||||
pub stripe_subscription_id: String,
|
||||
pub stripe_subscription_status: StripeSubscriptionStatus,
|
||||
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
|
||||
pub stripe_current_period_start: Option<i64>,
|
||||
pub stripe_current_period_end: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingSubscriptionParams {
|
||||
pub billing_customer_id: ActiveValue<BillingCustomerId>,
|
||||
pub kind: ActiveValue<Option<SubscriptionKind>>,
|
||||
pub stripe_subscription_id: ActiveValue<String>,
|
||||
pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
|
||||
pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
|
||||
pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
|
||||
pub stripe_current_period_start: ActiveValue<Option<i64>>,
|
||||
pub stripe_current_period_end: ActiveValue<Option<i64>>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new billing subscription.
|
||||
pub async fn create_billing_subscription(
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<billing_subscription::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
kind: ActiveValue::set(params.kind),
|
||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
|
||||
stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
|
||||
stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
|
||||
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
.last_insert_id;
|
||||
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.context("failed to retrieve inserted billing subscription")?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the specified billing subscription.
|
||||
pub async fn update_billing_subscription(
|
||||
&self,
|
||||
id: BillingSubscriptionId,
|
||||
params: &UpdateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_subscription::Entity::update(billing_subscription::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
billing_customer_id: params.billing_customer_id.clone(),
|
||||
kind: params.kind.clone(),
|
||||
stripe_subscription_id: params.stripe_subscription_id.clone(),
|
||||
stripe_subscription_status: params.stripe_subscription_status.clone(),
|
||||
stripe_cancel_at: params.stripe_cancel_at.clone(),
|
||||
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
|
||||
stripe_current_period_start: params.stripe_current_period_start.clone(),
|
||||
stripe_current_period_end: params.stripe_current_period_end.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing subscription with the specified Stripe subscription ID.
|
||||
pub async fn get_billing_subscription_by_stripe_subscription_id(
|
||||
&self,
|
||||
stripe_subscription_id: &str,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_billing_subscription(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.filter(
|
||||
Condition::all()
|
||||
.add(
|
||||
Condition::any()
|
||||
.add(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.add(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Trialing),
|
||||
),
|
||||
)
|
||||
.add(billing_subscription::Column::Kind.is_not_null()),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
}
|
||||
|
||||
/// Returns the count of the active billing subscriptions for the user with the specified ID.
|
||||
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
|
||||
self.transaction(|tx| async move {
|
||||
let count = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(
|
||||
billing_customer::Column::UserId.eq(user_id).and(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active)
|
||||
.or(billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Trialing)),
|
||||
),
|
||||
)
|
||||
.count(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(count as usize)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateProcessedStripeEventParams {
|
||||
pub stripe_event_id: String,
|
||||
pub stripe_event_type: String,
|
||||
pub stripe_event_created_timestamp: i64,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new processed Stripe event.
|
||||
pub async fn create_processed_stripe_event(
|
||||
&self,
|
||||
params: &CreateProcessedStripeEventParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
|
||||
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
|
||||
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
|
||||
stripe_event_created_timestamp: ActiveValue::set(
|
||||
params.stripe_event_created_timestamp,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the processed Stripe event with the specified event ID.
|
||||
pub async fn get_processed_stripe_event_by_event_id(
|
||||
&self,
|
||||
event_id: &str,
|
||||
) -> Result<Option<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find_by_id(event_id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the processed Stripe events with the specified event IDs.
|
||||
pub async fn get_processed_stripe_events_by_event_ids(
|
||||
&self,
|
||||
event_ids: &[&str],
|
||||
) -> Result<Vec<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find()
|
||||
.filter(
|
||||
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the Stripe event with the specified ID has already been processed.
|
||||
pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result<bool> {
|
||||
Ok(self
|
||||
.get_processed_stripe_event_by_event_id(event_id)
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
}
|
|
@ -1,7 +1,4 @@
|
|||
pub mod access_token;
|
||||
pub mod billing_customer;
|
||||
pub mod billing_preference;
|
||||
pub mod billing_subscription;
|
||||
pub mod buffer;
|
||||
pub mod buffer_operation;
|
||||
pub mod buffer_snapshot;
|
||||
|
@ -23,7 +20,6 @@ pub mod notification;
|
|||
pub mod notification_kind;
|
||||
pub mod observed_buffer_edits;
|
||||
pub mod observed_channel_messages;
|
||||
pub mod processed_stripe_event;
|
||||
pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod project_repository;
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
use crate::db::{BillingCustomerId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
/// A billing customer.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_customers")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingCustomerId,
|
||||
pub user_id: UserId,
|
||||
pub stripe_customer_id: String,
|
||||
pub has_overdue_invoices: bool,
|
||||
pub trial_started_at: Option<DateTime>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
#[sea_orm(has_many = "super::billing_subscription::Entity")]
|
||||
BillingSubscription,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_subscription::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingSubscription.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,32 +0,0 @@
|
|||
use crate::db::{BillingPreferencesId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_preferences")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingPreferencesId,
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
pub model_request_overages_enabled: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,161 +0,0 @@
|
|||
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
||||
use chrono::{Datelike as _, NaiveDate, Utc};
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
/// A billing subscription.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_subscriptions")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingSubscriptionId,
|
||||
pub billing_customer_id: BillingCustomerId,
|
||||
pub kind: Option<SubscriptionKind>,
|
||||
pub stripe_subscription_id: String,
|
||||
pub stripe_subscription_status: StripeSubscriptionStatus,
|
||||
pub stripe_cancel_at: Option<DateTime>,
|
||||
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
|
||||
pub stripe_current_period_start: Option<i64>,
|
||||
pub stripe_current_period_end: Option<i64>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_start = self.stripe_current_period_start?;
|
||||
chrono::DateTime::from_timestamp(period_start, 0)
|
||||
}
|
||||
|
||||
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_end = self.stripe_current_period_end?;
|
||||
chrono::DateTime::from_timestamp(period_end, 0)
|
||||
}
|
||||
|
||||
pub fn current_period(
|
||||
subscription: Option<Self>,
|
||||
is_staff: bool,
|
||||
) -> Option<(DateTimeUtc, DateTimeUtc)> {
|
||||
if is_staff {
|
||||
let now = Utc::now();
|
||||
let year = now.year();
|
||||
let month = now.month();
|
||||
|
||||
let first_day_of_this_month =
|
||||
NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?;
|
||||
|
||||
let next_month = if month == 12 { 1 } else { month + 1 };
|
||||
let next_month_year = if month == 12 { year + 1 } else { year };
|
||||
let first_day_of_next_month =
|
||||
NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?;
|
||||
|
||||
let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1);
|
||||
|
||||
Some((
|
||||
first_day_of_this_month.and_utc(),
|
||||
last_day_of_this_month.and_utc(),
|
||||
))
|
||||
} else {
|
||||
let subscription = subscription?;
|
||||
let period_start_at = subscription.current_period_start_at()?;
|
||||
let period_end_at = subscription.current_period_end_at()?;
|
||||
|
||||
Some((period_start_at, period_end_at))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::billing_customer::Entity",
|
||||
from = "Column::BillingCustomerId",
|
||||
to = "super::billing_customer::Column::Id"
|
||||
)]
|
||||
BillingCustomer,
|
||||
}
|
||||
|
||||
impl Related<super::billing_customer::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingCustomer.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SubscriptionKind {
|
||||
#[sea_orm(string_value = "zed_pro")]
|
||||
ZedPro,
|
||||
#[sea_orm(string_value = "zed_pro_trial")]
|
||||
ZedProTrial,
|
||||
#[sea_orm(string_value = "zed_free")]
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
impl From<SubscriptionKind> for cloud_llm_client::Plan {
|
||||
fn from(value: SubscriptionKind) -> Self {
|
||||
match value {
|
||||
SubscriptionKind::ZedPro => Self::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => Self::ZedFree,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The status of a Stripe subscription.
|
||||
///
|
||||
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
|
||||
#[derive(
|
||||
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
|
||||
)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StripeSubscriptionStatus {
|
||||
#[default]
|
||||
#[sea_orm(string_value = "incomplete")]
|
||||
Incomplete,
|
||||
#[sea_orm(string_value = "incomplete_expired")]
|
||||
IncompleteExpired,
|
||||
#[sea_orm(string_value = "trialing")]
|
||||
Trialing,
|
||||
#[sea_orm(string_value = "active")]
|
||||
Active,
|
||||
#[sea_orm(string_value = "past_due")]
|
||||
PastDue,
|
||||
#[sea_orm(string_value = "canceled")]
|
||||
Canceled,
|
||||
#[sea_orm(string_value = "unpaid")]
|
||||
Unpaid,
|
||||
#[sea_orm(string_value = "paused")]
|
||||
Paused,
|
||||
}
|
||||
|
||||
impl StripeSubscriptionStatus {
|
||||
pub fn is_cancelable(&self) -> bool {
|
||||
match self {
|
||||
Self::Trialing | Self::Active | Self::PastDue => true,
|
||||
Self::Incomplete
|
||||
| Self::IncompleteExpired
|
||||
| Self::Canceled
|
||||
| Self::Unpaid
|
||||
| Self::Paused => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The cancellation reason for a Stripe subscription.
|
||||
///
|
||||
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason)
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StripeCancellationReason {
|
||||
#[sea_orm(string_value = "cancellation_requested")]
|
||||
CancellationRequested,
|
||||
#[sea_orm(string_value = "payment_disputed")]
|
||||
PaymentDisputed,
|
||||
#[sea_orm(string_value = "payment_failed")]
|
||||
PaymentFailed,
|
||||
}
|
|
@ -1,16 +0,0 @@
|
|||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "processed_stripe_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub stripe_event_id: String,
|
||||
pub stripe_event_type: String,
|
||||
pub stripe_event_created_timestamp: i64,
|
||||
pub processed_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -29,8 +29,6 @@ pub struct Model {
|
|||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::access_token::Entity")]
|
||||
AccessToken,
|
||||
#[sea_orm(has_one = "super::billing_customer::Entity")]
|
||||
BillingCustomer,
|
||||
#[sea_orm(has_one = "super::room_participant::Entity")]
|
||||
RoomParticipant,
|
||||
#[sea_orm(has_many = "super::project::Entity")]
|
||||
|
@ -68,12 +66,6 @@ impl Related<super::access_token::Entity> for Entity {
|
|||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_customer::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingCustomer.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::room_participant::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::RoomParticipant.def()
|
||||
|
|
|
@ -8,7 +8,6 @@ mod embedding_tests;
|
|||
mod extension_tests;
|
||||
mod feature_flag_tests;
|
||||
mod message_tests;
|
||||
mod processed_stripe_event_tests;
|
||||
mod user_tests;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::test_both_dbs;
|
||||
|
||||
use super::{CreateProcessedStripeEventParams, Database};
|
||||
|
||||
test_both_dbs!(
|
||||
test_already_processed_stripe_event,
|
||||
test_already_processed_stripe_event_postgres,
|
||||
test_already_processed_stripe_event_sqlite
|
||||
);
|
||||
|
||||
async fn test_already_processed_stripe_event(db: &Arc<Database>) {
|
||||
let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string();
|
||||
let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string();
|
||||
|
||||
db.create_processed_stripe_event(&CreateProcessedStripeEventParams {
|
||||
stripe_event_id: processed_event_id.clone(),
|
||||
stripe_event_type: "customer.created".into(),
|
||||
stripe_event_created_timestamp: 1722355968,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
db.already_processed_stripe_event(&processed_event_id)
|
||||
.await
|
||||
.unwrap(),
|
||||
"Expected {processed_event_id} to already be processed"
|
||||
);
|
||||
|
||||
assert!(
|
||||
!db.already_processed_stripe_event(&unprocessed_event_id)
|
||||
.await
|
||||
.unwrap(),
|
||||
"Expected {unprocessed_event_id} to be unprocessed"
|
||||
);
|
||||
}
|
|
@ -20,7 +20,6 @@ use axum::{
|
|||
};
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
@ -242,7 +241,6 @@ impl ServiceMode {
|
|||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub executor: Executor,
|
||||
|
@ -257,20 +255,6 @@ impl AppState {
|
|||
let mut db = Database::new(db_options).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
.llm_database_url
|
||||
.clone()
|
||||
.zip(config.llm_database_max_connections)
|
||||
{
|
||||
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||
llm_db_options.max_connections(llm_database_max_connections);
|
||||
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||
llm_db.initialize().await?;
|
||||
Some(Arc::new(llm_db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let livekit_client = if let Some(((server, key), secret)) = config
|
||||
.livekit_server
|
||||
.as_ref()
|
||||
|
@ -289,7 +273,6 @@ impl AppState {
|
|||
let db = Arc::new(db);
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
livekit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
executor,
|
||||
|
|
|
@ -1,30 +1,9 @@
|
|||
mod ids;
|
||||
mod queries;
|
||||
mod seed;
|
||||
mod tables;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::HashMap;
|
||||
pub use ids::*;
|
||||
pub use seed::*;
|
||||
pub use tables::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
use usage_measure::UsageMeasure;
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::prelude::*;
|
||||
use sea_orm::{
|
||||
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
|
||||
};
|
||||
use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::TransactionHandle;
|
||||
|
@ -36,9 +15,6 @@ pub struct LlmDatabase {
|
|||
pool: DatabaseConnection,
|
||||
#[allow(unused)]
|
||||
executor: Executor,
|
||||
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
|
||||
models: HashMap<(LanguageModelProvider, String), model::Model>,
|
||||
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
|
||||
#[cfg(test)]
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
}
|
||||
|
@ -51,59 +27,11 @@ impl LlmDatabase {
|
|||
options: options.clone(),
|
||||
pool: sea_orm::Database::connect(options).await?,
|
||||
executor,
|
||||
provider_ids: HashMap::default(),
|
||||
models: HashMap::default(),
|
||||
usage_measure_ids: HashMap::default(),
|
||||
#[cfg(test)]
|
||||
runtime: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn initialize(&mut self) -> Result<()> {
|
||||
self.initialize_providers().await?;
|
||||
self.initialize_models().await?;
|
||||
self.initialize_usage_measures().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the list of all known models, with their [`LanguageModelProvider`].
|
||||
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
|
||||
self.models
|
||||
.iter()
|
||||
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||
self.models
|
||||
.keys()
|
||||
.filter_map(|(model_provider, model_name)| {
|
||||
if model_provider == &provider {
|
||||
Some(model_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.get(&(provider, name.to_string()))
|
||||
.with_context(|| format!("unknown model {provider:?}:{name}"))?)
|
||||
}
|
||||
|
||||
pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.values()
|
||||
.find(|model| model.id == id)
|
||||
.with_context(|| format!("no model for ID {id:?}"))?)
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
use sea_orm::{DbErr, entity::prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(BillingEventId);
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
|
@ -1,5 +0,0 @@
|
|||
use super::*;
|
||||
|
||||
pub mod providers;
|
||||
pub mod subscription_usages;
|
||||
pub mod usages;
|
|
@ -1,134 +0,0 @@
|
|||
use super::*;
|
||||
use sea_orm::{QueryOrder, sea_query::OnConflict};
|
||||
use std::str::FromStr;
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
pub struct ModelParams {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub name: String,
|
||||
pub max_requests_per_minute: i64,
|
||||
pub max_tokens_per_minute: i64,
|
||||
pub max_tokens_per_day: i64,
|
||||
pub price_per_million_input_tokens: i32,
|
||||
pub price_per_million_output_tokens: i32,
|
||||
}
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_providers(&mut self) -> Result<()> {
|
||||
self.provider_ids = self
|
||||
.transaction(|tx| async move {
|
||||
let existing_providers = provider::Entity::find().all(&*tx).await?;
|
||||
|
||||
let mut new_providers = LanguageModelProvider::iter()
|
||||
.filter(|provider| {
|
||||
!existing_providers
|
||||
.iter()
|
||||
.any(|p| p.name == provider.to_string())
|
||||
})
|
||||
.map(|provider| provider::ActiveModel {
|
||||
name: ActiveValue::set(provider.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.peekable();
|
||||
|
||||
if new_providers.peek().is_some() {
|
||||
provider::Entity::insert_many(new_providers)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let all_providers: HashMap<_, _> = provider::Entity::find()
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
LanguageModelProvider::from_str(&provider.name)
|
||||
.ok()
|
||||
.map(|p| (p, provider.id))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(all_providers)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn initialize_models(&mut self) -> Result<()> {
|
||||
let all_provider_ids = &self.provider_ids;
|
||||
self.models = self
|
||||
.transaction(|tx| async move {
|
||||
let all_models: HashMap<_, _> = model::Entity::find()
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|model| {
|
||||
let provider = all_provider_ids.iter().find_map(|(provider, id)| {
|
||||
if *id == model.provider_id {
|
||||
Some(provider)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})?;
|
||||
Some(((*provider, model.name.clone()), model))
|
||||
})
|
||||
.collect();
|
||||
Ok(all_models)
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_models(&mut self, models: &[ModelParams]) -> Result<()> {
|
||||
let all_provider_ids = &self.provider_ids;
|
||||
self.transaction(|tx| async move {
|
||||
model::Entity::insert_many(models.iter().map(|model_params| {
|
||||
let provider_id = all_provider_ids[&model_params.provider];
|
||||
model::ActiveModel {
|
||||
provider_id: ActiveValue::set(provider_id),
|
||||
name: ActiveValue::set(model_params.name.clone()),
|
||||
max_requests_per_minute: ActiveValue::set(model_params.max_requests_per_minute),
|
||||
max_tokens_per_minute: ActiveValue::set(model_params.max_tokens_per_minute),
|
||||
max_tokens_per_day: ActiveValue::set(model_params.max_tokens_per_day),
|
||||
price_per_million_input_tokens: ActiveValue::set(
|
||||
model_params.price_per_million_input_tokens,
|
||||
),
|
||||
price_per_million_output_tokens: ActiveValue::set(
|
||||
model_params.price_per_million_output_tokens,
|
||||
),
|
||||
..Default::default()
|
||||
}
|
||||
}))
|
||||
.on_conflict(
|
||||
OnConflict::columns([model::Column::ProviderId, model::Column::Name])
|
||||
.update_columns([
|
||||
model::Column::MaxRequestsPerMinute,
|
||||
model::Column::MaxTokensPerMinute,
|
||||
model::Column::MaxTokensPerDay,
|
||||
model::Column::PricePerMillionInputTokens,
|
||||
model::Column::PricePerMillionOutputTokens,
|
||||
])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
self.initialize_models().await
|
||||
}
|
||||
|
||||
/// Returns the list of LLM providers.
|
||||
pub async fn list_providers(&self) -> Result<Vec<LanguageModelProvider>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(provider::Entity::find()
|
||||
.order_by_asc(provider::Column::Name)
|
||||
.all(&*tx)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|p| LanguageModelProvider::from_str(&p.name).ok())
|
||||
.collect())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
use crate::db::UserId;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_subscription_usage_for_period(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
period_start_at: DateTimeUtc,
|
||||
period_end_at: DateTimeUtc,
|
||||
) -> Result<Option<subscription_usage::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
self.get_subscription_usage_for_period_in_tx(
|
||||
user_id,
|
||||
period_start_at,
|
||||
period_end_at,
|
||||
&tx,
|
||||
)
|
||||
.await
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_subscription_usage_for_period_in_tx(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
period_start_at: DateTimeUtc,
|
||||
period_end_at: DateTimeUtc,
|
||||
tx: &DatabaseTransaction,
|
||||
) -> Result<Option<subscription_usage::Model>> {
|
||||
Ok(subscription_usage::Entity::find()
|
||||
.filter(subscription_usage::Column::UserId.eq(user_id))
|
||||
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
|
||||
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
|
||||
.one(tx)
|
||||
.await?)
|
||||
}
|
||||
}
|
|
@ -1,44 +0,0 @@
|
|||
use std::str::FromStr;
|
||||
use strum::IntoEnumIterator as _;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn initialize_usage_measures(&mut self) -> Result<()> {
|
||||
let all_measures = self
|
||||
.transaction(|tx| async move {
|
||||
let existing_measures = usage_measure::Entity::find().all(&*tx).await?;
|
||||
|
||||
let new_measures = UsageMeasure::iter()
|
||||
.filter(|measure| {
|
||||
!existing_measures
|
||||
.iter()
|
||||
.any(|m| m.name == measure.to_string())
|
||||
})
|
||||
.map(|measure| usage_measure::ActiveModel {
|
||||
name: ActiveValue::set(measure.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !new_measures.is_empty() {
|
||||
usage_measure::Entity::insert_many(new_measures)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(usage_measure::Entity::find().all(&*tx).await?)
|
||||
})
|
||||
.await?;
|
||||
|
||||
self.usage_measure_ids = all_measures
|
||||
.into_iter()
|
||||
.filter_map(|measure| {
|
||||
UsageMeasure::from_str(&measure.name)
|
||||
.ok()
|
||||
.map(|um| (um, measure.id))
|
||||
})
|
||||
.collect();
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
use super::*;
|
||||
use crate::{Config, Result};
|
||||
use queries::providers::ModelParams;
|
||||
|
||||
pub async fn seed_database(_config: &Config, db: &mut LlmDatabase, _force: bool) -> Result<()> {
|
||||
db.insert_models(&[
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-5-sonnet".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 20_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 300, // $3.00/MTok
|
||||
price_per_million_output_tokens: 1500, // $15.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-opus".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 10_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 1500, // $15.00/MTok
|
||||
price_per_million_output_tokens: 7500, // $75.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-sonnet".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 20_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 1500, // $15.00/MTok
|
||||
price_per_million_output_tokens: 7500, // $75.00/MTok
|
||||
},
|
||||
ModelParams {
|
||||
provider: LanguageModelProvider::Anthropic,
|
||||
name: "claude-3-haiku".into(),
|
||||
max_requests_per_minute: 5,
|
||||
max_tokens_per_minute: 25_000,
|
||||
max_tokens_per_day: 300_000,
|
||||
price_per_million_input_tokens: 25, // $0.25/MTok
|
||||
price_per_million_output_tokens: 125, // $1.25/MTok
|
||||
},
|
||||
])
|
||||
.await
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
pub mod model;
|
||||
pub mod provider;
|
||||
pub mod subscription_usage;
|
||||
pub mod subscription_usage_meter;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
|
@ -1,48 +0,0 @@
|
|||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::{ModelId, ProviderId};
|
||||
|
||||
/// An LLM model.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "models")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: ModelId,
|
||||
pub provider_id: ProviderId,
|
||||
pub name: String,
|
||||
pub max_requests_per_minute: i64,
|
||||
pub max_tokens_per_minute: i64,
|
||||
pub max_input_tokens_per_minute: i64,
|
||||
pub max_output_tokens_per_minute: i64,
|
||||
pub max_tokens_per_day: i64,
|
||||
pub price_per_million_input_tokens: i32,
|
||||
pub price_per_million_cache_creation_input_tokens: i32,
|
||||
pub price_per_million_cache_read_input_tokens: i32,
|
||||
pub price_per_million_output_tokens: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::provider::Entity",
|
||||
from = "Column::ProviderId",
|
||||
to = "super::provider::Column::Id"
|
||||
)]
|
||||
Provider,
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Provider.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::usage::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Usages.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,25 +0,0 @@
|
|||
use crate::llm::db::ProviderId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
/// An LLM provider.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "providers")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: ProviderId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::model::Entity")]
|
||||
Models,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Models.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,22 +0,0 @@
|
|||
use crate::db::UserId;
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use sea_orm::entity::prelude::*;
|
||||
use time::PrimitiveDateTime;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "subscription_usages_v2")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: Uuid,
|
||||
pub user_id: UserId,
|
||||
pub period_start_at: PrimitiveDateTime,
|
||||
pub period_end_at: PrimitiveDateTime,
|
||||
pub plan: SubscriptionKind,
|
||||
pub model_requests: i32,
|
||||
pub edit_predictions: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,55 +0,0 @@
|
|||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::llm::db::ModelId;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "subscription_usage_meters_v2")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: Uuid,
|
||||
pub subscription_usage_id: Uuid,
|
||||
pub model_id: ModelId,
|
||||
pub mode: CompletionMode,
|
||||
pub requests: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::subscription_usage::Entity",
|
||||
from = "Column::SubscriptionUsageId",
|
||||
to = "super::subscription_usage::Column::Id"
|
||||
)]
|
||||
SubscriptionUsage,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::subscription_usage::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::SubscriptionUsage.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionMode {
|
||||
#[sea_orm(string_value = "normal")]
|
||||
Normal,
|
||||
#[sea_orm(string_value = "max")]
|
||||
Max,
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{ModelId, UsageId, UsageMeasureId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
/// An LLM usage record.
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: UsageId,
|
||||
/// The ID of the Zed user.
|
||||
///
|
||||
/// Corresponds to the `users` table in the primary collab database.
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub measure_id: UsageMeasureId,
|
||||
pub timestamp: DateTime,
|
||||
pub buckets: Vec<i64>,
|
||||
pub is_staff: bool,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::usage_measure::Entity",
|
||||
from = "Column::MeasureId",
|
||||
to = "super::usage_measure::Column::Id"
|
||||
)]
|
||||
UsageMeasure,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::usage_measure::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::UsageMeasure.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,36 +0,0 @@
|
|||
use crate::llm::db::UsageMeasureId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(
|
||||
Copy, Clone, Debug, PartialEq, Eq, Hash, strum::EnumString, strum::Display, strum::EnumIter,
|
||||
)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum UsageMeasure {
|
||||
RequestsPerMinute,
|
||||
TokensPerMinute,
|
||||
InputTokensPerMinute,
|
||||
OutputTokensPerMinute,
|
||||
TokensPerDay,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "usage_measures")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: UsageMeasureId,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
}
|
||||
|
||||
impl Related<super::usage::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Usages.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
|
@ -1,107 +0,0 @@
|
|||
mod provider_tests;
|
||||
|
||||
use gpui::BackgroundExecutor;
|
||||
use parking_lot::Mutex;
|
||||
use rand::prelude::*;
|
||||
use sea_orm::ConnectionTrait;
|
||||
use sqlx::migrate::MigrateDatabase;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct TestLlmDb {
|
||||
pub db: Option<LlmDatabase>,
|
||||
pub connection: Option<sqlx::AnyConnection>,
|
||||
}
|
||||
|
||||
impl TestLlmDb {
|
||||
pub fn postgres(background: BackgroundExecutor) -> Self {
|
||||
static LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
let _guard = LOCK.lock();
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let url = format!(
|
||||
"postgres://postgres@localhost/zed-llm-test-{}",
|
||||
rng.r#gen::<u128>()
|
||||
);
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let mut db = runtime.block_on(async {
|
||||
sqlx::Postgres::create_database(&url)
|
||||
.await
|
||||
.expect("failed to create test db");
|
||||
let mut options = ConnectOptions::new(url);
|
||||
options
|
||||
.max_connections(5)
|
||||
.idle_timeout(Duration::from_secs(0));
|
||||
let db = LlmDatabase::new(options, Executor::Deterministic(background))
|
||||
.await
|
||||
.unwrap();
|
||||
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
|
||||
run_database_migrations(db.options(), migrations_path)
|
||||
.await
|
||||
.unwrap();
|
||||
db
|
||||
});
|
||||
|
||||
db.runtime = Some(runtime);
|
||||
|
||||
Self {
|
||||
db: Some(db),
|
||||
connection: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn db(&mut self) -> &mut LlmDatabase {
|
||||
self.db.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! test_llm_db {
|
||||
($test_name:ident, $postgres_test_name:ident) => {
|
||||
#[gpui::test]
|
||||
async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
|
||||
if !cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut test_db = $crate::llm::db::TestLlmDb::postgres(cx.executor().clone());
|
||||
$test_name(test_db.db()).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Drop for TestLlmDb {
|
||||
fn drop(&mut self) {
|
||||
let db = self.db.take().unwrap();
|
||||
if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
|
||||
db.runtime.as_ref().unwrap().block_on(async {
|
||||
use util::ResultExt;
|
||||
let query = "
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE
|
||||
pg_stat_activity.datname = current_database() AND
|
||||
pid <> pg_backend_pid();
|
||||
";
|
||||
db.pool
|
||||
.execute(sea_orm::Statement::from_string(
|
||||
db.pool.get_database_backend(),
|
||||
query,
|
||||
))
|
||||
.await
|
||||
.log_err();
|
||||
sqlx::Postgres::drop_database(db.options.get_url())
|
||||
.await
|
||||
.log_err();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
use cloud_llm_client::LanguageModelProvider;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_llm_db;
|
||||
|
||||
test_llm_db!(
|
||||
test_initialize_providers,
|
||||
test_initialize_providers_postgres
|
||||
);
|
||||
|
||||
async fn test_initialize_providers(db: &mut LlmDatabase) {
|
||||
let initial_providers = db.list_providers().await.unwrap();
|
||||
assert_eq!(initial_providers, vec![]);
|
||||
|
||||
db.initialize_providers().await.unwrap();
|
||||
|
||||
// Do it twice, to make sure the operation is idempotent.
|
||||
db.initialize_providers().await.unwrap();
|
||||
|
||||
let providers = db.list_providers().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
providers,
|
||||
&[
|
||||
LanguageModelProvider::Anthropic,
|
||||
LanguageModelProvider::Google,
|
||||
LanguageModelProvider::OpenAi,
|
||||
]
|
||||
)
|
||||
}
|
|
@ -62,13 +62,6 @@ async fn main() -> Result<()> {
|
|||
db.initialize_notification_kinds().await?;
|
||||
|
||||
collab::seed::seed(&config, &db, false).await?;
|
||||
|
||||
if let Some(llm_database_url) = config.llm_database_url.clone() {
|
||||
let db_options = db::ConnectOptions::new(llm_database_url);
|
||||
let mut db = LlmDatabase::new(db_options.clone(), Executor::Production).await?;
|
||||
db.initialize().await?;
|
||||
collab::llm::db::seed_database(&config, &mut db, true).await?;
|
||||
}
|
||||
}
|
||||
Some("serve") => {
|
||||
let mode = match args.next().as_deref() {
|
||||
|
@ -263,9 +256,6 @@ async fn setup_llm_database(config: &Config) -> Result<()> {
|
|||
.llm_database_migrations_path
|
||||
.as_deref()
|
||||
.unwrap_or_else(|| {
|
||||
#[cfg(feature = "sqlite")]
|
||||
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm.sqlite");
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations_llm");
|
||||
|
||||
Path::new(default_migrations)
|
||||
|
|
|
@ -565,7 +565,6 @@ impl TestServer {
|
|||
) -> Arc<AppState> {
|
||||
Arc::new(AppState {
|
||||
db: test_db.db().clone(),
|
||||
llm_db: None,
|
||||
livekit_client: Some(Arc::new(livekit_test_server.create_api_client())),
|
||||
blob_store_client: None,
|
||||
executor,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue