collab: Remove unused RateLimiter (#29343)

This PR removes the `RateLimiter` from the collab codebase, as it is no
longer used.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-04-24 10:23:17 -04:00 committed by GitHub
parent fd8eeb537d
commit ea5ce2a1a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 3 additions and 424 deletions

View file

@ -14,7 +14,6 @@ pub mod messages;
pub mod notifications;
pub mod processed_stripe_events;
pub mod projects;
pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;

View file

@ -1,58 +0,0 @@
use super::*;
use crate::db::tables::rate_buckets;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
impl Database {
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
/// than the currently saved timestamp.
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
if buckets.is_empty() {
return Ok(());
}
self.transaction(|tx| async move {
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
rate_buckets::ActiveModel {
user_id: ActiveValue::Set(bucket.user_id),
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
token_count: ActiveValue::Set(bucket.token_count),
last_refill: ActiveValue::Set(bucket.last_refill),
}
}))
.on_conflict(
OnConflict::columns([
rate_buckets::Column::UserId,
rate_buckets::Column::RateLimitName,
])
.update_columns([
rate_buckets::Column::TokenCount,
rate_buckets::Column::LastRefill,
])
.to_owned(),
)
.exec(&*tx)
.await?;
Ok(())
})
.await
}
/// Retrieves the rate limit for the given user and rate limit name.
pub async fn get_rate_bucket(
&self,
user_id: UserId,
rate_limit_name: &str,
) -> Result<Option<rate_buckets::Model>> {
self.transaction(|tx| async move {
let rate_limit = rate_buckets::Entity::find()
.filter(rate_buckets::Column::UserId.eq(user_id))
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
.one(&*tx)
.await?;
Ok(rate_limit)
})
.await
}
}

View file

@ -28,7 +28,6 @@ pub mod project;
pub mod project_collaborator;
pub mod project_repository;
pub mod project_repository_statuses;
pub mod rate_buckets;
pub mod room;
pub mod room_participant;
pub mod server;

View file

@ -1,31 +0,0 @@
use crate::db::UserId;
use sea_orm::entity::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "rate_buckets")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub user_id: UserId,
#[sea_orm(primary_key, auto_increment = false)]
pub rate_limit_name: String,
pub token_count: i32,
pub last_refill: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {
#[sea_orm(
belongs_to = "super::user::Entity",
from = "Column::UserId",
to = "super::user::Column::Id"
)]
User,
}
impl Related<super::user::Entity> for Entity {
fn to() -> RelationDef {
Relation::User.def()
}
}
impl ActiveModelBehavior for ActiveModel {}

View file

@ -6,7 +6,6 @@ pub mod env;
pub mod executor;
pub mod llm;
pub mod migrations;
mod rate_limiter;
pub mod rpc;
pub mod seed;
pub mod stripe_billing;
@ -25,7 +24,6 @@ pub use cents::*;
use db::{ChannelId, Database};
use executor::Executor;
use llm::db::LlmDatabase;
pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@ -295,7 +293,6 @@ pub struct AppState {
pub blob_store_client: Option<aws_sdk_s3::Client>,
pub stripe_client: Option<Arc<stripe::Client>>,
pub stripe_billing: Option<Arc<StripeBilling>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
pub config: Config,
@ -348,7 +345,6 @@ impl AppState {
.clone()
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
stripe_client,
rate_limiter: Arc::new(RateLimiter::new(db)),
executor,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()

View file

@ -13,8 +13,8 @@ use collab::llm::db::LlmDatabase;
use collab::migrations::run_database_migrations;
use collab::user_backfiller::spawn_user_backfiller;
use collab::{
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
env, executor::Executor, rpc::ResultExt,
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
executor::Executor, rpc::ResultExt,
};
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
use db::Database;
@ -111,10 +111,6 @@ async fn main() -> Result<()> {
if mode.is_collab() {
state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(
state.rate_limiter.clone(),
state.executor.clone(),
);
let epoch = state
.db

View file

@ -1,321 +0,0 @@
use crate::{Database, Error, Result, db::UserId, executor::Executor};
use chrono::{DateTime, Duration, Utc};
use dashmap::{DashMap, DashSet};
use rpc::ErrorCodeExt;
use sea_orm::prelude::DateTimeUtc;
use std::sync::Arc;
use util::ResultExt;
pub trait RateLimit: Send + Sync {
fn capacity(&self) -> usize;
fn refill_duration(&self) -> Duration;
fn db_name(&self) -> &'static str;
}
/// Used to enforce per-user rate limits
pub struct RateLimiter {
buckets: DashMap<(UserId, String), RateBucket>,
dirty_buckets: DashSet<(UserId, String)>,
db: Arc<Database>,
}
impl RateLimiter {
pub fn new(db: Arc<Database>) -> Self {
RateLimiter {
buckets: DashMap::new(),
dirty_buckets: DashSet::new(),
db,
}
}
/// Spawns a new task that periodically saves rate limit data to the database.
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
executor.clone().spawn_detached(async move {
loop {
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
rate_limiter.save().await.log_err();
}
});
}
/// Returns an error if the user has exceeded the specified `RateLimit`.
/// Attempts to read the from the database if no cached RateBucket currently exists.
pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
self.check_internal(limit, user_id, Utc::now()).await
}
async fn check_internal(
&self,
limit: &dyn RateLimit,
user_id: UserId,
now: DateTimeUtc,
) -> Result<()> {
let bucket_key = (user_id, limit.db_name().to_string());
// Attempt to fetch the bucket from the database if it hasn't been cached.
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
// but this enforces limits across restarts so long as the database is reachable.
if !self.buckets.contains_key(&bucket_key) {
if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
self.buckets.insert(bucket_key.clone(), bucket);
self.dirty_buckets.insert(bucket_key.clone());
}
}
let mut bucket = self
.buckets
.entry(bucket_key.clone())
.or_insert_with(|| RateBucket::new(limit, now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
Ok(())
} else {
Err(rpc::proto::ErrorCode::RateLimitExceeded
.message("rate limit exceeded".into())
.anyhow())?
}
}
async fn load_bucket(
&self,
limit: &dyn RateLimit,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
.get_rate_bucket(user_id, limit.db_name())
.await?
.map(|saved_bucket| {
RateBucket::from_db(
limit,
saved_bucket.token_count as usize,
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
)
}))
}
pub async fn save(&self) -> Result<()> {
let mut buckets = Vec::new();
self.dirty_buckets.retain(|key| {
if let Some(bucket) = self.buckets.get(key) {
buckets.push(crate::db::rate_buckets::Model {
user_id: key.0,
rate_limit_name: key.1.clone(),
token_count: bucket.token_count as i32,
last_refill: bucket.last_refill.naive_utc(),
});
}
false
});
match self.db.save_rate_buckets(&buckets).await {
Ok(()) => Ok(()),
Err(err) => {
for bucket in buckets {
self.dirty_buckets
.insert((bucket.user_id, bucket.rate_limit_name));
}
Err(err)
}
}
}
}
#[derive(Clone, Debug)]
struct RateBucket {
capacity: usize,
token_count: usize,
refill_time_per_token: Duration,
last_refill: DateTimeUtc,
}
impl RateBucket {
fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
Self {
capacity: limit.capacity(),
token_count: limit.capacity(),
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill: now,
}
}
fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
Self {
capacity: limit.capacity(),
token_count,
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill,
}
}
fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now);
if self.token_count > 0 {
self.token_count -= 1;
true
} else {
false
}
}
fn refill(&mut self, now: DateTimeUtc) {
let elapsed = now - self.last_refill;
if elapsed >= self.refill_time_per_token {
let new_tokens =
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
let unused_refill_time = Duration::milliseconds(
elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
);
self.last_refill = now - unused_refill_time;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::{NewUserParams, TestDb};
use gpui::TestAppContext;
#[gpui::test]
async fn test_rate_limiter(cx: &mut TestAppContext) {
let test_db = TestDb::sqlite(cx.executor().clone());
let db = test_db.db().clone();
let user_1 = db
.create_user(
"user-1@zed.dev",
None,
false,
NewUserParams {
github_login: "user-1".into(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;
let user_2 = db
.create_user(
"user-2@zed.dev",
None,
false,
NewUserParams {
github_login: "user-2".into(),
github_user_id: 2,
},
)
.await
.unwrap()
.user_id;
let mut now = Utc::now();
let rate_limiter = RateLimiter::new(db.clone());
let rate_limit_a = Box::new(RateLimitA);
let rate_limit_b = Box::new(RateLimitB);
// User 1 can access resource A two times before being rate-limited.
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// User 2 can access resource A and user 1 can access resource B.
rate_limiter
.check_internal(&*rate_limit_b, user_2, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_b, user_1, now)
.await
.unwrap();
// After 1.5s, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(1500);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 500ms, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(500);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
rate_limiter.save().await.unwrap();
// Rate limits are reloaded from the database, so user A is still rate-limited
// for resource A.
let rate_limiter = RateLimiter::new(db.clone());
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 1s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
.check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
}
struct RateLimitA;
impl RateLimit for RateLimitA {
fn capacity(&self) -> usize {
2
}
fn refill_duration(&self) -> Duration {
Duration::seconds(2)
}
fn db_name(&self) -> &'static str {
"rate-limit-a"
}
}
struct RateLimitB;
impl RateLimit for RateLimitB {
fn capacity(&self) -> usize {
10
}
fn refill_duration(&self) -> Duration {
Duration::seconds(3)
}
fn db_name(&self) -> &'static str {
"rate-limit-b"
}
}
}

View file

@ -1,5 +1,5 @@
use crate::{
AppState, Config, RateLimiter,
AppState, Config,
db::{NewUserParams, UserId, tests::TestDb},
executor::Executor,
rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion},
@ -517,7 +517,6 @@ impl TestServer {
blob_store_client: None,
stripe_client: None,
stripe_billing: None,
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
executor,
kinesis_client: None,
config: Config {