1
//! A token bucket implementation.
2

            
3
use std::fmt::Debug;
4
use web_time_compat::{Duration, Instant};
5

            
6
/// A token bucket.
7
///
8
/// Calculations are performed at microsecond resolution.
9
/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10
/// operation on the token bucket.
11
///
12
/// This is partially inspired by tor's `token_bucket_ctr_t`,
13
/// but the implementation is quite a bit different.
14
/// We use larger values here (for example `u64`),
15
/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16
/// refill rate.
17
///
18
/// It's possible that we could relax these requirements to reduce memory usage and computation
19
/// complexity, but that optimization should probably only be made if/when needed since it would
20
/// make the code more difficult to reason about, and possibly more complex.
21
#[derive(Debug)]
22
pub struct TokenBucket<I> {
23
    /// The refill rate in tokens/second.
24
    rate: u64,
25
    /// The max amount of tokens in the bucket.
26
    /// Commonly referred to as the "burst".
27
    bucket_max: u64,
28
    /// Current amount of tokens in the bucket.
29
    // It's possible that in the future we may want a token bucket to allow negative values. For
30
    // example we might want to send a few extra bytes over the allowed limit if it would mean that
31
    // we send a complete TLS record.
32
    bucket: u64,
33
    /// Time that the most recent token was added to the bucket.
34
    ///
35
    /// While this can be thought of as the last time the bucket was partially refilled, it more
36
    /// specifically is the time that the most recent token was added. For example if the bucket
37
    /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38
    /// gain 5 tokens and the stored time would be 500 ms.
39
    added_tokens_at: I,
40
}
41

            
42
impl<I: TokenBucketInstant> TokenBucket<I> {
43
    /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44
    ///
45
    /// The bucket will initially be full.
46
    /// The value `max` is commonly referred to as the "burst".
47
162
    pub fn new(config: &TokenBucketConfig, now: I) -> Self {
48
162
        Self {
49
162
            rate: config.rate,
50
162
            bucket_max: config.bucket_max,
51
162
            bucket: config.bucket_max,
52
162
            added_tokens_at: now,
53
162
        }
54
162
    }
55

            
56
    /// Are there no tokens in the bucket?
57
22
    pub fn is_empty(&self) -> bool {
58
22
        self.bucket == 0
59
22
    }
60

            
61
    /// The maximum number of tokens that this bucket can hold.
62
296
    pub fn max(&self) -> u64 {
63
296
        self.bucket_max
64
296
    }
65

            
66
    /// Remove `count` tokens from the bucket.
67
30
    pub fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
68
30
        Ok(self.claim(count)?.commit())
69
30
    }
70

            
71
    /// Claim a number of tokens.
72
    ///
73
    /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
74
    ///
75
    /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
76
    // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
77
    // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
78
6326
    pub fn claim(&mut self, count: u64) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
79
6326
        if count > self.bucket {
80
520
            return Err(InsufficientTokensError {
81
520
                available: self.bucket,
82
520
            });
83
5806
        }
84

            
85
5806
        Ok(ClaimedTokens::new(self, count))
86
6326
    }
87

            
88
    /// Adjust the refill rate and max tokens of the bucket.
89
    ///
90
    /// The token bucket is refilled up to `now` before changing the rate.
91
    ///
92
    /// If the new max is smaller than the existing number of tokens,
93
    /// the number of tokens will be reduced to the new max.
94
    ///
95
    /// A rate and/or max of 0 is allowed.
96
108
    pub fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
97
        // make sure that the bucket gets the tokens it is owed before we change the rate
98
108
        self.refill(now);
99

            
100
        // If the old rate was small (or 0), the `refill()` might not have updated
101
        // `added_tokens_at`.
102
        //
103
        // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
104
        // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
105
        // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
106
        // tokens. We only want the rate change to become effective now, not in the past, so we
107
        // ensure this by resetting `added_tokens_at`.
108
108
        self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
109

            
110
108
        self.rate = config.rate;
111
108
        self.bucket_max = config.bucket_max;
112
108
        self.bucket = std::cmp::min(self.bucket, self.bucket_max);
113
108
    }
114

            
115
    /// An estimated time at which the bucket will have `tokens` available.
116
    ///
117
    /// It is not guaranteed that `tokens` will be available at the returned time.
118
    ///
119
    /// If there are already enough tokens available, a time in the past may be returned.
120
    ///
121
    /// A value of `None` implies "never",
122
    /// for example if the refill rate is 0,
123
    /// the bucket max is too small,
124
    /// or the time is too large to be represented as an `I`.
125
300
    pub fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
126
300
        let tokens_needed = tokens.saturating_sub(self.bucket);
127

            
128
        // check if we currently have enough tokens before considering refilling
129
300
        if tokens_needed == 0 {
130
20
            return Ok(self.added_tokens_at);
131
280
        }
132

            
133
        // if the rate is 0, we'll never get more tokens
134
280
        if self.rate == 0 {
135
10
            return Err(NeverEnoughTokensError::ZeroRate);
136
270
        }
137

            
138
        // if more tokens are wanted than the capacity of the bucket, we'll never get enough
139
270
        if tokens > self.bucket_max {
140
4
            return Err(NeverEnoughTokensError::ExceedsMaxTokens);
141
266
        }
142

            
143
        // this may underestimate the time if either argument is very large
144
266
        let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
145
266
            .ok_or(NeverEnoughTokensError::ZeroRate)?;
146

            
147
        // Always return at least 1 microsecond since:
148
        // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
149
        //    which may occur if the rate is very large (<1 ns/token).
150
        // 2. Clocks generally don't operate at <1 us resolution.
151
266
        let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
152

            
153
266
        self.added_tokens_at
154
266
            .checked_add(time_needed)
155
266
            .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
156
300
    }
157

            
158
    /// Refill the bucket.
159
6204
    pub fn refill(&mut self, now: I) -> BecameNonEmpty {
160
        // time since we last added tokens
161
6204
        let elapsed = now.saturating_duration_since(self.added_tokens_at);
162

            
163
        // If we exceeded the threshold, update the timestamp and return.
164
        // This is taken from tor, which has the comment below:
165
        //
166
        // > Skip over updates that include an overflow or a very large jump. This can happen for
167
        // > platform specific reasons, such as the old ~48 day windows timer.
168
        //
169
        // It's unclear if this type of OS bug is still common enough that this check is useful,
170
        // but it shouldn't hurt.
171
6204
        if elapsed > I::IGNORE_THRESHOLD {
172
            tracing::debug!(
173
                "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
174
                I::IGNORE_THRESHOLD,
175
            );
176
            self.added_tokens_at = now;
177
            return BecameNonEmpty::No;
178
6204
        }
179

            
180
6204
        let old_bucket = self.bucket;
181

            
182
        // Compute how much we should increment the bucket by.
183
        // This may be underestimated in some cases.
184
6204
        let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
185

            
186
6204
        self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
187

            
188
        // Compute how much we should increment the `last_added_tokens` time by. This avoids
189
        // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
190
        // cause the token bucket to effectively use a lower rate. For example if the rate was
191
        // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
192
        // increment the time by 1 second.
193
        //
194
        // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
195
        // returned, we have a test `test_duration_token_round_trip` which ensures that
196
        // `tokens_to_duration` returns the expected value when used with the result from
197
        // `duration_to_tokens`.
198
6204
        let added_tokens_at_inc =
199
6204
            Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
200

            
201
6204
        self.added_tokens_at = self
202
6204
            .added_tokens_at
203
6204
            .checked_add(added_tokens_at_inc)
204
6204
            .expect("overflowed time");
205
6204
        debug_assert!(self.added_tokens_at <= now);
206

            
207
6204
        if old_bucket == 0 && self.bucket != 0 {
208
256
            BecameNonEmpty::Yes
209
        } else {
210
5948
            BecameNonEmpty::No
211
        }
212
6204
    }
213

            
214
    /// How long would it take to refill `tokens` at `rate`?
215
    ///
216
    /// The result is rounded up to the nearest microsecond.
217
    /// If the number of `tokens` is large,
218
    /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
219
    ///
220
    /// `None` will be returned if the `rate` is 0.
221
26502
    fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
222
        // Perform the calculation in microseconds rather than nanoseconds since timers typically
223
        // have microsecond granularity, and it lowers the chance that the calculation overflows the
224
        // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
225
        // returned duration will be shorter than the real value.
226
        //
227
        // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
228
        // second, but:
229
        // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
230
        //
231
        // The `div_ceil` ensures we always round up to the nearest microsecond.
232
        //
233
        // dimensional analysis:
234
        // (tokens) * (microseconds / second) / (tokens / second) = microseconds
235
26502
        if rate == 0 {
236
50
            return None;
237
26452
        }
238
26452
        let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
239
26452
        Some(Duration::from_micros(micros))
240
26502
    }
241

            
242
    /// How many tokens would be refilled within `time` at `rate`?
243
    ///
244
    /// The `time` is truncated to microsecond granularity.
245
    /// If the `time` or `rate` is large,
246
    /// the result may be much lower than the expected number of tokens due to saturating 64-bit
247
    /// arithmetic.
248
46268
    fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
249
46268
        let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
250
        // dimensional analysis:
251
        // (tokens / second) * (microseconds) / (microseconds / second) = tokens
252
46268
        rate.saturating_mul(micros) / (1000 * 1000)
253
46268
    }
254
}
255

            
256
/// The refill rate and token max for a [`TokenBucket`].
257
#[derive(Clone, Debug)]
258
#[allow(clippy::exhaustive_structs)] // constructed directly by callers configuring the bucket
259
pub struct TokenBucketConfig {
260
    /// The refill rate in tokens/second.
261
    pub rate: u64,
262
    /// The max amount of tokens in the bucket.
263
    /// Commonly referred to as the "burst".
264
    pub bucket_max: u64,
265
}
266

            
267
/// A handle to a number of claimed tokens.
268
///
269
/// Dropping this handle will commit the claim.
270
#[derive(Debug)]
271
pub struct ClaimedTokens<'a, I> {
272
    /// The bucket that the claim is for.
273
    bucket: &'a mut TokenBucket<I>,
274
    /// How many tokens to remove from the bucket.
275
    count: u64,
276
}
277

            
278
impl<'a, I> ClaimedTokens<'a, I> {
279
    /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
280
    /// dropped.
281
5806
    fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
282
5806
        Self { bucket, count }
283
5806
    }
284

            
285
    /// Commit the claimed tokens.
286
    ///
287
    /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
288
    /// token bucket became empty or not.
289
5766
    pub fn commit(mut self) -> BecameEmpty {
290
5766
        self.commit_impl()
291
5766
    }
292

            
293
    /// Reduce the claim to a fewer number of tokens than the original claim.
294
    ///
295
    /// If `count` is larger than the original claim, an error will be returned containing the
296
    /// current number of claimed tokens.
297
5736
    pub fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
298
5736
        if count > self.count {
299
            return Err(InsufficientTokensError {
300
                available: self.count,
301
            });
302
5736
        }
303

            
304
5736
        self.count = count;
305
5736
        Ok(())
306
5736
    }
307

            
308
    /// Discard the claim.
309
    ///
310
    /// This does not remove any tokens from the token bucket.
311
40
    pub fn discard(mut self) {
312
40
        self.count = 0;
313
40
    }
314

            
315
    /// The commit implementation.
316
    ///
317
    /// After calling [`commit_impl()`](Self::commit_impl),
318
    /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
319
11572
    fn commit_impl(&mut self) -> BecameEmpty {
320
        // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
321
        // there were enough tokens
322
11572
        self.bucket.bucket = self
323
11572
            .bucket
324
11572
            .bucket
325
11572
            .checked_sub(self.count)
326
11572
            .unwrap_or_else(|| {
327
                panic!(
328
                    "claim commit failed: {}, {}",
329
                    self.count, self.bucket.bucket,
330
                )
331
            });
332

            
333
        // when `self` is dropped some time after this function ends,
334
        // we don't want to subtract again
335
11572
        self.count = 0;
336

            
337
11572
        if self.bucket.bucket > 0 {
338
11052
            BecameEmpty::No
339
        } else {
340
520
            BecameEmpty::Yes
341
        }
342
11572
    }
343
}
344

            
345
impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
346
5806
    fn drop(&mut self) {
347
5806
        self.commit_impl();
348
5806
    }
349
}
350

            
351
/// An operation was attempted to reduce the number of tokens,
352
/// but the token bucket did not have enough tokens.
353
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
354
#[error("insufficient tokens for operation")]
355
pub struct InsufficientTokensError {
356
    /// The number of tokens that are available to drain/commit.
357
    available: u64,
358
}
359

            
360
impl InsufficientTokensError {
361
    /// Get the number of tokens that are available to drain/commit.
362
13780
    pub fn available_tokens(&self) -> u64 {
363
13780
        self.available
364
13780
    }
365
}
366

            
367
/// The token bucket will never have the requested number of tokens.
368
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
369
#[allow(clippy::exhaustive_enums)] // callers exhaustively match on these variants
370
#[error("there will never be enough tokens for this operation")]
371
pub enum NeverEnoughTokensError {
372
    /// The request exceeds the bucket's maximum number of tokens.
373
    ExceedsMaxTokens,
374
    /// The refill rate is 0.
375
    ZeroRate,
376
    /// The time is not representable.
377
    ///
378
    /// For example the if the rate is low and a large number of tokens were requested, it may be
379
    /// too far in the future that it cannot be represented as a time value.
380
    InstantNotRepresentable,
381
}
382

            
383
/// The token bucket transitioned from "empty" to "non-empty".
384
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
385
#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
386
pub enum BecameNonEmpty {
387
    /// Token bucket became non-empty.
388
    Yes,
389
    /// Token bucket remains empty.
390
    No,
391
}
392

            
393
/// The token bucket transitioned from "non-empty" to "empty".
394
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
395
#[allow(clippy::exhaustive_enums)] // a simple yes/no status that callers match on
396
pub enum BecameEmpty {
397
    /// Token bucket became empty.
398
    Yes,
399
    /// Token bucket remains non-empty.
400
    No,
401
}
402

            
403
/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
404
/// clock.
405
pub trait TokenBucketInstant: Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord {
406
    /// An unrealistically large time jump.
407
    ///
408
    /// We assume that any time change larger than this indicates a broken monotonic clock,
409
    /// and the bucket will not be refilled.
410
    const IGNORE_THRESHOLD: Duration;
411

            
412
    /// See [`Instant::checked_add`].
413
    fn checked_add(&self, duration: Duration) -> Option<Self>;
414

            
415
    /// See [`Instant::checked_duration_since`].
416
    fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
417

            
418
    /// See [`Instant::saturating_duration_since`].
419
60
    fn saturating_duration_since(&self, earlier: Self) -> Duration {
420
60
        self.checked_duration_since(earlier).unwrap_or_default()
421
60
    }
422
}
423

            
424
impl TokenBucketInstant for Instant {
425
    // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
426
    // `src/lib/evloop/token_bucket.c`).
427
    const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
428

            
429
    #[inline]
430
6388
    fn checked_add(&self, duration: Duration) -> Option<Self> {
431
6388
        self.checked_add(duration)
432
6388
    }
433

            
434
    #[inline]
435
    fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
436
        self.checked_duration_since(earlier)
437
    }
438

            
439
    #[inline]
440
6144
    fn saturating_duration_since(&self, earlier: Self) -> Duration {
441
6144
        self.saturating_duration_since(earlier)
442
6144
    }
443
}
444

            
445
#[cfg(test)]
446
mod test {
447
    #![allow(clippy::unwrap_used)]
448

            
449
    use super::*;
450

            
451
    use rand::RngExt;
452

            
453
    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
454
    struct MillisTimestamp(u64);
455

            
456
    impl TokenBucketInstant for MillisTimestamp {
457
        const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
458

            
459
        fn checked_add(&self, duration: Duration) -> Option<Self> {
460
            let duration = u64::try_from(duration.as_millis()).ok()?;
461
            self.0.checked_add(duration).map(Self)
462
        }
463

            
464
        fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
465
            Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
466
        }
467
    }
468

            
469
    #[test]
470
    fn adjust_now() {
471
        let time = MillisTimestamp(100);
472

            
473
        let config = TokenBucketConfig {
474
            rate: 10,
475
            bucket_max: 100,
476
        };
477
        let mut tb = TokenBucket::new(&config, time);
478
        assert_eq!(tb.bucket, 100);
479
        assert_eq!(tb.bucket_max, 100);
480
        assert_eq!(tb.rate, 10);
481

            
482
        tb.adjust(
483
            time,
484
            &TokenBucketConfig {
485
                rate: 20,
486
                bucket_max: 100,
487
            },
488
        );
489
        assert_eq!(tb.bucket, 100);
490
        assert_eq!(tb.bucket_max, 100);
491

            
492
        tb.adjust(
493
            time,
494
            &TokenBucketConfig {
495
                rate: 20,
496
                bucket_max: 40,
497
            },
498
        );
499
        assert_eq!(tb.bucket, 40);
500
        assert_eq!(tb.bucket_max, 40);
501

            
502
        tb.adjust(
503
            time,
504
            &TokenBucketConfig {
505
                rate: 20,
506
                bucket_max: 100,
507
            },
508
        );
509
        assert_eq!(tb.bucket, 40);
510
        assert_eq!(tb.bucket_max, 100);
511

            
512
        tb.adjust(
513
            time,
514
            &TokenBucketConfig {
515
                rate: 200,
516
                bucket_max: 100,
517
            },
518
        );
519
        assert_eq!(tb.bucket, 40);
520
        assert_eq!(tb.bucket_max, 100);
521
        assert_eq!(tb.rate, 200);
522
    }
523

            
524
    #[test]
525
    fn adjust_future() {
526
        let config = TokenBucketConfig {
527
            rate: 10,
528
            bucket_max: 100,
529
        };
530
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
531
        assert_eq!(tb.bucket, 100);
532
        assert_eq!(tb.bucket_max, 100);
533
        assert_eq!(tb.rate, 10);
534

            
535
        // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
536
        tb.adjust(
537
            MillisTimestamp(300),
538
            &TokenBucketConfig {
539
                rate: 20,
540
                bucket_max: 200,
541
            },
542
        );
543
        assert_eq!(tb.bucket, 100);
544
        assert_eq!(tb.bucket_max, 200);
545

            
546
        // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
547
        tb.adjust(
548
            MillisTimestamp(500),
549
            &TokenBucketConfig {
550
                rate: 20,
551
                bucket_max: 200,
552
            },
553
        );
554
        assert_eq!(tb.bucket, 104);
555
        assert_eq!(tb.bucket_max, 200);
556

            
557
        // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
558
        tb.adjust(
559
            MillisTimestamp(700),
560
            &TokenBucketConfig {
561
                rate: 0,
562
                bucket_max: 100,
563
            },
564
        );
565
        assert_eq!(tb.bucket, 100);
566
        assert_eq!(tb.bucket_max, 100);
567

            
568
        // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
569
        tb.adjust(
570
            MillisTimestamp(900),
571
            &TokenBucketConfig {
572
                rate: 100,
573
                bucket_max: 200,
574
            },
575
        );
576
        assert_eq!(tb.bucket, 100);
577
        assert_eq!(tb.bucket_max, 200);
578
    }
579

            
580
    #[test]
581
    fn adjust_zero() {
582
        let time = MillisTimestamp(100);
583

            
584
        let config = TokenBucketConfig {
585
            rate: 10,
586
            bucket_max: 100,
587
        };
588

            
589
        let mut tb = TokenBucket::new(&config, time);
590
        tb.adjust(
591
            time,
592
            &TokenBucketConfig {
593
                rate: 0,
594
                bucket_max: 200,
595
            },
596
        );
597
        assert_eq!(tb.bucket, 100);
598
        assert_eq!(tb.bucket_max, 200);
599
        assert_eq!(tb.rate, 0);
600
        // bucket should not increase
601
        tb.refill(MillisTimestamp(10_000_000));
602
        assert_eq!(tb.bucket, 100);
603

            
604
        let mut tb = TokenBucket::new(&config, time);
605
        tb.adjust(
606
            time,
607
            &TokenBucketConfig {
608
                rate: 10,
609
                bucket_max: 0,
610
            },
611
        );
612
        assert_eq!(tb.bucket, 0);
613
        assert_eq!(tb.bucket_max, 0);
614
        assert_eq!(tb.rate, 10);
615
        // bucket should stay empty
616
        tb.refill(MillisTimestamp(10_000_000));
617
        assert_eq!(tb.bucket, 0);
618

            
619
        let mut tb = TokenBucket::new(&config, time);
620
        tb.adjust(
621
            time,
622
            &TokenBucketConfig {
623
                rate: 0,
624
                bucket_max: 0,
625
            },
626
        );
627
        assert_eq!(tb.bucket, 0);
628
        assert_eq!(tb.bucket_max, 0);
629
        assert_eq!(tb.rate, 0);
630
        // bucket should stay empty
631
        tb.refill(MillisTimestamp(10_000_000));
632
        assert_eq!(tb.bucket, 0);
633
    }
634

            
635
    #[test]
636
    fn is_empty() {
637
        // increases 10 tokens/second (one every 100 ms)
638
        let config = TokenBucketConfig {
639
            rate: 10,
640
            bucket_max: 100,
641
        };
642
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
643
        assert!(!tb.is_empty());
644

            
645
        tb.drain(99).unwrap();
646
        assert!(!tb.is_empty());
647

            
648
        tb.drain(1).unwrap();
649
        assert!(tb.is_empty());
650

            
651
        tb.refill(MillisTimestamp(199));
652
        assert!(tb.is_empty());
653

            
654
        tb.refill(MillisTimestamp(200));
655
        assert!(!tb.is_empty());
656
    }
657

            
658
    #[test]
659
    fn correctness() {
660
        // increases 10 tokens/second (one every 100 ms)
661
        let config = TokenBucketConfig {
662
            rate: 10,
663
            bucket_max: 100,
664
        };
665
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
666

            
667
        tb.drain(50).unwrap();
668
        assert_eq!(tb.bucket, 50);
669

            
670
        tb.refill(MillisTimestamp(1100));
671
        assert_eq!(tb.bucket, 60);
672

            
673
        tb.drain(50).unwrap();
674
        assert_eq!(tb.bucket, 10);
675

            
676
        tb.refill(MillisTimestamp(2100));
677
        assert_eq!(tb.bucket, 20);
678

            
679
        tb.refill(MillisTimestamp(2101));
680
        assert_eq!(tb.bucket, 20);
681
        tb.refill(MillisTimestamp(2199));
682
        assert_eq!(tb.bucket, 20);
683
        tb.refill(MillisTimestamp(2200));
684
        assert_eq!(tb.bucket, 21);
685
    }
686

            
687
    #[test]
688
    fn rounding() {
689
        // increases 10 tokens/second (one every 100 ms)
690
        let config = TokenBucketConfig {
691
            rate: 10,
692
            bucket_max: 100,
693
        };
694
        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
695
        tb.drain(100).unwrap();
696

            
697
        // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
698
        // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
699
        tb.refill(MillisTimestamp(99));
700
        assert_eq!(tb.bucket, 0);
701
        tb.refill(MillisTimestamp(150));
702
        assert_eq!(tb.bucket, 1);
703
        tb.refill(MillisTimestamp(199));
704
        assert_eq!(tb.bucket, 1);
705
        tb.refill(MillisTimestamp(200));
706
        assert_eq!(tb.bucket, 2);
707
    }
708

            
709
    #[test]
710
    fn tokens_available_at() {
711
        // increases 10 tokens/second (one every 100 ms)
712
        let config = TokenBucketConfig {
713
            rate: 10,
714
            bucket_max: 100,
715
        };
716
        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
717

            
718
        // bucket is empty at 0 ms, next token at 100 ms
719
        tb.drain(100).unwrap();
720

            
721
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
722
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
723
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
724

            
725
        // bucket is still empty at 40 ms, next token at 100 ms
726
        tb.refill(MillisTimestamp(40));
727

            
728
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
729
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
730
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
731

            
732
        // bucket has 1 token at 100 ms, next token at 200 ms
733
        tb.refill(MillisTimestamp(100));
734

            
735
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
736
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
737
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
738

            
739
        // bucket is empty at 100 ms, next token at 200 ms
740
        tb.drain(1).unwrap();
741

            
742
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
743
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
744
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
745

            
746
        // bucket is empty at 140 ms, next token at 200 ms
747
        tb.refill(MillisTimestamp(140));
748

            
749
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
750
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
751
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
752

            
753
        // bucket has 1 token at 210 ms, next token at 300 ms
754
        tb.refill(MillisTimestamp(210));
755

            
756
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
757
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
758
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
759

            
760
        use NeverEnoughTokensError as NETE;
761

            
762
        assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
763
        assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
764
        assert_eq!(
765
            tb.tokens_available_at(u64::MAX),
766
            Err(NETE::ExceedsMaxTokens),
767
        );
768

            
769
        // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
770
        tb.adjust(
771
            MillisTimestamp(210),
772
            &TokenBucketConfig {
773
                rate: 0,
774
                bucket_max: 100,
775
            },
776
        );
777

            
778
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
779
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
780
        assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
781
    }
782

            
783
    #[test]
784
    fn test_duration_token_round_trip() {
785
        let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
786
        let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
787

            
788
        // start with some hand-picked cases
789
        let mut duration_rate_pairs = vec![
790
            (Duration::from_nanos(0), 1),
791
            (Duration::from_nanos(1), 1),
792
            (Duration::from_micros(2), 1),
793
            (Duration::MAX, 1),
794
            (Duration::from_nanos(0), 3),
795
            (Duration::from_nanos(1), 3),
796
            (Duration::from_micros(2), 3),
797
            (Duration::MAX, 3),
798
            (Duration::from_nanos(0), 1000),
799
            (Duration::from_nanos(1), 1000),
800
            (Duration::from_micros(2), 1000),
801
            (Duration::MAX, 1000),
802
            (Duration::from_nanos(0), u64::MAX),
803
            (Duration::from_nanos(1), u64::MAX),
804
            (Duration::from_micros(2), u64::MAX),
805
            (Duration::MAX, u64::MAX),
806
        ];
807

            
808
        let mut rng = rand::rng();
809

            
810
        // add some fuzzing
811
        for _ in 0..10_000 {
812
            let secs = rng.random();
813
            let nanos = rng.random();
814
            // Duration::new() may panic, so just skip if there's a panic rather than trying to
815
            // write our own logic to avoid the panic in the first place
816
            let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
817
            else {
818
                continue;
819
            };
820
            let random_rate = rng.random();
821
            duration_rate_pairs.push((random_duration, random_rate));
822
        }
823

            
824
        // for various combinations of durations and rates, we ensure that after an initial
825
        // `duration_to_tokens` calculation which may truncate, a round-trip between
826
        // `tokens_to_duration` and `duration_to_tokens` isn't lossy
827
        for (original_duration, rate) in duration_rate_pairs {
828
            // this may give a smaller number of tokens than expected (see docs on
829
            // `TokenBucket::duration_to_tokens`)
830
            let tokens = duration_to_tokens(original_duration, rate);
831

            
832
            // we want to ensure that converting these `tokens` to a duration and then back to
833
            // tokens is not lossy, which implies that `tokens_to_duration` is returning the
834
            // expected value and not a truncated value due to saturating arithmetic
835
            let duration = tokens_to_duration(tokens, rate).unwrap();
836
            assert_eq!(tokens, duration_to_tokens(duration, rate));
837
        }
838
    }
839
}