1
//! An [`AsyncWrite`] rate limiter.
2

            
3
use std::future::Future;
4
use std::num::NonZero;
5
use std::pin::Pin;
6
use std::task::{Context, Poll};
7
use std::time::{Duration, Instant};
8

            
9
use futures::AsyncWrite;
10
use futures::io::Error;
11
use sync_wrapper::SyncFuture;
12
use tor_rtcompat::SleepProvider;
13

            
14
use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15

            
16
/// A rate-limited async [writer](AsyncWrite).
17
///
18
/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19
#[derive(educe::Educe)]
20
#[educe(Debug)]
21
#[pin_project::pin_project]
22
pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23
    /// The token bucket.
24
    bucket: TokenBucket<Instant>,
25
    /// The sleep provider, for getting the current time and creating new sleep futures.
26
    ///
27
    /// While we use [`Instant`] for the time, we should always get the time from this
28
    /// [`SleepProvider`].
29
    /// For example, use [`SleepProvider::now()`], not [`Instant::now()`].
30
    #[educe(Debug(ignore))]
31
    sleep_provider: P,
32
    /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
33
    wake_when_bytes_available: NonZero<u64>,
34
    /// The inner writer.
35
    #[educe(Debug(ignore))]
36
    #[pin]
37
    inner: W,
38
    /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
39
    #[educe(Debug(ignore))]
40
    #[pin]
41
    sleep_fut: Option<SyncFuture<P::SleepFuture>>,
42
}
43

            
44
impl<W, P> RateLimitedWriter<W, P>
45
where
46
    W: AsyncWrite,
47
    P: SleepProvider,
48
{
49
    /// Create a new [`RateLimitedWriter`].
50
    // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
51
    // bucket only ever uses times from `sleep_provider`.
52
124
    pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
53
124
        let bucket_config = TokenBucketConfig {
54
124
            rate: config.rate,
55
124
            bucket_max: config.burst,
56
124
        };
57
124
        Self::from_token_bucket(
58
124
            writer,
59
124
            TokenBucket::new(&bucket_config, sleep_provider.now()),
60
124
            config.wake_when_bytes_available,
61
124
            sleep_provider,
62
        )
63
124
    }
64

            
65
    /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
66
    ///
67
    /// The token bucket must have only been used with times created by `sleep_provider`.
68
140
    #[cfg_attr(test, visibility::make(pub(super)))]
69
140
    fn from_token_bucket(
70
140
        writer: W,
71
140
        bucket: TokenBucket<Instant>,
72
140
        wake_when_bytes_available: NonZero<u64>,
73
140
        sleep_provider: P,
74
140
    ) -> Self {
75
140
        Self {
76
140
            bucket,
77
140
            sleep_provider,
78
140
            wake_when_bytes_available,
79
140
            inner: writer,
80
140
            sleep_fut: None,
81
140
        }
82
140
    }
83

            
84
    /// Access the inner [`AsyncWrite`] writer.
85
    pub(crate) fn inner(&self) -> &W {
86
        &self.inner
87
    }
88

            
89
    /// Adjust the refill rate and burst.
90
    ///
91
    /// A rate and/or burst of 0 is allowed.
92
84
    pub(crate) fn adjust(
93
84
        self: &mut Pin<&mut Self>,
94
84
        now: Instant,
95
84
        config: &RateLimitedWriterConfig,
96
84
    ) {
97
84
        let self_ = self.as_mut().project();
98

            
99
        // destructuring allows us to make sure we aren't forgetting to handle any fields
100
        let RateLimitedWriterConfig {
101
84
            rate,
102
84
            burst,
103
84
            wake_when_bytes_available,
104
84
        } = *config;
105

            
106
84
        let bucket_config = TokenBucketConfig {
107
84
            rate,
108
84
            bucket_max: burst,
109
84
        };
110

            
111
84
        self_.bucket.adjust(now, &bucket_config);
112
84
        *self_.wake_when_bytes_available = wake_when_bytes_available;
113
84
    }
114

            
115
    /// The sleep provider.
116
    ///
117
    /// We don't want this to be generally accessible, only to other token bucket-related modules
118
    /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
119
84
    pub(super) fn sleep_provider(&self) -> &P {
120
84
        &self.sleep_provider
121
84
    }
122

            
123
    /// Configure this writer to sleep for `duration`.
124
    ///
125
    /// A `duration` of `None` is interpreted as "forever".
126
    ///
127
    /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
128
284
    fn register_sleep(
129
284
        sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
130
284
        sleep_provider: &mut P,
131
284
        cx: &mut Context<'_>,
132
284
        duration: Option<Duration>,
133
284
    ) -> Poll<()> {
134
284
        match duration {
135
            None => {
136
40
                sleep_fut.as_mut().set(None);
137
40
                Poll::Pending
138
            }
139
244
            Some(duration) => {
140
244
                debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
141
244
                sleep_fut
142
244
                    .as_mut()
143
244
                    .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
144
244
                sleep_fut
145
244
                    .as_mut()
146
244
                    .as_pin_mut()
147
244
                    .expect("but we just set it to `Some`?!")
148
244
                    .poll(cx)
149
            }
150
        }
151
284
    }
152
}
153

            
154
impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
155
where
156
    W: AsyncWrite,
157
    P: SleepProvider,
158
{
159
6060
    fn poll_write(
160
6060
        mut self: Pin<&mut Self>,
161
6060
        cx: &mut Context<'_>,
162
6060
        mut buf: &[u8],
163
6060
    ) -> Poll<Result<usize, Error>> {
164
6060
        let mut self_ = self.as_mut().project();
165

            
166
        // this should be optimized to a no-op on at least x86-64
167
12316
        fn to_u64(x: usize) -> u64 {
168
12316
            x.try_into().expect("failed usize to u64 conversion")
169
12316
        }
170

            
171
        // for an empty buffer, just defer to the inner writer's impl
172
6060
        if buf.is_empty() {
173
            return self_.inner.poll_write(cx, buf);
174
6060
        }
175

            
176
6060
        let now = self_.sleep_provider.now();
177

            
178
        // refill the bucket and attempt to claim all of the bytes
179
6060
        self_.bucket.refill(now);
180
6060
        let claim = self_.bucket.claim(to_u64(buf.len()));
181

            
182
6060
        let mut claim = match claim {
183
            // claim was successful
184
5540
            Ok(x) => x,
185
            // not enough tokens, so let's use a smaller buffer
186
520
            Err(e) => {
187
520
                let available = e.available_tokens();
188

            
189
                // need to drop the old claim so that we can access the token bucket again
190
520
                drop(claim);
191

            
192
                // if no tokens in bucket, we must sleep
193
520
                if available == 0 {
194
                    // number of tokens we'll wait for
195
284
                    let wake_at_tokens = to_u64(buf.len());
196

            
197
                    // If the user wants to write X tokens, we don't necessarily want to sleep until
198
                    // we have room for X tokens. We also don't want to wake every time that a
199
                    // single byte can be written. We allow the user to configure this threshold
200
                    // with `RateLimitedWriterConfig::wake_when_bytes_available`.
201
284
                    let wake_at_tokens =
202
284
                        std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
203

            
204
                    // max number of tokens the bucket can hold
205
284
                    let bucket_max = self_.bucket.max();
206

            
207
                    // how long to sleep for; `None` indicates to sleep forever
208
284
                    let sleep_for = if bucket_max == 0 {
209
                        // bucket can't hold any tokens, so sleep forever
210
32
                        None
211
                    } else {
212
                        // if the bucket has a max of X tokens, we should never try to wait for >X
213
                        // tokens
214
252
                        let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
215

            
216
                        // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
217
                        // want
218
252
                        debug_assert!(wake_at_tokens > 0);
219

            
220
252
                        let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
221
252
                        let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
222

            
223
8
                        match sleep_for {
224
244
                            Ok(x) => Some(x),
225
                            Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
226
                                panic!(
227
                                    "exceeds max tokens, but we took the max into account above"
228
                                );
229
                            }
230
                            // we aren't refilling, so sleep forever
231
8
                            Err(NeverEnoughTokensError::ZeroRate) => None,
232
                            // too far in the future to be represented, so sleep forever
233
                            Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
234
                        }
235
                    };
236

            
237
                    // configure the sleep future and poll it to register
238
284
                    let poll = Self::register_sleep(
239
284
                        &mut self_.sleep_fut,
240
284
                        self_.sleep_provider,
241
284
                        cx,
242
284
                        sleep_for,
243
                    );
244
284
                    return match poll {
245
                        // wait for the sleep to finish
246
284
                        Poll::Pending => Poll::Pending,
247
                        // The sleep is already ready?! A recursive call here isn't great, but
248
                        // there's not much else we can do here. Hopefully this second `poll_write`
249
                        // will succeed since we should now have enough tokens.
250
                        Poll::Ready(()) => self.poll_write(cx, buf),
251
                    };
252
236
                }
253

            
254
                /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
255
                // This is a separate function to ensure we don't accidentally try to convert a
256
                // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
257
                // sense.
258
236
                fn to_usize_saturating(x: u64) -> usize {
259
236
                    x.try_into().unwrap_or(usize::MAX)
260
236
                }
261

            
262
                // There are tokens, so try to write as many as are available.
263
236
                let available_usize = to_usize_saturating(available);
264
236
                buf = &buf[0..available_usize];
265
236
                self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
266
                    panic!(
267
                        "bucket has {available} tokens available, but can't claim {}?",
268
                        buf.len(),
269
                    )
270
                })
271
            }
272
        };
273

            
274
5776
        let rv = self_.inner.poll_write(cx, buf);
275

            
276
5736
        match rv {
277
            // no bytes were written, so discard the claim
278
40
            Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
279
            // `x` bytes were written, so only commit those tokens
280
5736
            Poll::Ready(Ok(x)) => {
281
5736
                if x <= buf.len() {
282
5736
                    claim
283
5736
                        .reduce(to_u64(x))
284
5736
                        .expect("can't commit fewer tokens?!");
285
5736
                    claim.commit();
286
5736
                } else {
287
                    cfg_if::cfg_if! {
288
                        if #[cfg(debug_assertions)] {
289
                            panic!(
290
                                "Writer is claiming it wrote more bytes {x} than we gave it {}",
291
                                buf.len(),
292
                            );
293
                        } else {
294
                            // the best we can do is to just claim the original amount
295
                            claim.commit();
296
                        }
297
                    }
298
                }
299
            }
300
        };
301

            
302
5776
        rv
303
6060
    }
304

            
305
32
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
306
32
        self.project().inner.poll_flush(cx)
307
32
    }
308

            
309
16
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
310
        // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
311
        // closed and will continue to accept bytes even after being closed, so we must continue to
312
        // apply rate limiting even after being closed
313
16
        self.project().inner.poll_close(cx)
314
16
    }
315
}
316

            
317
/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
318
/// everywhere.
319
#[cfg(feature = "tokio")]
320
mod tokio_impl {
321
    use super::*;
322

            
323
    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
324
    use tokio_util::compat::FuturesAsyncWriteCompatExt;
325

            
326
    use std::io::Result as IoResult;
327

            
328
    impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
329
    where
330
        W: AsyncWrite,
331
        P: SleepProvider,
332
    {
333
        fn poll_write(
334
            self: Pin<&mut Self>,
335
            cx: &mut Context<'_>,
336
            buf: &[u8],
337
        ) -> Poll<IoResult<usize>> {
338
            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
339
        }
340

            
341
        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
342
            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
343
        }
344

            
345
        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
346
            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
347
        }
348
    }
349
}
350

            
351
/// The refill rate and burst for a [`RateLimitedWriter`].
352
#[derive(Clone, Debug)]
353
pub(crate) struct RateLimitedWriterConfig {
354
    /// The refill rate in bytes/second.
355
    pub(crate) rate: u64,
356
    /// The "burst" in bytes.
357
    pub(crate) burst: u64,
358
    /// When polled, block until at most this many bytes are available.
359
    ///
360
    /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
361
    /// larger.
362
    ///
363
    /// For example if a user attempts to write a large buffer, we usually don't want to block until
364
    /// the entire buffer can be written. We'd prefer several partial writes to a single large
365
    /// write. So instead of blocking until the entire buffer can be written, we only block until
366
    /// at most this many bytes are available.
367
    pub(crate) wake_when_bytes_available: NonZero<u64>,
368
}
369

            
370
#[cfg(test)]
371
mod test {
372
    #![allow(clippy::unwrap_used)]
373

            
374
    use super::*;
375

            
376
    use futures::{AsyncWriteExt, FutureExt};
377
    use tor_rtcompat::SpawnExt;
378

            
379
    #[test]
380
    fn writer() {
381
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
382
            let start = rt.now();
383

            
384
            // increases 10 tokens/second (one every 100 ms)
385
            let config = TokenBucketConfig {
386
                rate: 10,
387
                bucket_max: 100,
388
            };
389
            let mut tb = TokenBucket::new(&config, start);
390
            // drain the bucket
391
            tb.drain(100).unwrap();
392

            
393
            let wake_when_bytes_available = NonZero::new(15).unwrap();
394

            
395
            let mut writer = Vec::new();
396
            let mut writer = RateLimitedWriter::from_token_bucket(
397
                &mut writer,
398
                tb,
399
                wake_when_bytes_available,
400
                rt.clone(),
401
            );
402

            
403
            // drive time forward from 0 to 20_000 ms in 50 ms intervals
404
            let rt_clone = rt.clone();
405
            rt.spawn(async move {
406
                for _ in 0..400 {
407
                    rt_clone.progress_until_stalled().await;
408
                    rt_clone.advance_by(Duration::from_millis(50)).await;
409
                }
410
            })
411
            .unwrap();
412

            
413
            // try writing 60 bytes, which sleeps until we can write at least 15 of them
414
            assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
415
            assert_eq!(1500, rt.now().duration_since(start).as_millis());
416

            
417
            // wait 2 seconds
418
            rt.sleep(Duration::from_millis(2000)).await;
419

            
420
            // ensure that we can write immediately, and that we can write
421
            // 2000 ms / (100 ms/token) = 20 bytes
422
            assert_eq!(
423
                Some(20),
424
                writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
425
            );
426
        });
427
    }
428

            
429
    /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
430
    #[test]
431
    fn rate_burst_zero() {
432
        let configs = [
433
            // non-zero rate, zero max
434
            TokenBucketConfig {
435
                rate: 10,
436
                bucket_max: 0,
437
            },
438
            // zero rate, non-zero max
439
            TokenBucketConfig {
440
                rate: 0,
441
                bucket_max: 10,
442
            },
443
            // zero rate, zero max
444
            TokenBucketConfig {
445
                rate: 0,
446
                bucket_max: 0,
447
            },
448
        ];
449
        for config in configs {
450
            tor_rtmock::MockRuntime::test_with_various(|rt| {
451
                let config = config.clone();
452
                async move {
453
                    // an empty token bucket
454
                    let mut tb = TokenBucket::new(&config, rt.now());
455
                    tb.drain(tb.max()).unwrap();
456
                    assert!(tb.is_empty());
457

            
458
                    let wake_when_bytes_available = NonZero::new(2).unwrap();
459

            
460
                    let mut writer = Vec::new();
461
                    let mut writer = RateLimitedWriter::from_token_bucket(
462
                        &mut writer,
463
                        tb,
464
                        wake_when_bytes_available,
465
                        rt.clone(),
466
                    );
467

            
468
                    // drive time forward from 0 to 10_000 ms in 100 ms intervals
469
                    let rt_clone = rt.clone();
470
                    rt.spawn(async move {
471
                        for _ in 0..100 {
472
                            rt_clone.progress_until_stalled().await;
473
                            rt_clone.advance_by(Duration::from_millis(100)).await;
474
                        }
475
                    })
476
                    .unwrap();
477

            
478
                    // ensure that a write returns `Pending`
479
                    assert_eq!(
480
                        None,
481
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
482
                    );
483

            
484
                    // wait 5 seconds
485
                    rt.sleep(Duration::from_millis(5000)).await;
486

            
487
                    // ensure that a write still returns `Pending`
488
                    assert_eq!(
489
                        None,
490
                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
491
                    );
492
                }
493
            });
494
        }
495
    }
496
}