Fix logic errors in RateLimiter (#12421)

This pull request fixes two issues in `RateLimiter` that caused
excessive rate-limiting to take place:

- c19083a35c fixes a mistake that caused
us to load buckets from the database incorrectly and set the
`refill_time_per_token` to equal the `refill_duration`. This was the
primary reason why rate limiting was acting oddly.
- 34b88d14f6 fixes another slight logic
error that caused tokens to be underprovisioned. This was minor compared
to the bug above.

Release Notes:

- N/A
This commit is contained in:
Antonio Scandurra 2024-05-29 12:05:40 +02:00 committed by GitHub
parent 3c6c850390
commit 4acfab689e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -62,7 +62,7 @@ impl RateLimiter {
let mut bucket = self let mut bucket = self
.buckets .buckets
.entry(bucket_key.clone()) .entry(bucket_key.clone())
.or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now)); .or_insert_with(|| RateBucket::new::<T>(now));
if bucket.value_mut().allow(now) { if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key); self.dirty_buckets.insert(bucket_key);
@ -72,19 +72,19 @@ impl RateLimiter {
} }
} }
async fn load_bucket<K: RateLimit>( async fn load_bucket<T: RateLimit>(
&self, &self,
user_id: UserId, user_id: UserId,
) -> Result<Option<RateBucket>, Error> { ) -> Result<Option<RateBucket>, Error> {
Ok(self Ok(self
.db .db
.get_rate_bucket(user_id, K::db_name()) .get_rate_bucket(user_id, T::db_name())
.await? .await?
.map(|saved_bucket| RateBucket { .map(|saved_bucket| {
capacity: K::capacity(), RateBucket::from_db::<T>(
refill_time_per_token: K::refill_duration(), saved_bucket.token_count as usize,
token_count: saved_bucket.token_count as usize, DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), )
})) }))
} }
@ -124,15 +124,24 @@ struct RateBucket {
} }
impl RateBucket { impl RateBucket {
fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self { fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
RateBucket { Self {
capacity, capacity: T::capacity(),
token_count: capacity, token_count: T::capacity(),
refill_time_per_token: refill_duration / capacity as i32, refill_time_per_token: T::refill_duration() / T::capacity() as i32,
last_refill: now, last_refill: now,
} }
} }
fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
Self {
capacity: T::capacity(),
token_count,
refill_time_per_token: T::refill_duration() / T::capacity() as i32,
last_refill,
}
}
fn allow(&mut self, now: DateTimeUtc) -> bool { fn allow(&mut self, now: DateTimeUtc) -> bool {
self.refill(now); self.refill(now);
if self.token_count > 0 { if self.token_count > 0 {
@ -148,9 +157,12 @@ impl RateBucket {
if elapsed >= self.refill_time_per_token { if elapsed >= self.refill_time_per_token {
let new_tokens = let new_tokens =
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds(); elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity); self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
self.last_refill = now;
let unused_refill_time = Duration::milliseconds(
elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
);
self.last_refill = now - unused_refill_time;
} }
} }
} }
@ -218,8 +230,19 @@ mod tests {
.await .await
.unwrap(); .unwrap();
// After one second, user 1 can make another request before being rate-limited again. // After 1.5s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1); now += Duration::milliseconds(1500);
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
// After 500ms, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(500);
rate_limiter rate_limiter
.check_internal::<RateLimitA>(user_1, now) .check_internal::<RateLimitA>(user_1, now)
.await .await
@ -238,6 +261,17 @@ mod tests {
.check_internal::<RateLimitA>(user_1, now) .check_internal::<RateLimitA>(user_1, now)
.await .await
.unwrap_err(); .unwrap_err();
// After 1s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap();
rate_limiter
.check_internal::<RateLimitA>(user_1, now)
.await
.unwrap_err();
} }
struct RateLimitA; struct RateLimitA;