1
//! An [`AsyncWrite`] rate limiter.
2

            
3
use std::future::Future;
4
use std::num::NonZero;
5
use std::pin::Pin;
6
use std::task::{Context, Poll};
7
use web_time_compat::{Duration, Instant};
8

            
9
use futures::AsyncWrite;
10
use futures::io::Error;
11
use sync_wrapper::SyncFuture;
12
use tor_rtcompat::SleepProvider;
13

            
14
use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15

            
16
/// A rate-limited async [writer](AsyncWrite).
17
///
18
/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19
#[derive(educe::Educe)]
20
#[educe(Debug)]
21
#[pin_project::pin_project]
22
pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23
    /// The token bucket.
24
    bucket: TokenBucket<Instant>,
25
    /// The sleep provider, for getting the current time and creating new sleep futures.
26
    ///
27
    /// While we use [`Instant`] for the time, we should always get the time from this
28
    /// [`SleepProvider`].
29
    /// For example, use [`SleepProvider::now()`],
30
    /// not [`Instant::now()`](std::time::Instant::now) or
31
    /// [`InstantExt::get`](web_time_compat::InstantExt::get).
32
    #[educe(Debug(ignore))]
33
    sleep_provider: P,
34
    /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
35
    wake_when_bytes_available: NonZero<u64>,
36
    /// The inner writer.
37
    #[educe(Debug(ignore))]
38
    #[pin]
39
    inner: W,
40
    /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
41
    #[educe(Debug(ignore))]
42
    #[pin]
43
    sleep_fut: Option<SyncFuture<P::SleepFuture>>,
44
}
45

            
46
impl<W, P> RateLimitedWriter<W, P>
47
where
48
    W: AsyncWrite,
49
    P: SleepProvider,
50
{
51
    /// Create a new [`RateLimitedWriter`].
52
    // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
53
    // bucket only ever uses times from `sleep_provider`.
54
124
    pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
55
124
        let bucket_config = TokenBucketConfig {
56
124
            rate: config.rate,
57
124
            bucket_max: config.burst,
58
124
        };
59
124
        Self::from_token_bucket(
60
124
            writer,
61
124
            TokenBucket::new(&bucket_config, sleep_provider.now()),
62
124
            config.wake_when_bytes_available,
63
124
            sleep_provider,
64
        )
65
124
    }
66

            
67
    /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
68
    ///
69
    /// The token bucket must have only been used with times created by `sleep_provider`.
70
140
    #[cfg_attr(test, visibility::make(pub(super)))]
71
140
    fn from_token_bucket(
72
140
        writer: W,
73
140
        bucket: TokenBucket<Instant>,
74
140
        wake_when_bytes_available: NonZero<u64>,
75
140
        sleep_provider: P,
76
140
    ) -> Self {
77
140
        Self {
78
140
            bucket,
79
140
            sleep_provider,
80
140
            wake_when_bytes_available,
81
140
            inner: writer,
82
140
            sleep_fut: None,
83
140
        }
84
140
    }
85

            
86
    /// Access the inner [`AsyncWrite`] writer.
87
    pub(crate) fn inner(&self) -> &W {
88
        &self.inner
89
    }
90

            
91
    /// Adjust the refill rate and burst.
92
    ///
93
    /// A rate and/or burst of 0 is allowed.
94
84
    pub(crate) fn adjust(
95
84
        self: &mut Pin<&mut Self>,
96
84
        now: Instant,
97
84
        config: &RateLimitedWriterConfig,
98
84
    ) {
99
84
        let self_ = self.as_mut().project();
100

            
101
        // destructuring allows us to make sure we aren't forgetting to handle any fields
102
        let RateLimitedWriterConfig {
103
84
            rate,
104
84
            burst,
105
84
            wake_when_bytes_available,
106
84
        } = *config;
107

            
108
84
        let bucket_config = TokenBucketConfig {
109
84
            rate,
110
84
            bucket_max: burst,
111
84
        };
112

            
113
84
        self_.bucket.adjust(now, &bucket_config);
114
84
        *self_.wake_when_bytes_available = wake_when_bytes_available;
115
84
    }
116

            
117
    /// The sleep provider.
118
    ///
119
    /// We don't want this to be generally accessible, only to other token bucket-related modules
120
    /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
121
84
    pub(super) fn sleep_provider(&self) -> &P {
122
84
        &self.sleep_provider
123
84
    }
124

            
125
    /// Configure this writer to sleep for `duration`.
126
    ///
127
    /// A `duration` of `None` is interpreted as "forever".
128
    ///
129
    /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
130
284
    fn register_sleep(
131
284
        sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
132
284
        sleep_provider: &mut P,
133
284
        cx: &mut Context<'_>,
134
284
        duration: Option<Duration>,
135
284
    ) -> Poll<()> {
136
284
        match duration {
137
            None => {
138
40
                sleep_fut.as_mut().set(None);
139
40
                Poll::Pending
140
            }
141
244
            Some(duration) => {
142
244
                debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
143
244
                sleep_fut
144
244
                    .as_mut()
145
244
                    .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
146
244
                sleep_fut
147
244
                    .as_mut()
148
244
                    .as_pin_mut()
149
244
                    .expect("but we just set it to `Some`?!")
150
244
                    .poll(cx)
151
            }
152
        }
153
284
    }
154
}
155

            
156
impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
157
where
158
    W: AsyncWrite,
159
    P: SleepProvider,
160
{
161
6060
    fn poll_write(
162
6060
        mut self: Pin<&mut Self>,
163
6060
        cx: &mut Context<'_>,
164
6060
        mut buf: &[u8],
165
6060
    ) -> Poll<Result<usize, Error>> {
166
6060
        let mut self_ = self.as_mut().project();
167

            
168
        // this should be optimized to a no-op on at least x86-64
169
12316
        fn to_u64(x: usize) -> u64 {
170
12316
            x.try_into().expect("failed usize to u64 conversion")
171
12316
        }
172

            
173
        // for an empty buffer, just defer to the inner writer's impl
174
6060
        if buf.is_empty() {
175
            return self_.inner.poll_write(cx, buf);
176
6060
        }
177

            
178
6060
        let now = self_.sleep_provider.now();
179

            
180
        // refill the bucket and attempt to claim all of the bytes
181
6060
        self_.bucket.refill(now);
182
6060
        let claim = self_.bucket.claim(to_u64(buf.len()));
183

            
184
6060
        let mut claim = match claim {
185
            // claim was successful
186
5540
            Ok(x) => x,
187
            // not enough tokens, so let's use a smaller buffer
188
520
            Err(e) => {
189
520
                let available = e.available_tokens();
190

            
191
                // need to drop the old claim so that we can access the token bucket again
192
520
                drop(claim);
193

            
194
                // if no tokens in bucket, we must sleep
195
520
                if available == 0 {
196
                    // number of tokens we'll wait for
197
284
                    let wake_at_tokens = to_u64(buf.len());
198

            
199
                    // If the user wants to write X tokens, we don't necessarily want to sleep until
200
                    // we have room for X tokens. We also don't want to wake every time that a
201
                    // single byte can be written. We allow the user to configure this threshold
202
                    // with `RateLimitedWriterConfig::wake_when_bytes_available`.
203
284
                    let wake_at_tokens =
204
284
                        std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
205

            
206
                    // max number of tokens the bucket can hold
207
284
                    let bucket_max = self_.bucket.max();
208

            
209
                    // how long to sleep for; `None` indicates to sleep forever
210
284
                    let sleep_for = if bucket_max == 0 {
211
                        // bucket can't hold any tokens, so sleep forever
212
32
                        None
213
                    } else {
214
                        // if the bucket has a max of X tokens, we should never try to wait for >X
215
                        // tokens
216
252
                        let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
217

            
218
                        // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
219
                        // want
220
252
                        debug_assert!(wake_at_tokens > 0);
221

            
222
252
                        let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
223
252
                        let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
224

            
225
8
                        match sleep_for {
226
244
                            Ok(x) => Some(x),
227
                            Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
228
                                panic!(
229
                                    "exceeds max tokens, but we took the max into account above"
230
                                );
231
                            }
232
                            // we aren't refilling, so sleep forever
233
8
                            Err(NeverEnoughTokensError::ZeroRate) => None,
234
                            // too far in the future to be represented, so sleep forever
235
                            Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
236
                        }
237
                    };
238

            
239
                    // configure the sleep future and poll it to register
240
284
                    let poll = Self::register_sleep(
241
284
                        &mut self_.sleep_fut,
242
284
                        self_.sleep_provider,
243
284
                        cx,
244
284
                        sleep_for,
245
                    );
246
284
                    return match poll {
247
                        // wait for the sleep to finish
248
284
                        Poll::Pending => Poll::Pending,
249
                        // The sleep is already ready?! A recursive call here isn't great, but
250
                        // there's not much else we can do here. Hopefully this second `poll_write`
251
                        // will succeed since we should now have enough tokens.
252
                        Poll::Ready(()) => self.poll_write(cx, buf),
253
                    };
254
236
                }
255

            
256
                /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
257
                // This is a separate function to ensure we don't accidentally try to convert a
258
                // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
259
                // sense.
260
236
                fn to_usize_saturating(x: u64) -> usize {
261
236
                    x.try_into().unwrap_or(usize::MAX)
262
236
                }
263

            
264
                // There are tokens, so try to write as many as are available.
265
236
                let available_usize = to_usize_saturating(available);
266
236
                buf = &buf[0..available_usize];
267
236
                self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
268
                    panic!(
269
                        "bucket has {available} tokens available, but can't claim {}?",
270
                        buf.len(),
271
                    )
272
                })
273
            }
274
        };
275

            
276
5776
        let rv = self_.inner.poll_write(cx, buf);
277

            
278
5736
        match rv {
279
            // no bytes were written, so discard the claim
280
40
            Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
281
            // `x` bytes were written, so only commit those tokens
282
5736
            Poll::Ready(Ok(x)) => {
283
5736
                if x <= buf.len() {
284
5736
                    claim
285
5736
                        .reduce(to_u64(x))
286
5736
                        .expect("can't commit fewer tokens?!");
287
5736
                    claim.commit();
288
5736
                } else {
289
                    cfg_if::cfg_if! {
290
                        if #[cfg(debug_assertions)] {
291
                            panic!(
292
                                "Writer is claiming it wrote more bytes {x} than we gave it {}",
293
                                buf.len(),
294
                            );
295
                        } else {
296
                            // the best we can do is to just claim the original amount
297
                            claim.commit();
298
                        }
299
                    }
300
                }
301
            }
302
        };
303

            
304
5776
        rv
305
6060
    }
306

            
307
32
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308
32
        self.project().inner.poll_flush(cx)
309
32
    }
310

            
311
16
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
312
        // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
313
        // closed and will continue to accept bytes even after being closed, so we must continue to
314
        // apply rate limiting even after being closed
315
16
        self.project().inner.poll_close(cx)
316
16
    }
317
}
318

            
319
/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
320
/// everywhere.
321
#[cfg(feature = "tokio")]
322
mod tokio_impl {
323
    use super::*;
324

            
325
    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
326
    use tokio_util::compat::FuturesAsyncWriteCompatExt;
327

            
328
    use std::io::Result as IoResult;
329

            
330
    impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
331
    where
332
        W: AsyncWrite,
333
        P: SleepProvider,
334
    {
335
        fn poll_write(
336
            self: Pin<&mut Self>,
337
            cx: &mut Context<'_>,
338
            buf: &[u8],
339
        ) -> Poll<IoResult<usize>> {
340
            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
341
        }
342

            
343
        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
344
            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
345
        }
346

            
347
        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
348
            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
349
        }
350
    }
351
}
352

            
353
/// The refill rate and burst for a [`RateLimitedWriter`].
354
#[derive(Clone, Debug)]
355
pub(crate) struct RateLimitedWriterConfig {
356
    /// The refill rate in bytes/second.
357
    pub(crate) rate: u64,
358
    /// The "burst" in bytes.
359
    pub(crate) burst: u64,
360
    /// When polled, block until at most this many bytes are available.
361
    ///
362
    /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
363
    /// larger.
364
    ///
365
    /// For example if a user attempts to write a large buffer, we usually don't want to block until
366
    /// the entire buffer can be written. We'd prefer several partial writes to a single large
367
    /// write. So instead of blocking until the entire buffer can be written, we only block until
368
    /// at most this many bytes are available.
369
    pub(crate) wake_when_bytes_available: NonZero<u64>,
370
}
371

            
372
#[cfg(test)]
373
mod test {
374
    #![allow(clippy::unwrap_used)]
375

            
376
    use super::*;
377

            
378
    use futures::{AsyncWriteExt, FutureExt};
379
    use tor_rtcompat::SpawnExt;
380

            
381
    #[test]
382
    fn writer() {
383
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
384
            let start = rt.now();
385

            
386
            // increases 10 tokens/second (one every 100 ms)
387
            let config = TokenBucketConfig {
388
                rate: 10,
389
                bucket_max: 100,
390
            };
391
            let mut tb = TokenBucket::new(&config, start);
392
            // drain the bucket
393
            tb.drain(100).unwrap();
394

            
395
            let wake_when_bytes_available = NonZero::new(15).unwrap();
396

            
397
            let mut writer = Vec::new();
398
            let mut writer = RateLimitedWriter::from_token_bucket(
399
                &mut writer,
400
                tb,
401
                wake_when_bytes_available,
402
                rt.clone(),
403
            );
404

            
405
            // drive time forward from 0 to 20_000 ms in 50 ms intervals
406
            let rt_clone = rt.clone();
407
            rt.spawn(async move {
408
                for _ in 0..400 {
409
                    rt_clone.progress_until_stalled().await;
410
                    rt_clone.advance_by(Duration::from_millis(50)).await;
411
                }
412
            })
413
            .unwrap();
414

            
415
            // try writing 60 bytes, which sleeps until we can write at least 15 of them
416
            assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
417
            assert_eq!(1500, rt.now().duration_since(start).as_millis());
418

            
419
            // wait 2 seconds
420
            rt.sleep(Duration::from_millis(2000)).await;
421

            
422
            // ensure that we can write immediately, and that we can write
423
            // 2000 ms / (100 ms/token) = 20 bytes
424
            assert_eq!(
425
                Some(20),
426
                writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
427
            );
428
        });
429
    }
430

            
431
    /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
432
    #[test]
433
    fn rate_burst_zero() {
434
        let configs = [
435
            // non-zero rate, zero max
436
            TokenBucketConfig {
437
                rate: 10,
438
                bucket_max: 0,
439
            },
440
            // zero rate, non-zero max
441
            TokenBucketConfig {
442
                rate: 0,
443
                bucket_max: 10,
444
            },
445
            // zero rate, zero max
446
            TokenBucketConfig {
447
                rate: 0,
448
                bucket_max: 0,
449
            },
450
        ];
451
        for config in configs {
452
            tor_rtmock::MockRuntime::test_with_various(|rt| {
453
                let config = config.clone();
454
                async move {
455
                    // an empty token bucket
456
                    let mut tb = TokenBucket::new(&config, rt.now());
457
                    tb.drain(tb.max()).unwrap();
458
                    assert!(tb.is_empty());
459

            
460
                    let wake_when_bytes_available = NonZero::new(2).unwrap();
461

            
462
                    let mut writer = Vec::new();
463
                    let mut writer = RateLimitedWriter::from_token_bucket(
464
                        &mut writer,
465
                        tb,
466
                        wake_when_bytes_available,
467
                        rt.clone(),
468
                    );
469

            
470
                    // drive time forward from 0 to 10_000 ms in 100 ms intervals
471
                    let rt_clone = rt.clone();
472
                    rt.spawn(async move {
473
                        for _ in 0..100 {
474
                            rt_clone.progress_until_stalled().await;
475
                            rt_clone.advance_by(Duration::from_millis(100)).await;
476
                        }
477
                    })
478
                    .unwrap();
479

            
480
                    // ensure that a write returns `Pending`
481
                    assert_eq!(
482
                        None,
483
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
484
                    );
485

            
486
                    // wait 5 seconds
487
                    rt.sleep(Duration::from_millis(5000)).await;
488

            
489
                    // ensure that a write still returns `Pending`
490
                    assert_eq!(
491
                        None,
492
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
493
                    );
494
                }
495
            });
496
        }
497
    }
498
}