1
//! An [`AsyncWrite`] rate limiter which receives rate limit changes from a [`FusedStream`].
2

            
3
use std::pin::Pin;
4
use std::task::{Context, Poll};
5

            
6
use futures::AsyncWrite;
7
use futures::io::Error;
8
use futures::stream::FusedStream;
9
use tor_rtcompat::SleepProvider;
10

            
11
use super::writer::{RateLimitedWriter, RateLimitedWriterConfig};
12

            
13
/// A rate-limited async [writer](AsyncWrite).
14
///
15
/// This wraps a [`RateLimitedWriter`] and watches a stream for configuration changes (such as rate
16
/// limit changes).
17
#[derive(educe::Educe)]
18
#[educe(Debug)]
19
#[pin_project::pin_project]
20
pub(crate) struct DynamicRateLimitedWriter<W: AsyncWrite, S, P: SleepProvider> {
21
    /// The rate-limited writer.
22
    #[pin]
23
    writer: RateLimitedWriter<W, P>,
24
    /// A stream that provides configuration updates, including rate limit updates.
25
    #[educe(Debug(ignore))]
26
    #[pin]
27
    updates: S,
28
}
29

            
30
impl<W, S, P> DynamicRateLimitedWriter<W, S, P>
31
where
32
    W: AsyncWrite,
33
    P: SleepProvider,
34
{
35
    /// Create a new [`DynamicRateLimitedWriter`].
36
    ///
37
    /// This wraps the `writer` and watches for configuration changes from the `updates` stream.
38
124
    pub(crate) fn new(writer: RateLimitedWriter<W, P>, updates: S) -> Self {
39
124
        Self { writer, updates }
40
124
    }
41

            
42
    /// Access the inner [`AsyncWrite`] writer of the [`RateLimitedWriter`].
43
    pub(crate) fn inner(&self) -> &W {
44
        self.writer.inner()
45
    }
46
}
47

            
48
impl<W, S, P> AsyncWrite for DynamicRateLimitedWriter<W, S, P>
49
where
50
    W: AsyncWrite,
51
    S: FusedStream<Item = RateLimitedWriterConfig>,
52
    P: SleepProvider,
53
{
54
6024
    fn poll_write(
55
6024
        mut self: Pin<&mut Self>,
56
6024
        cx: &mut Context<'_>,
57
6024
        buf: &[u8],
58
6024
    ) -> Poll<Result<usize, Error>> {
59
6024
        let mut self_ = self.as_mut().project();
60

            
61
        // Try getting any update to the rate limit and burst.
62
        //
63
        // We loop until we receive `Ready(None)` or `Pending`. The former indicates that we
64
        // shouldn't receive any more updates. The latter indicates that there aren't currently more
65
        // to read, and that we've registered the waker with the stream so that we'll wake when the
66
        // rate limit is later updated.
67
        //
68
        // Since `S` is a `FusedStream`, it's fine to call `poll_next()` even if `Ready(None)` was
69
        // returned in the past.
70
6024
        let mut iters = 0;
71
6108
        while let Poll::Ready(Some(config)) = self_.updates.as_mut().poll_next(cx) {
72
            // update the writer's configuration
73
84
            let now = self_.writer.sleep_provider().now();
74
84
            self_.writer.adjust(now, &config);
75

            
76
            // It's possible that `DynamicRateLimitedWriter` was constructed with a stream where an
77
            // infinite number of items will be immediately ready, for example with
78
            // `futures::stream::repeat()`. We escape the possible infinite loop by returning an
79
            // error.
80
84
            iters += 1;
81
84
            if iters > 100_000 {
82
                const MSG: &str =
83
                    "possible infinite loop in `DynamicRateLimitedWriter::poll_write`";
84
                tracing::debug!(MSG);
85
                return Poll::Ready(Err(Error::other(MSG)));
86
84
            }
87
        }
88

            
89
        // Try writing the bytes. This also registers the waker with the `RateLimitedWriter`.
90
6024
        self_.writer.poll_write(cx, buf)
91
6024
    }
92

            
93
32
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
94
32
        self.project().writer.poll_flush(cx)
95
32
    }
96

            
97
16
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
98
16
        self.project().writer.poll_close(cx)
99
16
    }
100
}
101

            
102
/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
103
/// everywhere.
104
#[cfg(feature = "tokio")]
105
mod tokio_impl {
106
    use super::*;
107

            
108
    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
109
    use tokio_util::compat::FuturesAsyncWriteCompatExt;
110

            
111
    use std::io::Result as IoResult;
112

            
113
    impl<W, S, P> TokioAsyncWrite for DynamicRateLimitedWriter<W, S, P>
114
    where
115
        W: AsyncWrite,
116
        S: FusedStream<Item = RateLimitedWriterConfig>,
117
        P: SleepProvider,
118
    {
119
        fn poll_write(
120
            self: Pin<&mut Self>,
121
            cx: &mut Context<'_>,
122
            buf: &[u8],
123
        ) -> Poll<IoResult<usize>> {
124
            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
125
        }
126

            
127
        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
128
            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
129
        }
130

            
131
        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
132
            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
133
        }
134
    }
135
}
136

            
137
#[cfg(test)]
138
mod test {
139
    #![allow(clippy::unwrap_used)]
140

            
141
    use super::*;
142

            
143
    use std::num::NonZero;
144
    use std::time::Duration;
145

            
146
    use futures::{AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
147
    use tor_rtcompat::SpawnExt;
148

            
149
    #[cfg(feature = "tokio")]
150
    use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
151

            
152
    /// This test ensures that a [`DynamicRateLimitedWriter`] writes the expected number of bytes,
153
    /// as a background task alternates the rate limit between on/off once every second.
154
    #[cfg(feature = "tokio")]
155
    #[test]
156
    fn alternating_on_off() {
157
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
158
            // drive time forward from 0 to 8_000 ms in 1 ms intervals
159
            let rt_clone = rt.clone();
160
            rt.spawn(async move {
161
                for _ in 0..8_000 {
162
                    rt_clone.progress_until_stalled().await;
163
                    rt_clone.advance_by(Duration::from_millis(1)).await;
164
                }
165
            })
166
            .unwrap();
167

            
168
            // start with a rate limiter that doesn't allow any bytes
169
            let config = RateLimitedWriterConfig {
170
                rate: 0,
171
                burst: 0,
172
                // wake up the writer each time the rate limiter allows 10 bytes to be sent
173
                wake_when_bytes_available: NonZero::new(10).unwrap(),
174
            };
175

            
176
            // there are some other crates which allow you to make a data "pipe" without tokio, but
177
            // I don't think it's worth bringing in a new dev-dependency for this
178
            let (writer, reader) = tokio_crate::io::duplex(/* max_buf_size= */ 1000);
179
            let writer = writer.compat_write();
180
            let mut reader = reader.compat();
181

            
182
            let writer = RateLimitedWriter::new(writer, &config, rt.clone());
183

            
184
            // how we send rate updates to the rate-limited writer
185
            let (mut rate_tx, rate_rx) = futures::channel::mpsc::unbounded();
186

            
187
            // our rate-limited writer which can receive rate limit changes
188
            let mut writer = DynamicRateLimitedWriter::new(writer, rate_rx);
189

            
190
            /// Duration between updates. A prime number is used so that smaller intervals don't
191
            /// fall on this interval, which can causes issues with `MockRuntime::test_with_various`
192
            /// since the test becomes dependent on the order that tasks are woken.
193
            const UPDATE_INTERVAL: Duration = Duration::from_millis(841);
194

            
195
            // a background task which sends alternating on/off rate limits every 841 ms
196
            let rt_clone = rt.clone();
197
            rt.spawn(async move {
198
                for rate in [100, 0, 200, 0, 400, 0] {
199
                    rt_clone.sleep(UPDATE_INTERVAL).await;
200

            
201
                    // update the rate/burst
202
                    let mut config = config.clone();
203
                    config.rate = rate;
204
                    config.burst = rate;
205

            
206
                    // we expect the send() to succeed immediately
207
                    rate_tx.send(config).now_or_never().unwrap().unwrap();
208
                }
209
            })
210
            .unwrap();
211

            
212
            // a background task which writes as much as possible
213
            rt.spawn(async move {
214
                // write until the receiving end goes away
215
                while writer.write(&[0; 100]).await.is_ok() {}
216
            })
217
            .unwrap();
218

            
219
            // helper to make the `assert_eq` a single line
220
            let res_unwrap = Result::unwrap;
221

            
222
            let mut buf = vec![0; 1000];
223
            let buf = &mut buf;
224

            
225
            // sleep for 1 ms so that our upcoming sleeps end 1 ms after the rate limit changes
226
            rt.sleep(Duration::from_millis(1)).await;
227

            
228
            // Rate is 0, so no bytes expected.
229
            rt.sleep(UPDATE_INTERVAL).await;
230
            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
231

            
232
            // Rate is 100 bytes/s, so 841/(1000/100) = 84 bytes expected.
233
            // Woken every `wake_when_bytes_available` = 10 bytes, so 80 bytes expected.
234
            rt.sleep(UPDATE_INTERVAL).await;
235
            assert_eq!(Some(80), reader.read(buf).now_or_never().map(res_unwrap));
236

            
237
            // Rate is 0, so no bytes expected.
238
            rt.sleep(UPDATE_INTERVAL).await;
239
            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
240

            
241
            // Rate is 200 bytes/s, so 841/(1000/200) = 168 bytes expected.
242
            // Woken every `wake_when_bytes_available` = 10 bytes, so 160 bytes expected.
243
            rt.sleep(UPDATE_INTERVAL).await;
244
            assert_eq!(Some(160), reader.read(buf).now_or_never().map(res_unwrap));
245

            
246
            // Rate is 0, so no bytes expected.
247
            rt.sleep(UPDATE_INTERVAL).await;
248
            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
249

            
250
            // Rate is 400 bytes/s, so 841/(1000/400) = 336 bytes expected.
251
            // Woken every `wake_when_bytes_available` = 10 bytes, so 330 bytes expected.
252
            rt.sleep(UPDATE_INTERVAL).await;
253
            assert_eq!(Some(330), reader.read(buf).now_or_never().map(res_unwrap));
254

            
255
            // Rate is 0, so no bytes expected.
256
            rt.sleep(UPDATE_INTERVAL).await;
257
            assert_eq!(None, reader.read(buf).now_or_never().map(res_unwrap));
258
        });
259
    }
260
}