1
//! A token bucket implementation.
2

            
3
use std::fmt::Debug;
4
use std::time::{Duration, Instant};
5

            
6
/// A token bucket.
7
///
8
/// Calculations are performed at microsecond resolution.
9
/// You likely want to call [`refill()`](Self::refill) each time you want to access or perform an
10
/// operation on the token bucket.
11
///
12
/// This is partially inspired by tor's `token_bucket_ctr_t`,
13
/// but the implementation is quite a bit different.
14
/// We use larger values here (for example `u64`),
15
/// and we aim to avoid drift when refills occur at times that aren't exactly in period with the
16
/// refill rate.
17
///
18
/// It's possible that we could relax these requirements to reduce memory usage and computation
19
/// complexity, but that optimization should probably only be made if/when needed since it would
20
/// make the code more difficult to reason about, and possibly more complex.
21
#[derive(Debug)]
22
pub(crate) struct TokenBucket<I> {
23
    /// The refill rate in tokens/second.
24
    rate: u64,
25
    /// The max amount of tokens in the bucket.
26
    /// Commonly referred to as the "burst".
27
    bucket_max: u64,
28
    /// Current amount of tokens in the bucket.
29
    // It's possible that in the future we may want a token bucket to allow negative values. For
30
    // example we might want to send a few extra bytes over the allowed limit if it would mean that
31
    // we send a complete TLS record.
32
    bucket: u64,
33
    /// Time that the most recent token was added to the bucket.
34
    ///
35
    /// While this can be thought of as the last time the bucket was partially refilled, it more
36
    /// specifically is the time that the most recent token was added. For example if the bucket
37
    /// refills one token every 100 ms, and the bucket is refilled at time 510 ms, the bucket would
38
    /// gain 5 tokens and the stored time would be 500 ms.
39
    added_tokens_at: I,
40
}
41

            
42
impl<I: TokenBucketInstant> TokenBucket<I> {
43
    /// A new [`TokenBucket`] with a given `rate` in tokens/second and a `max` token limit.
44
    ///
45
    /// The bucket will initially be full.
46
    /// The value `max` is commonly referred to as the "burst".
47
158
    pub(crate) fn new(config: &TokenBucketConfig, now: I) -> Self {
48
158
        Self {
49
158
            rate: config.rate,
50
158
            bucket_max: config.bucket_max,
51
158
            bucket: config.bucket_max,
52
158
            added_tokens_at: now,
53
158
        }
54
158
    }
55

            
56
    /// Are there no tokens in the bucket?
57
    // remove this if we use it in the future
58
    #[cfg_attr(not(test), expect(dead_code))]
59
22
    pub(crate) fn is_empty(&self) -> bool {
60
22
        self.bucket == 0
61
22
    }
62

            
63
    /// The maximum number of tokens that this bucket can hold.
64
296
    pub(crate) fn max(&self) -> u64 {
65
296
        self.bucket_max
66
296
    }
67

            
68
    /// Remove `count` tokens from the bucket.
69
    // remove this if we use it in the future
70
    #[cfg_attr(not(test), expect(dead_code))]
71
30
    pub(crate) fn drain(&mut self, count: u64) -> Result<BecameEmpty, InsufficientTokensError> {
72
30
        Ok(self.claim(count)?.commit())
73
30
    }
74

            
75
    /// Claim a number of tokens.
76
    ///
77
    /// The claim will be held by the returned [`ClaimedTokens`], and committed when dropped.
78
    ///
79
    /// **Note:** You probably want to call [`refill()`](Self::refill) before this.
80
    // Since the `ClaimedTokens` holds a `&mut` to this `TokenBucket`, we don't need to worry about
81
    // other calls accessing the `TokenBucket` before the `ClaimedTokens` are committed.
82
6326
    pub(crate) fn claim(
83
6326
        &mut self,
84
6326
        count: u64,
85
6326
    ) -> Result<ClaimedTokens<I>, InsufficientTokensError> {
86
6326
        if count > self.bucket {
87
520
            return Err(InsufficientTokensError {
88
520
                available: self.bucket,
89
520
            });
90
5806
        }
91

            
92
5806
        Ok(ClaimedTokens::new(self, count))
93
6326
    }
94

            
95
    /// Adjust the refill rate and max tokens of the bucket.
96
    ///
97
    /// The token bucket is refilled up to `now` before changing the rate.
98
    ///
99
    /// If the new max is smaller than the existing number of tokens,
100
    /// the number of tokens will be reduced to the new max.
101
    ///
102
    /// A rate and/or max of 0 is allowed.
103
108
    pub(crate) fn adjust(&mut self, now: I, config: &TokenBucketConfig) {
104
        // make sure that the bucket gets the tokens it is owed before we change the rate
105
108
        self.refill(now);
106

            
107
        // If the old rate was small (or 0), the `refill()` might not have updated
108
        // `added_tokens_at`.
109
        //
110
        // For example if the bucket has a rate of 0 and was last refilled 10 seconds ago, it will
111
        // not have gained any tokens in the last 10 seconds. If we were to only update the rate to
112
        // 100 tokens/second now, the bucket would immediately become eligible to refill 1000
113
        // tokens. We only want the rate change to become effective now, not in the past, so we
114
        // ensure this by resetting `added_tokens_at`.
115
108
        self.added_tokens_at = std::cmp::max(self.added_tokens_at, now);
116

            
117
108
        self.rate = config.rate;
118
108
        self.bucket_max = config.bucket_max;
119
108
        self.bucket = std::cmp::min(self.bucket, self.bucket_max);
120
108
    }
121

            
122
    /// An estimated time at which the bucket will have `tokens` available.
123
    ///
124
    /// It is not guaranteed that `tokens` will be available at the returned time.
125
    ///
126
    /// If there are already enough tokens available, a time in the past may be returned.
127
    ///
128
    /// A value of `None` implies "never",
129
    /// for example if the refill rate is 0,
130
    /// the bucket max is too small,
131
    /// or the time is too large to be represented as an `I`.
132
300
    pub(crate) fn tokens_available_at(&self, tokens: u64) -> Result<I, NeverEnoughTokensError> {
133
300
        let tokens_needed = tokens.saturating_sub(self.bucket);
134

            
135
        // check if we currently have enough tokens before considering refilling
136
300
        if tokens_needed == 0 {
137
20
            return Ok(self.added_tokens_at);
138
280
        }
139

            
140
        // if the rate is 0, we'll never get more tokens
141
280
        if self.rate == 0 {
142
10
            return Err(NeverEnoughTokensError::ZeroRate);
143
270
        }
144

            
145
        // if more tokens are wanted than the capacity of the bucket, we'll never get enough
146
270
        if tokens > self.bucket_max {
147
4
            return Err(NeverEnoughTokensError::ExceedsMaxTokens);
148
266
        }
149

            
150
        // this may underestimate the time if either argument is very large
151
266
        let time_needed = Self::tokens_to_duration(tokens_needed, self.rate)
152
266
            .ok_or(NeverEnoughTokensError::ZeroRate)?;
153

            
154
        // Always return at least 1 microsecond since:
155
        // 1. We don't want to return `Duration::ZERO` if the tokens aren't ready,
156
        //    which may occur if the rate is very large (<1 ns/token).
157
        // 2. Clocks generally don't operate at <1 us resolution.
158
266
        let time_needed = std::cmp::max(time_needed, Duration::from_micros(1));
159

            
160
266
        self.added_tokens_at
161
266
            .checked_add(time_needed)
162
266
            .ok_or(NeverEnoughTokensError::InstantNotRepresentable)
163
300
    }
164

            
165
    /// Refill the bucket.
166
6204
    pub(crate) fn refill(&mut self, now: I) -> BecameNonEmpty {
167
        // time since we last added tokens
168
6204
        let elapsed = now.saturating_duration_since(self.added_tokens_at);
169

            
170
        // If we exceeded the threshold, update the timestamp and return.
171
        // This is taken from tor, which has the comment below:
172
        //
173
        // > Skip over updates that include an overflow or a very large jump. This can happen for
174
        // > platform specific reasons, such as the old ~48 day windows timer.
175
        //
176
        // It's unclear if this type of OS bug is still common enough that this check is useful,
177
        // but it shouldn't hurt.
178
6204
        if elapsed > I::IGNORE_THRESHOLD {
179
            tracing::debug!(
180
                "Time jump of {elapsed:?} is larger than {:?}; not refilling token bucket",
181
                I::IGNORE_THRESHOLD,
182
            );
183
            self.added_tokens_at = now;
184
            return BecameNonEmpty::No;
185
6204
        }
186

            
187
6204
        let old_bucket = self.bucket;
188

            
189
        // Compute how much we should increment the bucket by.
190
        // This may be underestimated in some cases.
191
6204
        let bucket_inc = Self::duration_to_tokens(elapsed, self.rate);
192

            
193
6204
        self.bucket = std::cmp::min(self.bucket_max, self.bucket.saturating_add(bucket_inc));
194

            
195
        // Compute how much we should increment the `last_added_tokens` time by. This avoids
196
        // drifting if the `bucket_inc` was underestimated, and avoids rounding errors which could
197
        // cause the token bucket to effectively use a lower rate. For example if the rate was
198
        // "1 token / sec" and the elapsed time was "1.2 sec", we only want to refill 1 token and
199
        // increment the time by 1 second.
200
        //
201
        // While the docs for `tokens_to_duration` say that a smaller than expected duration may be
202
        // returned, we have a test `test_duration_token_round_trip` which ensures that
203
        // `tokens_to_duration` returns the expected value when used with the result from
204
        // `duration_to_tokens`.
205
6204
        let added_tokens_at_inc =
206
6204
            Self::tokens_to_duration(bucket_inc, self.rate).unwrap_or(Duration::ZERO);
207

            
208
6204
        self.added_tokens_at = self
209
6204
            .added_tokens_at
210
6204
            .checked_add(added_tokens_at_inc)
211
6204
            .expect("overflowed time");
212
6204
        debug_assert!(self.added_tokens_at <= now);
213

            
214
6204
        if old_bucket == 0 && self.bucket != 0 {
215
256
            BecameNonEmpty::Yes
216
        } else {
217
5948
            BecameNonEmpty::No
218
        }
219
6204
    }
220

            
221
    /// How long would it take to refill `tokens` at `rate`?
222
    ///
223
    /// The result is rounded up to the nearest microsecond.
224
    /// If the number of `tokens` is large,
225
    /// the result may be much lower than the expected duration due to saturating 64-bit arithmetic.
226
    ///
227
    /// `None` will be returned if the `rate` is 0.
228
26502
    fn tokens_to_duration(tokens: u64, rate: u64) -> Option<Duration> {
229
        // Perform the calculation in microseconds rather than nanoseconds since timers typically
230
        // have microsecond granularity, and it lowers the chance that the calculation overflows the
231
        // `u64::MAX` limit compared to nanoseconds. In the case that the calculation saturates, the
232
        // returned duration will be shorter than the real value.
233
        //
234
        // For example with `tokens = u64::MAX` and `rate = u64::MAX` we'd expect a result of 1
235
        // second, but:
236
        // u64::MAX.saturating_mul(1000 * 1000).div_ceil(u64::MAX) = 1 microsecond
237
        //
238
        // The `div_ceil` ensures we always round up to the nearest microsecond.
239
        //
240
        // dimensional analysis:
241
        // (tokens) * (microseconds / second) / (tokens / second) = microseconds
242
26502
        if rate == 0 {
243
50
            return None;
244
26452
        }
245
26452
        let micros = tokens.saturating_mul(1000 * 1000).div_ceil(rate);
246
26452
        Some(Duration::from_micros(micros))
247
26502
    }
248

            
249
    /// How many tokens would be refilled within `time` at `rate`?
250
    ///
251
    /// The `time` is truncated to microsecond granularity.
252
    /// If the `time` or `rate` is large,
253
    /// the result may be much lower than the expected number of tokens due to saturating 64-bit
254
    /// arithmetic.
255
46268
    fn duration_to_tokens(time: Duration, rate: u64) -> u64 {
256
46268
        let micros = u64::try_from(time.as_micros()).unwrap_or(u64::MAX);
257
        // dimensional analysis:
258
        // (tokens / second) * (microseconds) / (microseconds / second) = tokens
259
46268
        rate.saturating_mul(micros) / (1000 * 1000)
260
46268
    }
261
}
262

            
263
/// The refill rate and token max for a [`TokenBucket`].
264
#[derive(Clone, Debug)]
265
pub(crate) struct TokenBucketConfig {
266
    /// The refill rate in tokens/second.
267
    pub(crate) rate: u64,
268
    /// The max amount of tokens in the bucket.
269
    /// Commonly referred to as the "burst".
270
    pub(crate) bucket_max: u64,
271
}
272

            
273
/// A handle to a number of claimed tokens.
274
///
275
/// Dropping this handle will commit the claim.
276
#[derive(Debug)]
277
pub(crate) struct ClaimedTokens<'a, I> {
278
    /// The bucket that the claim is for.
279
    bucket: &'a mut TokenBucket<I>,
280
    /// How many tokens to remove from the bucket.
281
    count: u64,
282
}
283

            
284
impl<'a, I> ClaimedTokens<'a, I> {
285
    /// Create a new [`ClaimedTokens`] that will remove `count` tokens from the token `bucket` when
286
    /// dropped.
287
5806
    fn new(bucket: &'a mut TokenBucket<I>, count: u64) -> Self {
288
5806
        Self { bucket, count }
289
5806
    }
290

            
291
    /// Commit the claimed tokens.
292
    ///
293
    /// This is equivalent to just dropping the [`ClaimedTokens`], but also returns whether the
294
    /// token bucket became empty or not.
295
5766
    pub(crate) fn commit(mut self) -> BecameEmpty {
296
5766
        self.commit_impl()
297
5766
    }
298

            
299
    /// Reduce the claim to a fewer number of tokens than the original claim.
300
    ///
301
    /// If `count` is larger than the original claim, an error will be returned containing the
302
    /// current number of claimed tokens.
303
5736
    pub(crate) fn reduce(&mut self, count: u64) -> Result<(), InsufficientTokensError> {
304
5736
        if count > self.count {
305
            return Err(InsufficientTokensError {
306
                available: self.count,
307
            });
308
5736
        }
309

            
310
5736
        self.count = count;
311
5736
        Ok(())
312
5736
    }
313

            
314
    /// Discard the claim.
315
    ///
316
    /// This does not remove any tokens from the token bucket.
317
40
    pub(crate) fn discard(mut self) {
318
40
        self.count = 0;
319
40
    }
320

            
321
    /// The commit implementation.
322
    ///
323
    /// After calling [`commit_impl()`](Self::commit_impl),
324
    /// the [`ClaimedTokens`] should no longer be used and should be dropped immediately.
325
11572
    fn commit_impl(&mut self) -> BecameEmpty {
326
        // when the `ClaimedTokens` was created by the `TokenBucket`, it should have ensured that
327
        // there were enough tokens
328
11572
        self.bucket.bucket = self
329
11572
            .bucket
330
11572
            .bucket
331
11572
            .checked_sub(self.count)
332
11572
            .unwrap_or_else(|| {
333
                panic!(
334
                    "claim commit failed: {}, {}",
335
                    self.count, self.bucket.bucket,
336
                )
337
            });
338

            
339
        // when `self` is dropped some time after this function ends,
340
        // we don't want to subtract again
341
11572
        self.count = 0;
342

            
343
11572
        if self.bucket.bucket > 0 {
344
11052
            BecameEmpty::No
345
        } else {
346
520
            BecameEmpty::Yes
347
        }
348
11572
    }
349
}
350

            
351
impl<'a, I> std::ops::Drop for ClaimedTokens<'a, I> {
352
5806
    fn drop(&mut self) {
353
5806
        self.commit_impl();
354
5806
    }
355
}
356

            
357
/// An operation was attempted to reduce the number of tokens,
358
/// but the token bucket did not have enough tokens.
359
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
360
#[error("insufficient tokens for operation")]
361
pub(crate) struct InsufficientTokensError {
362
    /// The number of tokens that are available to drain/commit.
363
    available: u64,
364
}
365

            
366
impl InsufficientTokensError {
367
    /// Get the number of tokens that are available to drain/commit.
368
520
    pub(crate) fn available_tokens(&self) -> u64 {
369
520
        self.available
370
520
    }
371
}
372

            
373
/// The token bucket will never have the requested number of tokens.
374
#[derive(Copy, Clone, Debug, PartialEq, Eq, thiserror::Error)]
375
#[error("there will never be enough tokens for this operation")]
376
pub(crate) enum NeverEnoughTokensError {
377
    /// The request exceeds the bucket's maximum number of tokens.
378
    ExceedsMaxTokens,
379
    /// The refill rate is 0.
380
    ZeroRate,
381
    /// The time is not representable.
382
    ///
383
    /// For example the if the rate is low and a large number of tokens were requested, it may be
384
    /// too far in the future that it cannot be represented as a time value.
385
    InstantNotRepresentable,
386
}
387

            
388
/// The token bucket transitioned from "empty" to "non-empty".
389
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
390
pub(crate) enum BecameNonEmpty {
391
    /// Token bucket became non-empty.
392
    Yes,
393
    /// Token bucket remains empty.
394
    No,
395
}
396

            
397
/// The token bucket transitioned from "non-empty" to "empty".
398
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
399
pub(crate) enum BecameEmpty {
400
    /// Token bucket became empty.
401
    Yes,
402
    /// Token bucket remains non-empty.
403
    No,
404
}
405

            
406
/// Any type implementing this must be represented as a measurement of a monotonically nondecreasing
407
/// clock.
408
pub(crate) trait TokenBucketInstant:
409
    Copy + Clone + Debug + PartialEq + Eq + PartialOrd + Ord
410
{
411
    /// An unrealistically large time jump.
412
    ///
413
    /// We assume that any time change larger than this indicates a broken monotonic clock,
414
    /// and the bucket will not be refilled.
415
    const IGNORE_THRESHOLD: Duration;
416

            
417
    /// See [`Instant::checked_add`].
418
    fn checked_add(&self, duration: Duration) -> Option<Self>;
419

            
420
    /// See [`Instant::checked_duration_since`].
421
    fn checked_duration_since(&self, earlier: Self) -> Option<Duration>;
422

            
423
    /// See [`Instant::saturating_duration_since`].
424
60
    fn saturating_duration_since(&self, earlier: Self) -> Duration {
425
60
        self.checked_duration_since(earlier).unwrap_or_default()
426
60
    }
427
}
428

            
429
impl TokenBucketInstant for Instant {
430
    // This value is taken from tor (see `elapsed_ticks <= UINT32_MAX/4` in
431
    // `src/lib/evloop/token_bucket.c`).
432
    const IGNORE_THRESHOLD: Duration = Duration::from_secs((u32::MAX / 4) as u64);
433

            
434
    #[inline]
435
6388
    fn checked_add(&self, duration: Duration) -> Option<Self> {
436
6388
        self.checked_add(duration)
437
6388
    }
438

            
439
    #[inline]
440
    fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
441
        self.checked_duration_since(earlier)
442
    }
443

            
444
    #[inline]
445
6144
    fn saturating_duration_since(&self, earlier: Self) -> Duration {
446
6144
        self.saturating_duration_since(earlier)
447
6144
    }
448
}
449

            
450
#[cfg(test)]
451
mod test {
452
    #![allow(clippy::unwrap_used)]
453

            
454
    use super::*;
455

            
456
    use rand::Rng;
457

            
458
    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
459
    struct MillisTimestamp(u64);
460

            
461
    impl TokenBucketInstant for MillisTimestamp {
462
        const IGNORE_THRESHOLD: Duration = Duration::from_millis(1_000_000_000);
463

            
464
        fn checked_add(&self, duration: Duration) -> Option<Self> {
465
            let duration = u64::try_from(duration.as_millis()).ok()?;
466
            self.0.checked_add(duration).map(Self)
467
        }
468

            
469
        fn checked_duration_since(&self, earlier: Self) -> Option<Duration> {
470
            Some(Duration::from_millis(self.0.checked_sub(earlier.0)?))
471
        }
472
    }
473

            
474
    #[test]
475
    fn adjust_now() {
476
        let time = MillisTimestamp(100);
477

            
478
        let config = TokenBucketConfig {
479
            rate: 10,
480
            bucket_max: 100,
481
        };
482
        let mut tb = TokenBucket::new(&config, time);
483
        assert_eq!(tb.bucket, 100);
484
        assert_eq!(tb.bucket_max, 100);
485
        assert_eq!(tb.rate, 10);
486

            
487
        tb.adjust(
488
            time,
489
            &TokenBucketConfig {
490
                rate: 20,
491
                bucket_max: 100,
492
            },
493
        );
494
        assert_eq!(tb.bucket, 100);
495
        assert_eq!(tb.bucket_max, 100);
496

            
497
        tb.adjust(
498
            time,
499
            &TokenBucketConfig {
500
                rate: 20,
501
                bucket_max: 40,
502
            },
503
        );
504
        assert_eq!(tb.bucket, 40);
505
        assert_eq!(tb.bucket_max, 40);
506

            
507
        tb.adjust(
508
            time,
509
            &TokenBucketConfig {
510
                rate: 20,
511
                bucket_max: 100,
512
            },
513
        );
514
        assert_eq!(tb.bucket, 40);
515
        assert_eq!(tb.bucket_max, 100);
516

            
517
        tb.adjust(
518
            time,
519
            &TokenBucketConfig {
520
                rate: 200,
521
                bucket_max: 100,
522
            },
523
        );
524
        assert_eq!(tb.bucket, 40);
525
        assert_eq!(tb.bucket_max, 100);
526
        assert_eq!(tb.rate, 200);
527
    }
528

            
529
    #[test]
530
    fn adjust_future() {
531
        let config = TokenBucketConfig {
532
            rate: 10,
533
            bucket_max: 100,
534
        };
535
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
536
        assert_eq!(tb.bucket, 100);
537
        assert_eq!(tb.bucket_max, 100);
538
        assert_eq!(tb.rate, 10);
539

            
540
        // at 300 ms: increase rate and max; bucket was already full, so doesn't gain any tokens
541
        tb.adjust(
542
            MillisTimestamp(300),
543
            &TokenBucketConfig {
544
                rate: 20,
545
                bucket_max: 200,
546
            },
547
        );
548
        assert_eq!(tb.bucket, 100);
549
        assert_eq!(tb.bucket_max, 200);
550

            
551
        // at 500 ms: no changes; bucket is refilled during `adjust()`, so gains 4 tokens
552
        tb.adjust(
553
            MillisTimestamp(500),
554
            &TokenBucketConfig {
555
                rate: 20,
556
                bucket_max: 200,
557
            },
558
        );
559
        assert_eq!(tb.bucket, 104);
560
        assert_eq!(tb.bucket_max, 200);
561

            
562
        // at 700 ms: lower rate and max; bucket is lowered to new max, so loses 4 tokens
563
        tb.adjust(
564
            MillisTimestamp(700),
565
            &TokenBucketConfig {
566
                rate: 0,
567
                bucket_max: 100,
568
            },
569
        );
570
        assert_eq!(tb.bucket, 100);
571
        assert_eq!(tb.bucket_max, 100);
572

            
573
        // at 900 ms: raise rate and max; rate was previously 0 so doesn't gain any tokens
574
        tb.adjust(
575
            MillisTimestamp(900),
576
            &TokenBucketConfig {
577
                rate: 100,
578
                bucket_max: 200,
579
            },
580
        );
581
        assert_eq!(tb.bucket, 100);
582
        assert_eq!(tb.bucket_max, 200);
583
    }
584

            
585
    #[test]
586
    fn adjust_zero() {
587
        let time = MillisTimestamp(100);
588

            
589
        let config = TokenBucketConfig {
590
            rate: 10,
591
            bucket_max: 100,
592
        };
593

            
594
        let mut tb = TokenBucket::new(&config, time);
595
        tb.adjust(
596
            time,
597
            &TokenBucketConfig {
598
                rate: 0,
599
                bucket_max: 200,
600
            },
601
        );
602
        assert_eq!(tb.bucket, 100);
603
        assert_eq!(tb.bucket_max, 200);
604
        assert_eq!(tb.rate, 0);
605
        // bucket should not increase
606
        tb.refill(MillisTimestamp(10_000_000));
607
        assert_eq!(tb.bucket, 100);
608

            
609
        let mut tb = TokenBucket::new(&config, time);
610
        tb.adjust(
611
            time,
612
            &TokenBucketConfig {
613
                rate: 10,
614
                bucket_max: 0,
615
            },
616
        );
617
        assert_eq!(tb.bucket, 0);
618
        assert_eq!(tb.bucket_max, 0);
619
        assert_eq!(tb.rate, 10);
620
        // bucket should stay empty
621
        tb.refill(MillisTimestamp(10_000_000));
622
        assert_eq!(tb.bucket, 0);
623

            
624
        let mut tb = TokenBucket::new(&config, time);
625
        tb.adjust(
626
            time,
627
            &TokenBucketConfig {
628
                rate: 0,
629
                bucket_max: 0,
630
            },
631
        );
632
        assert_eq!(tb.bucket, 0);
633
        assert_eq!(tb.bucket_max, 0);
634
        assert_eq!(tb.rate, 0);
635
        // bucket should stay empty
636
        tb.refill(MillisTimestamp(10_000_000));
637
        assert_eq!(tb.bucket, 0);
638
    }
639

            
640
    #[test]
641
    fn is_empty() {
642
        // increases 10 tokens/second (one every 100 ms)
643
        let config = TokenBucketConfig {
644
            rate: 10,
645
            bucket_max: 100,
646
        };
647
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
648
        assert!(!tb.is_empty());
649

            
650
        tb.drain(99).unwrap();
651
        assert!(!tb.is_empty());
652

            
653
        tb.drain(1).unwrap();
654
        assert!(tb.is_empty());
655

            
656
        tb.refill(MillisTimestamp(199));
657
        assert!(tb.is_empty());
658

            
659
        tb.refill(MillisTimestamp(200));
660
        assert!(!tb.is_empty());
661
    }
662

            
663
    #[test]
664
    fn correctness() {
665
        // increases 10 tokens/second (one every 100 ms)
666
        let config = TokenBucketConfig {
667
            rate: 10,
668
            bucket_max: 100,
669
        };
670
        let mut tb = TokenBucket::new(&config, MillisTimestamp(100));
671

            
672
        tb.drain(50).unwrap();
673
        assert_eq!(tb.bucket, 50);
674

            
675
        tb.refill(MillisTimestamp(1100));
676
        assert_eq!(tb.bucket, 60);
677

            
678
        tb.drain(50).unwrap();
679
        assert_eq!(tb.bucket, 10);
680

            
681
        tb.refill(MillisTimestamp(2100));
682
        assert_eq!(tb.bucket, 20);
683

            
684
        tb.refill(MillisTimestamp(2101));
685
        assert_eq!(tb.bucket, 20);
686
        tb.refill(MillisTimestamp(2199));
687
        assert_eq!(tb.bucket, 20);
688
        tb.refill(MillisTimestamp(2200));
689
        assert_eq!(tb.bucket, 21);
690
    }
691

            
692
    #[test]
693
    fn rounding() {
694
        // increases 10 tokens/second (one every 100 ms)
695
        let config = TokenBucketConfig {
696
            rate: 10,
697
            bucket_max: 100,
698
        };
699
        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
700
        tb.drain(100).unwrap();
701

            
702
        // ensure that refilling at 150 ms does not change the `added_tokens_at` time to 150 ms,
703
        // otherwise the next refill wouldn't occur until 250 ms instead of 200 ms
704
        tb.refill(MillisTimestamp(99));
705
        assert_eq!(tb.bucket, 0);
706
        tb.refill(MillisTimestamp(150));
707
        assert_eq!(tb.bucket, 1);
708
        tb.refill(MillisTimestamp(199));
709
        assert_eq!(tb.bucket, 1);
710
        tb.refill(MillisTimestamp(200));
711
        assert_eq!(tb.bucket, 2);
712
    }
713

            
714
    #[test]
715
    fn tokens_available_at() {
716
        // increases 10 tokens/second (one every 100 ms)
717
        let config = TokenBucketConfig {
718
            rate: 10,
719
            bucket_max: 100,
720
        };
721
        let mut tb = TokenBucket::new(&config, MillisTimestamp(0));
722

            
723
        // bucket is empty at 0 ms, next token at 100 ms
724
        tb.drain(100).unwrap();
725

            
726
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
727
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
728
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
729

            
730
        // bucket is still empty at 40 ms, next token at 100 ms
731
        tb.refill(MillisTimestamp(40));
732

            
733
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(0)));
734
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
735
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
736

            
737
        // bucket has 1 token at 100 ms, next token at 200 ms
738
        tb.refill(MillisTimestamp(100));
739

            
740
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
741
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(100)));
742
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(200)));
743

            
744
        // bucket is empty at 100 ms, next token at 200 ms
745
        tb.drain(1).unwrap();
746

            
747
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
748
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
749
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
750

            
751
        // bucket is empty at 140 ms, next token at 200 ms
752
        tb.refill(MillisTimestamp(140));
753

            
754
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(100)));
755
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
756
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
757

            
758
        // bucket has 1 token at 210 ms, next token at 300 ms
759
        tb.refill(MillisTimestamp(210));
760

            
761
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(200)));
762
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(200)));
763
        assert_eq!(tb.tokens_available_at(2), Ok(MillisTimestamp(300)));
764

            
765
        use NeverEnoughTokensError as NETE;
766

            
767
        assert_eq!(tb.tokens_available_at(100), Ok(MillisTimestamp(10_100)));
768
        assert_eq!(tb.tokens_available_at(101), Err(NETE::ExceedsMaxTokens));
769
        assert_eq!(
770
            tb.tokens_available_at(u64::MAX),
771
            Err(NETE::ExceedsMaxTokens),
772
        );
773

            
774
        // set the refill rate to 0; note that adjusting the rate also resets `added_tokens_at`
775
        tb.adjust(
776
            MillisTimestamp(210),
777
            &TokenBucketConfig {
778
                rate: 0,
779
                bucket_max: 100,
780
            },
781
        );
782

            
783
        assert_eq!(tb.tokens_available_at(0), Ok(MillisTimestamp(210)));
784
        assert_eq!(tb.tokens_available_at(1), Ok(MillisTimestamp(210)));
785
        assert_eq!(tb.tokens_available_at(2), Err(NETE::ZeroRate));
786
    }
787

            
788
    #[test]
789
    fn test_duration_token_round_trip() {
790
        let tokens_to_duration = TokenBucket::<Instant>::tokens_to_duration;
791
        let duration_to_tokens = TokenBucket::<Instant>::duration_to_tokens;
792

            
793
        // start with some hand-picked cases
794
        let mut duration_rate_pairs = vec![
795
            (Duration::from_nanos(0), 1),
796
            (Duration::from_nanos(1), 1),
797
            (Duration::from_micros(2), 1),
798
            (Duration::MAX, 1),
799
            (Duration::from_nanos(0), 3),
800
            (Duration::from_nanos(1), 3),
801
            (Duration::from_micros(2), 3),
802
            (Duration::MAX, 3),
803
            (Duration::from_nanos(0), 1000),
804
            (Duration::from_nanos(1), 1000),
805
            (Duration::from_micros(2), 1000),
806
            (Duration::MAX, 1000),
807
            (Duration::from_nanos(0), u64::MAX),
808
            (Duration::from_nanos(1), u64::MAX),
809
            (Duration::from_micros(2), u64::MAX),
810
            (Duration::MAX, u64::MAX),
811
        ];
812

            
813
        let mut rng = rand::rng();
814

            
815
        // add some fuzzing
816
        for _ in 0..10_000 {
817
            let secs = rng.random();
818
            let nanos = rng.random();
819
            // Duration::new() may panic, so just skip if there's a panic rather than trying to
820
            // write our own logic to avoid the panic in the first place
821
            let Ok(random_duration) = std::panic::catch_unwind(|| Duration::new(secs, nanos))
822
            else {
823
                continue;
824
            };
825
            let random_rate = rng.random();
826
            duration_rate_pairs.push((random_duration, random_rate));
827
        }
828

            
829
        // for various combinations of durations and rates, we ensure that after an initial
830
        // `duration_to_tokens` calculation which may truncate, a round-trip between
831
        // `tokens_to_duration` and `duration_to_tokens` isn't lossy
832
        for (original_duration, rate) in duration_rate_pairs {
833
            // this may give a smaller number of tokens than expected (see docs on
834
            // `TokenBucket::duration_to_tokens`)
835
            let tokens = duration_to_tokens(original_duration, rate);
836

            
837
            // we want to ensure that converting these `tokens` to a duration and then back to
838
            // tokens is not lossy, which implies that `tokens_to_duration` is returning the
839
            // expected value and not a truncated value due to saturating arithmetic
840
            let duration = tokens_to_duration(tokens, rate).unwrap();
841
            assert_eq!(tokens, duration_to_tokens(duration, rate));
842
        }
843
    }
844
}