1
//! An [`AsyncWrite`] rate limiter.
2

            
3
use std::future::Future;
4
use std::num::NonZero;
5
use std::pin::Pin;
6
use std::task::{Context, Poll};
7
use web_time_compat::{Duration, Instant};
8

            
9
use futures::AsyncWrite;
10
use futures::io::Error;
11
use sync_wrapper::SyncFuture;
12
use tor_rtcompat::SleepProvider;
13

            
14
use tor_basic_utils::token_bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15

            
16
/// A rate-limited async [writer](AsyncWrite).
17
///
18
/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19
#[derive(educe::Educe)]
20
#[educe(Debug)]
21
#[pin_project::pin_project]
22
pub struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23
    /// The token bucket.
24
    bucket: TokenBucket<Instant>,
25
    /// The sleep provider, for getting the current time and creating new sleep futures.
26
    ///
27
    /// While we use [`Instant`] for the time, we should always get the time from this
28
    /// [`SleepProvider`].
29
    /// For example, use [`SleepProvider::now()`],
30
    /// not [`Instant::now()`](std::time::Instant::now) or
31
    /// [`InstantExt::get`](web_time_compat::InstantExt::get).
32
    #[educe(Debug(ignore))]
33
    sleep_provider: P,
34
    /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
35
    wake_when_bytes_available: NonZero<u64>,
36
    /// The inner writer.
37
    #[educe(Debug(ignore))]
38
    #[pin]
39
    inner: W,
40
    /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
41
    #[educe(Debug(ignore))]
42
    #[pin]
43
    sleep_fut: Option<SyncFuture<P::SleepFuture>>,
44
}
45

            
46
impl<W, P> RateLimitedWriter<W, P>
47
where
48
    W: AsyncWrite,
49
    P: SleepProvider,
50
{
51
    /// Create a new [`RateLimitedWriter`].
52
    // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
53
    // bucket only ever uses times from `sleep_provider`.
54
128
    pub fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
55
128
        let bucket_config = TokenBucketConfig {
56
128
            rate: config.rate,
57
128
            bucket_max: config.burst,
58
128
        };
59
128
        Self::from_token_bucket(
60
128
            writer,
61
128
            TokenBucket::new(&bucket_config, sleep_provider.now()),
62
128
            config.wake_when_bytes_available,
63
128
            sleep_provider,
64
        )
65
128
    }
66

            
67
    /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
68
    ///
69
    /// The token bucket must have only been used with times created by `sleep_provider`.
70
20
    #[cfg_attr(test, visibility::make(pub(super)))]
71
144
    fn from_token_bucket(
72
144
        writer: W,
73
144
        bucket: TokenBucket<Instant>,
74
144
        wake_when_bytes_available: NonZero<u64>,
75
144
        sleep_provider: P,
76
144
    ) -> Self {
77
144
        Self {
78
144
            bucket,
79
144
            sleep_provider,
80
144
            wake_when_bytes_available,
81
144
            inner: writer,
82
144
            sleep_fut: None,
83
144
        }
84
144
    }
85

            
86
    /// Access the inner [`AsyncWrite`] writer.
87
    pub fn inner(&self) -> &W {
88
        &self.inner
89
    }
90

            
91
    /// Adjust the refill rate and burst.
92
    ///
93
    /// A rate and/or burst of 0 is allowed.
94
84
    pub fn adjust(self: &mut Pin<&mut Self>, now: Instant, config: &RateLimitedWriterConfig) {
95
84
        let self_ = self.as_mut().project();
96

            
97
        // destructuring allows us to make sure we aren't forgetting to handle any fields
98
        let RateLimitedWriterConfig {
99
84
            rate,
100
84
            burst,
101
84
            wake_when_bytes_available,
102
84
        } = *config;
103

            
104
84
        let bucket_config = TokenBucketConfig {
105
84
            rate,
106
84
            bucket_max: burst,
107
84
        };
108

            
109
84
        self_.bucket.adjust(now, &bucket_config);
110
84
        *self_.wake_when_bytes_available = wake_when_bytes_available;
111
84
    }
112

            
113
    /// The sleep provider.
114
    ///
115
    /// We don't want this to be generally accessible, only to other token bucket-related modules
116
    /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
117
84
    pub(super) fn sleep_provider(&self) -> &P {
118
84
        &self.sleep_provider
119
84
    }
120

            
121
    /// Configure this writer to sleep for `duration`.
122
    ///
123
    /// A `duration` of `None` is interpreted as "forever".
124
    ///
125
    /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
126
284
    fn register_sleep(
127
284
        sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
128
284
        sleep_provider: &mut P,
129
284
        cx: &mut Context<'_>,
130
284
        duration: Option<Duration>,
131
284
    ) -> Poll<()> {
132
284
        match duration {
133
            None => {
134
40
                sleep_fut.as_mut().set(None);
135
40
                Poll::Pending
136
            }
137
244
            Some(duration) => {
138
244
                debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
139
244
                sleep_fut
140
244
                    .as_mut()
141
244
                    .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
142
244
                sleep_fut
143
244
                    .as_mut()
144
244
                    .as_pin_mut()
145
244
                    .expect("but we just set it to `Some`?!")
146
244
                    .poll(cx)
147
            }
148
        }
149
284
    }
150
}
151

            
152
impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
153
where
154
    W: AsyncWrite,
155
    P: SleepProvider,
156
{
157
6060
    fn poll_write(
158
6060
        mut self: Pin<&mut Self>,
159
6060
        cx: &mut Context<'_>,
160
6060
        mut buf: &[u8],
161
6060
    ) -> Poll<Result<usize, Error>> {
162
6060
        let mut self_ = self.as_mut().project();
163

            
164
        // this should be optimized to a no-op on at least x86-64
165
288316
        fn to_u64(x: usize) -> u64 {
166
288316
            x.try_into().expect("failed usize to u64 conversion")
167
288316
        }
168

            
169
        // for an empty buffer, just defer to the inner writer's impl
170
6060
        if buf.is_empty() {
171
            return self_.inner.poll_write(cx, buf);
172
6060
        }
173

            
174
6060
        let now = self_.sleep_provider.now();
175

            
176
        // refill the bucket and attempt to claim all of the bytes
177
6060
        self_.bucket.refill(now);
178
6060
        let claim = self_.bucket.claim(to_u64(buf.len()));
179

            
180
6060
        let mut claim = match claim {
181
            // claim was successful
182
5540
            Ok(x) => x,
183
            // not enough tokens, so let's use a smaller buffer
184
520
            Err(e) => {
185
520
                let available = e.available_tokens();
186

            
187
                // need to drop the old claim so that we can access the token bucket again
188
520
                drop(claim);
189

            
190
                // if no tokens in bucket, we must sleep
191
520
                if available == 0 {
192
                    // number of tokens we'll wait for
193
284
                    let wake_at_tokens = to_u64(buf.len());
194

            
195
                    // If the user wants to write X tokens, we don't necessarily want to sleep until
196
                    // we have room for X tokens. We also don't want to wake every time that a
197
                    // single byte can be written. We allow the user to configure this threshold
198
                    // with `RateLimitedWriterConfig::wake_when_bytes_available`.
199
284
                    let wake_at_tokens =
200
284
                        std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
201

            
202
                    // max number of tokens the bucket can hold
203
284
                    let bucket_max = self_.bucket.max();
204

            
205
                    // how long to sleep for; `None` indicates to sleep forever
206
284
                    let sleep_for = if bucket_max == 0 {
207
                        // bucket can't hold any tokens, so sleep forever
208
32
                        None
209
                    } else {
210
                        // if the bucket has a max of X tokens, we should never try to wait for >X
211
                        // tokens
212
252
                        let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
213

            
214
                        // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
215
                        // want
216
252
                        debug_assert!(wake_at_tokens > 0);
217

            
218
252
                        let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
219
252
                        let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
220

            
221
8
                        match sleep_for {
222
244
                            Ok(x) => Some(x),
223
                            Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
224
                                panic!(
225
                                    "exceeds max tokens, but we took the max into account above"
226
                                );
227
                            }
228
                            // we aren't refilling, so sleep forever
229
8
                            Err(NeverEnoughTokensError::ZeroRate) => None,
230
                            // too far in the future to be represented, so sleep forever
231
                            Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
232
                        }
233
                    };
234

            
235
                    // configure the sleep future and poll it to register
236
284
                    let poll = Self::register_sleep(
237
284
                        &mut self_.sleep_fut,
238
284
                        self_.sleep_provider,
239
284
                        cx,
240
284
                        sleep_for,
241
                    );
242
284
                    return match poll {
243
                        // wait for the sleep to finish
244
284
                        Poll::Pending => Poll::Pending,
245
                        // The sleep is already ready?! A recursive call here isn't great, but
246
                        // there's not much else we can do here. Hopefully this second `poll_write`
247
                        // will succeed since we should now have enough tokens.
248
                        Poll::Ready(()) => self.poll_write(cx, buf),
249
                    };
250
236
                }
251

            
252
                /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
253
                // This is a separate function to ensure we don't accidentally try to convert a
254
                // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
255
                // sense.
256
236
                fn to_usize_saturating(x: u64) -> usize {
257
236
                    x.try_into().unwrap_or(usize::MAX)
258
236
                }
259

            
260
                // There are tokens, so try to write as many as are available.
261
236
                let available_usize = to_usize_saturating(available);
262
236
                buf = &buf[0..available_usize];
263
236
                self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
264
                    panic!(
265
                        "bucket has {available} tokens available, but can't claim {}?",
266
                        buf.len(),
267
                    )
268
                })
269
            }
270
        };
271

            
272
5776
        let rv = self_.inner.poll_write(cx, buf);
273

            
274
5736
        match rv {
275
            // no bytes were written, so discard the claim
276
40
            Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
277
            // `x` bytes were written, so only commit those tokens
278
5736
            Poll::Ready(Ok(x)) => {
279
5736
                if x <= buf.len() {
280
5736
                    claim
281
5736
                        .reduce(to_u64(x))
282
5736
                        .expect("can't commit fewer tokens?!");
283
5736
                    claim.commit();
284
5736
                } else {
285
                    cfg_if::cfg_if! {
286
                        if #[cfg(debug_assertions)] {
287
                            panic!(
288
                                "Writer is claiming it wrote more bytes {x} than we gave it {}",
289
                                buf.len(),
290
                            );
291
                        } else {
292
                            // the best we can do is to just claim the original amount
293
                            claim.commit();
294
                        }
295
                    }
296
                }
297
            }
298
        };
299

            
300
5776
        rv
301
6060
    }
302

            
303
32
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
304
32
        self.project().inner.poll_flush(cx)
305
32
    }
306

            
307
16
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308
        // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
309
        // closed and will continue to accept bytes even after being closed, so we must continue to
310
        // apply rate limiting even after being closed
311
16
        self.project().inner.poll_close(cx)
312
16
    }
313
}
314

            
315
/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
316
/// everywhere.
317
#[cfg(feature = "tokio")]
318
mod tokio_impl {
319
    use super::*;
320

            
321
    use tokio::io::AsyncWrite as TokioAsyncWrite;
322
    use tokio_util::compat::FuturesAsyncWriteCompatExt;
323

            
324
    use std::io::Result as IoResult;
325

            
326
    impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
327
    where
328
        W: AsyncWrite,
329
        P: SleepProvider,
330
    {
331
        fn poll_write(
332
            self: Pin<&mut Self>,
333
            cx: &mut Context<'_>,
334
            buf: &[u8],
335
        ) -> Poll<IoResult<usize>> {
336
            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
337
        }
338

            
339
        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
340
            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
341
        }
342

            
343
        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
344
            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
345
        }
346
    }
347
}
348

            
349
/// The refill rate and burst for a [`RateLimitedWriter`].
350
#[derive(Clone, Debug)]
351
#[allow(clippy::exhaustive_structs)]
352
pub struct RateLimitedWriterConfig {
353
    /// The refill rate in bytes/second.
354
    pub rate: u64,
355
    /// The "burst" in bytes.
356
    pub burst: u64,
357
    /// When polled, block until at most this many bytes are available.
358
    ///
359
    /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
360
    /// larger.
361
    ///
362
    /// For example if a user attempts to write a large buffer, we usually don't want to block until
363
    /// the entire buffer can be written. We'd prefer several partial writes to a single large
364
    /// write. So instead of blocking until the entire buffer can be written, we only block until
365
    /// at most this many bytes are available.
366
    pub wake_when_bytes_available: NonZero<u64>,
367
}
368

            
369
#[cfg(test)]
370
mod test {
371
    #![allow(clippy::unwrap_used)]
372

            
373
    use super::*;
374

            
375
    use futures::{AsyncWriteExt, FutureExt};
376
    use tor_rtcompat::SpawnExt;
377

            
378
    #[test]
379
    fn writer() {
380
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
381
            let start = rt.now();
382

            
383
            // increases 10 tokens/second (one every 100 ms)
384
            let config = TokenBucketConfig {
385
                rate: 10,
386
                bucket_max: 100,
387
            };
388
            let mut tb = TokenBucket::new(&config, start);
389
            // drain the bucket
390
            tb.drain(100).unwrap();
391

            
392
            let wake_when_bytes_available = NonZero::new(15).unwrap();
393

            
394
            let mut writer = Vec::new();
395
            let mut writer = RateLimitedWriter::from_token_bucket(
396
                &mut writer,
397
                tb,
398
                wake_when_bytes_available,
399
                rt.clone(),
400
            );
401

            
402
            // drive time forward from 0 to 20_000 ms in 50 ms intervals
403
            let rt_clone = rt.clone();
404
            rt.spawn(async move {
405
                for _ in 0..400 {
406
                    rt_clone.progress_until_stalled().await;
407
                    rt_clone.advance_by(Duration::from_millis(50)).await;
408
                }
409
            })
410
            .unwrap();
411

            
412
            // try writing 60 bytes, which sleeps until we can write at least 15 of them
413
            assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
414
            assert_eq!(1500, rt.now().duration_since(start).as_millis());
415

            
416
            // wait 2 seconds
417
            rt.sleep(Duration::from_millis(2000)).await;
418

            
419
            // ensure that we can write immediately, and that we can write
420
            // 2000 ms / (100 ms/token) = 20 bytes
421
            assert_eq!(
422
                Some(20),
423
                writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
424
            );
425
        });
426
    }
427

            
428
    /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
429
    #[test]
430
    fn rate_burst_zero() {
431
        let configs = [
432
            // non-zero rate, zero max
433
            TokenBucketConfig {
434
                rate: 10,
435
                bucket_max: 0,
436
            },
437
            // zero rate, non-zero max
438
            TokenBucketConfig {
439
                rate: 0,
440
                bucket_max: 10,
441
            },
442
            // zero rate, zero max
443
            TokenBucketConfig {
444
                rate: 0,
445
                bucket_max: 0,
446
            },
447
        ];
448
        for config in configs {
449
            tor_rtmock::MockRuntime::test_with_various(|rt| {
450
                let config = config.clone();
451
                async move {
452
                    // an empty token bucket
453
                    let mut tb = TokenBucket::new(&config, rt.now());
454
                    tb.drain(tb.max()).unwrap();
455
                    assert!(tb.is_empty());
456

            
457
                    let wake_when_bytes_available = NonZero::new(2).unwrap();
458

            
459
                    let mut writer = Vec::new();
460
                    let mut writer = RateLimitedWriter::from_token_bucket(
461
                        &mut writer,
462
                        tb,
463
                        wake_when_bytes_available,
464
                        rt.clone(),
465
                    );
466

            
467
                    // drive time forward from 0 to 10_000 ms in 100 ms intervals
468
                    let rt_clone = rt.clone();
469
                    rt.spawn(async move {
470
                        for _ in 0..100 {
471
                            rt_clone.progress_until_stalled().await;
472
                            rt_clone.advance_by(Duration::from_millis(100)).await;
473
                        }
474
                    })
475
                    .unwrap();
476

            
477
                    // ensure that a write returns `Pending`
478
                    assert_eq!(
479
                        None,
480
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
481
                    );
482

            
483
                    // wait 5 seconds
484
                    rt.sleep(Duration::from_millis(5000)).await;
485

            
486
                    // ensure that a write still returns `Pending`
487
                    assert_eq!(
488
                        None,
489
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
490
                    );
491
                }
492
            });
493
        }
494
    }
495
}