diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index bfcd111e3f..64b627e475 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -14,7 +14,6 @@ pub mod messages; pub mod notifications; pub mod processed_stripe_events; pub mod projects; -pub mod rate_buckets; pub mod rooms; pub mod servers; pub mod users; diff --git a/crates/collab/src/db/queries/rate_buckets.rs b/crates/collab/src/db/queries/rate_buckets.rs deleted file mode 100644 index 58b62170f4..0000000000 --- a/crates/collab/src/db/queries/rate_buckets.rs +++ /dev/null @@ -1,58 +0,0 @@ -use super::*; -use crate::db::tables::rate_buckets; -use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; - -impl Database { - /// Saves the rate limit for the given user and rate limit name if the last_refill is later - /// than the currently saved timestamp. - pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> { - if buckets.is_empty() { - return Ok(()); - } - - self.transaction(|tx| async move { - rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| { - rate_buckets::ActiveModel { - user_id: ActiveValue::Set(bucket.user_id), - rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()), - token_count: ActiveValue::Set(bucket.token_count), - last_refill: ActiveValue::Set(bucket.last_refill), - } - })) - .on_conflict( - OnConflict::columns([ - rate_buckets::Column::UserId, - rate_buckets::Column::RateLimitName, - ]) - .update_columns([ - rate_buckets::Column::TokenCount, - rate_buckets::Column::LastRefill, - ]) - .to_owned(), - ) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Retrieves the rate limit for the given user and rate limit name. - pub async fn get_rate_bucket( - &self, - user_id: UserId, - rate_limit_name: &str, - ) -> Result> { - self.transaction(|tx| async move { - let rate_limit = rate_buckets::Entity::find() - .filter(rate_buckets::Column::UserId.eq(user_id)) - .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name)) - .one(&*tx) - .await?; - - Ok(rate_limit) - }) - .await - } -} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index f3dfa6c3ab..d87ab174bd 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -28,7 +28,6 @@ pub mod project; pub mod project_collaborator; pub mod project_repository; pub mod project_repository_statuses; -pub mod rate_buckets; pub mod room; pub mod room_participant; pub mod server; diff --git a/crates/collab/src/db/tables/rate_buckets.rs b/crates/collab/src/db/tables/rate_buckets.rs deleted file mode 100644 index e16db36814..0000000000 --- a/crates/collab/src/db/tables/rate_buckets.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::db::UserId; -use sea_orm::entity::prelude::*; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "rate_buckets")] -pub struct Model { - #[sea_orm(primary_key, auto_increment = false)] - pub user_id: UserId, - #[sea_orm(primary_key, auto_increment = false)] - pub rate_limit_name: String, - pub token_count: i32, - pub last_refill: DateTime, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 9e2e967eb8..1d95cbaab1 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -6,7 +6,6 @@ pub mod env; pub mod executor; pub mod llm; pub mod migrations; -mod rate_limiter; pub mod rpc; pub mod seed; pub mod stripe_billing; @@ -25,7 +24,6 @@ pub use cents::*; use db::{ChannelId, Database}; use executor::Executor; use llm::db::LlmDatabase; -pub use rate_limiter::*; use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; @@ -295,7 +293,6 @@ pub struct AppState { pub blob_store_client: Option, pub stripe_client: Option>, pub stripe_billing: Option>, - pub rate_limiter: Arc, pub executor: Executor, pub kinesis_client: Option<::aws_sdk_kinesis::Client>, pub config: Config, @@ -348,7 +345,6 @@ impl AppState { .clone() .map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))), stripe_client, - rate_limiter: Arc::new(RateLimiter::new(db)), executor, kinesis_client: if config.kinesis_access_key.is_some() { build_kinesis_client(&config).await.log_err() diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 8f850ee847..a363d556a1 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -13,8 +13,8 @@ use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; use collab::user_backfiller::spawn_user_backfiller; use collab::{ - AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db, - env, executor::Executor, rpc::ResultExt, + AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, + executor::Executor, rpc::ResultExt, }; use collab::{ServiceMode, api::billing::poll_stripe_events_periodically}; use db::Database; @@ -111,10 +111,6 @@ async fn main() -> Result<()> { if mode.is_collab() { state.db.purge_old_embeddings().await.trace_err(); - RateLimiter::save_periodically( - state.rate_limiter.clone(), - state.executor.clone(), - ); let epoch = state .db diff --git a/crates/collab/src/rate_limiter.rs b/crates/collab/src/rate_limiter.rs deleted file mode 100644 index 889910f8ea..0000000000 --- a/crates/collab/src/rate_limiter.rs +++ /dev/null @@ -1,321 +0,0 @@ -use crate::{Database, Error, Result, db::UserId, executor::Executor}; -use chrono::{DateTime, Duration, Utc}; -use dashmap::{DashMap, DashSet}; -use rpc::ErrorCodeExt; -use sea_orm::prelude::DateTimeUtc; -use std::sync::Arc; -use util::ResultExt; - -pub trait RateLimit: Send + Sync { - fn capacity(&self) -> usize; - fn refill_duration(&self) -> Duration; - fn db_name(&self) -> &'static str; -} - -/// Used to enforce per-user rate limits -pub struct RateLimiter { - buckets: DashMap<(UserId, String), RateBucket>, - dirty_buckets: DashSet<(UserId, String)>, - db: Arc, -} - -impl RateLimiter { - pub fn new(db: Arc) -> Self { - RateLimiter { - buckets: DashMap::new(), - dirty_buckets: DashSet::new(), - db, - } - } - - /// Spawns a new task that periodically saves rate limit data to the database. - pub fn save_periodically(rate_limiter: Arc, executor: Executor) { - const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); - - executor.clone().spawn_detached(async move { - loop { - executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await; - rate_limiter.save().await.log_err(); - } - }); - } - - /// Returns an error if the user has exceeded the specified `RateLimit`. - /// Attempts to read the from the database if no cached RateBucket currently exists. - pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> { - self.check_internal(limit, user_id, Utc::now()).await - } - - async fn check_internal( - &self, - limit: &dyn RateLimit, - user_id: UserId, - now: DateTimeUtc, - ) -> Result<()> { - let bucket_key = (user_id, limit.db_name().to_string()); - - // Attempt to fetch the bucket from the database if it hasn't been cached. - // For now, we keep buckets in memory for the lifetime of the process rather than expiring them, - // but this enforces limits across restarts so long as the database is reachable. - if !self.buckets.contains_key(&bucket_key) { - if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() { - self.buckets.insert(bucket_key.clone(), bucket); - self.dirty_buckets.insert(bucket_key.clone()); - } - } - - let mut bucket = self - .buckets - .entry(bucket_key.clone()) - .or_insert_with(|| RateBucket::new(limit, now)); - - if bucket.value_mut().allow(now) { - self.dirty_buckets.insert(bucket_key); - Ok(()) - } else { - Err(rpc::proto::ErrorCode::RateLimitExceeded - .message("rate limit exceeded".into()) - .anyhow())? - } - } - - async fn load_bucket( - &self, - limit: &dyn RateLimit, - user_id: UserId, - ) -> Result, Error> { - Ok(self - .db - .get_rate_bucket(user_id, limit.db_name()) - .await? - .map(|saved_bucket| { - RateBucket::from_db( - limit, - saved_bucket.token_count as usize, - DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), - ) - })) - } - - pub async fn save(&self) -> Result<()> { - let mut buckets = Vec::new(); - self.dirty_buckets.retain(|key| { - if let Some(bucket) = self.buckets.get(key) { - buckets.push(crate::db::rate_buckets::Model { - user_id: key.0, - rate_limit_name: key.1.clone(), - token_count: bucket.token_count as i32, - last_refill: bucket.last_refill.naive_utc(), - }); - } - false - }); - - match self.db.save_rate_buckets(&buckets).await { - Ok(()) => Ok(()), - Err(err) => { - for bucket in buckets { - self.dirty_buckets - .insert((bucket.user_id, bucket.rate_limit_name)); - } - Err(err) - } - } - } -} - -#[derive(Clone, Debug)] -struct RateBucket { - capacity: usize, - token_count: usize, - refill_time_per_token: Duration, - last_refill: DateTimeUtc, -} - -impl RateBucket { - fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self { - Self { - capacity: limit.capacity(), - token_count: limit.capacity(), - refill_time_per_token: limit.refill_duration() / limit.capacity() as i32, - last_refill: now, - } - } - - fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self { - Self { - capacity: limit.capacity(), - token_count, - refill_time_per_token: limit.refill_duration() / limit.capacity() as i32, - last_refill, - } - } - - fn allow(&mut self, now: DateTimeUtc) -> bool { - self.refill(now); - if self.token_count > 0 { - self.token_count -= 1; - true - } else { - false - } - } - - fn refill(&mut self, now: DateTimeUtc) { - let elapsed = now - self.last_refill; - if elapsed >= self.refill_time_per_token { - let new_tokens = - elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds(); - self.token_count = (self.token_count + new_tokens as usize).min(self.capacity); - - let unused_refill_time = Duration::milliseconds( - elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(), - ); - self.last_refill = now - unused_refill_time; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::db::{NewUserParams, TestDb}; - use gpui::TestAppContext; - - #[gpui::test] - async fn test_rate_limiter(cx: &mut TestAppContext) { - let test_db = TestDb::sqlite(cx.executor().clone()); - let db = test_db.db().clone(); - let user_1 = db - .create_user( - "user-1@zed.dev", - None, - false, - NewUserParams { - github_login: "user-1".into(), - github_user_id: 1, - }, - ) - .await - .unwrap() - .user_id; - let user_2 = db - .create_user( - "user-2@zed.dev", - None, - false, - NewUserParams { - github_login: "user-2".into(), - github_user_id: 2, - }, - ) - .await - .unwrap() - .user_id; - - let mut now = Utc::now(); - - let rate_limiter = RateLimiter::new(db.clone()); - let rate_limit_a = Box::new(RateLimitA); - let rate_limit_b = Box::new(RateLimitB); - - // User 1 can access resource A two times before being rate-limited. - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap_err(); - - // User 2 can access resource A and user 1 can access resource B. - rate_limiter - .check_internal(&*rate_limit_b, user_2, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_b, user_1, now) - .await - .unwrap(); - - // After 1.5s, user 1 can make another request before being rate-limited again. - now += Duration::milliseconds(1500); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap_err(); - - // After 500ms, user 1 can make another request before being rate-limited again. - now += Duration::milliseconds(500); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap_err(); - - rate_limiter.save().await.unwrap(); - - // Rate limits are reloaded from the database, so user A is still rate-limited - // for resource A. - let rate_limiter = RateLimiter::new(db.clone()); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap_err(); - - // After 1s, user 1 can make another request before being rate-limited again. - now += Duration::seconds(1); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap(); - rate_limiter - .check_internal(&*rate_limit_a, user_1, now) - .await - .unwrap_err(); - } - - struct RateLimitA; - - impl RateLimit for RateLimitA { - fn capacity(&self) -> usize { - 2 - } - - fn refill_duration(&self) -> Duration { - Duration::seconds(2) - } - - fn db_name(&self) -> &'static str { - "rate-limit-a" - } - } - - struct RateLimitB; - - impl RateLimit for RateLimitB { - fn capacity(&self) -> usize { - 10 - } - - fn refill_duration(&self) -> Duration { - Duration::seconds(3) - } - - fn db_name(&self) -> &'static str { - "rate-limit-b" - } - } -} diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 835d637b9f..ca94312e0f 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -1,5 +1,5 @@ use crate::{ - AppState, Config, RateLimiter, + AppState, Config, db::{NewUserParams, UserId, tests::TestDb}, executor::Executor, rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion}, @@ -517,7 +517,6 @@ impl TestServer { blob_store_client: None, stripe_client: None, stripe_billing: None, - rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), executor, kinesis_client: None, config: Config {