1
//! A facility for an MPSC channel that counts the number of outstanding entries on the channel.
2
//
3
// (Tokio makes this possible by default, but we don't require tokio.  Crossbeam channels also allow
4
// this, but they aren't async, and they're MPMC. If a future version of the
5
// `futures` crate adds this functionality, we can use that instead. )
6

            
7
use std::{
8
    pin::Pin,
9
    sync::{
10
        Arc,
11
        atomic::{AtomicUsize, Ordering},
12
    },
13
    task::ready,
14
    task::{Context, Poll},
15
};
16

            
17
use futures::{Stream, sink::Sink, stream::FusedStream};
18
use pin_project::pin_project;
19

            
20
/// A wrapper around an arbitrary [`Sink`], to count the items inserted.
21
#[derive(Clone, Debug)]
22
#[pin_project]
23
pub struct CountingSink<S> {
24
    /// The inner sink whose items we're counting.
25
    #[pin]
26
    inner: S,
27
    /// A shared counter for items inserted into the channel
28
    ///
29
    /// We add 1 every time we enqueue an item.
30
    count: Arc<AtomicUsize>,
31
}
32

            
33
/// A wrapper around an arbitrary [`Stream`], to count the items inserted.
34
#[derive(Clone, Debug)]
35
#[pin_project]
36
pub struct CountingStream<S> {
37
    /// The inner stream whose items we're counting.
38
    #[pin]
39
    inner: S,
40
    /// A shared counter for items inserted into the channel.
41
    ///
42
    /// We remove 1 every time we dequeue an item.
43
    count: Arc<AtomicUsize>,
44
}
45

            
46
impl<T, S: Sink<T>> Sink<T> for CountingSink<S> {
47
    type Error = S::Error;
48

            
49
98958
    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50
98958
        self.project().inner.poll_ready(cx)
51
98958
    }
52

            
53
84754
    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
54
84754
        let self_ = self.project();
55
84754
        let r = self_.inner.start_send(item);
56
84754
        if r.is_ok() {
57
84754
            // We successfully sent an item, so we increment the counter.
58
84754
            //
59
84754
            // Using `Relaxed` ensures that the operation is atomic, but does not guarantee its
60
84754
            // order with respect to operations on other locations.
61
84754
            self_.count.fetch_add(1, Ordering::Relaxed);
62
84754
        }
63
84754
        r
64
84754
    }
65

            
66
80164
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67
80164
        self.project().inner.poll_flush(cx)
68
80164
    }
69

            
70
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
71
        self.project().inner.poll_close(cx)
72
    }
73
}
74

            
75
impl<S: Stream> Stream for CountingStream<S> {
76
    type Item = S::Item;
77

            
78
85590
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79
85590
        let self_ = self.project();
80
85590
        let next = ready!(self_.inner.poll_next(cx));
81
84458
        if next.is_some() {
82
84382
            // We got an item, so we'll decrement the counter.
83
84382
            //
84
84382
            // See note above about "Relaxed" ordering.
85
84382
            self_.count.fetch_sub(1, Ordering::Relaxed);
86
84384
        }
87
84458
        Poll::Ready(next)
88
85590
    }
89
}
90

            
91
impl<S: FusedStream> FusedStream for CountingStream<S> {
92
5582
    fn is_terminated(&self) -> bool {
93
5582
        self.inner.is_terminated()
94
5582
    }
95
}
96

            
97
impl<S> CountingStream<S> {
98
    /// Return an approximate count of the number of items currently on this channel.
99
    ///
100
    /// This count is necessarily approximate because the count can be changed by any of this
101
    /// channel's Senders or Receivers between when the caller
102
    /// gets the count and when the caller uses the count.
103
80044
    pub fn approx_count(&self) -> usize {
104
80044
        self.count.load(Ordering::Relaxed)
105
80044
    }
106

            
107
    /// Return a reference to the inner stream.
108
    ///
109
    /// If the stream has interior mutability, the caller must take care
110
    /// not to do anything with the stream that would invalidate the current counter.
111
    pub fn inner(&self) -> &S {
112
        &self.inner
113
    }
114

            
115
    /// Return a mutable reference to the inner stream.
116
    ///
117
    /// If the stream has interior mutability, the caller must take care
118
    /// not to do anything with the stream that would invalidate the current counter.
119
    pub fn inner_mut(&mut self) -> &mut S {
120
        &mut self.inner
121
    }
122
}
123

            
124
impl<S> CountingSink<S> {
125
    /// Return an approximate count of the number of items currently on this channel.
126
    ///
127
    /// This count is necessarily approximate because the count can be changed by any of this
128
    /// channel's Senders or Receivers between when the caller
129
    /// gets the count and when the caller uses the count.
130
116
    pub fn approx_count(&self) -> usize {
131
116
        self.count.load(Ordering::Relaxed)
132
116
    }
133

            
134
    /// Return a reference to the inner sink.
135
    ///
136
    /// If the sink has interior mutability, the caller must take care
137
    /// not to do anything with the sink that would invalidate the current counter.
138
460
    pub fn inner(&self) -> &S {
139
460
        &self.inner
140
460
    }
141

            
142
    /// Return a mutable reference to the inner sink.
143
    ///
144
    /// If the sink has interior mutability, the caller must take care
145
    /// not to do anything with the sink that would invalidate the current counter.
146
    pub fn inner_mut(&mut self) -> &mut S {
147
        &mut self.inner
148
    }
149
}
150

            
151
/// Wrap a [`Sink`]/[`Stream`] pair into a [`CountingSink`] and [`CountingStream`] pair.
152
///
153
/// # Correctness
154
///
155
/// The sink and the stream should match and form a channel:
156
/// items sent on the sink should be received from the stream.
157
///
158
/// There should be no other handles in use for adding or removing items from the channel.
159
///
160
/// If these requirements aren't met, then the counts returned by the sink and stream
161
/// will not be accurate.
162
1706
pub fn channel<T, U>(tx: T, rx: U) -> (CountingSink<T>, CountingStream<U>) {
163
1706
    let count = Arc::new(AtomicUsize::new(0));
164
1706
    let new_tx = CountingSink {
165
1706
        inner: tx,
166
1706
        count: Arc::clone(&count),
167
1706
    };
168
1706
    let new_rx = CountingStream { inner: rx, count };
169
1706
    (new_tx, new_rx)
170
1706
}
171

            
172
#[cfg(test)]
173
mod test {
174
    // @@ begin test lint list maintained by maint/add_warning @@
175
    #![allow(clippy::bool_assert_comparison)]
176
    #![allow(clippy::clone_on_copy)]
177
    #![allow(clippy::dbg_macro)]
178
    #![allow(clippy::mixed_attributes_style)]
179
    #![allow(clippy::print_stderr)]
180
    #![allow(clippy::print_stdout)]
181
    #![allow(clippy::single_char_pattern)]
182
    #![allow(clippy::unwrap_used)]
183
    #![allow(clippy::unchecked_time_subtraction)]
184
    #![allow(clippy::useless_vec)]
185
    #![allow(clippy::needless_pass_by_value)]
186
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
187

            
188
    use futures::{SinkExt as _, StreamExt as _};
189

            
190
    #[test]
191
    fn send_only_onetask() {
192
        tor_rtmock::MockRuntime::test_with_various(|_rt| async move {
193
            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
194
            let (mut tx, rx) = super::channel(tx, rx);
195
            for n in 1..10 {
196
                tx.send(n).await.unwrap();
197
                assert_eq!(tx.approx_count(), n);
198
                assert_eq!(rx.approx_count(), n);
199
            }
200
        });
201
    }
202

            
203
    #[test]
204
    fn send_only_twotasks() {
205
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
206
            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
207
            let (mut tx, rx) = super::channel(tx, rx);
208

            
209
            let mut tx2 = tx.clone();
210
            let j1 = rt.spawn_join("thread1", async move {
211
                for n in 1..=10 {
212
                    tx.send(n).await.unwrap();
213
                    assert!(tx.approx_count() >= n);
214
                }
215
            });
216

            
217
            let j2 = rt.spawn_join("thread2", async move {
218
                for n in 1..=10 {
219
                    tx2.send(n).await.unwrap();
220
                    assert!(tx2.approx_count() >= n);
221
                }
222
            });
223
            j1.await;
224
            j2.await;
225
            assert_eq!(rx.approx_count(), 20);
226
        });
227
    }
228

            
229
    #[test]
230
    fn send_and_receive() {
231
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
232
            let (tx, rx) = futures::channel::mpsc::unbounded::<usize>();
233
            let (mut tx, mut rx) = super::channel(tx, rx);
234
            const MAX: usize = 10000;
235

            
236
            let mut tx2 = tx.clone();
237
            let j1 = rt.spawn_join("thread1", async move {
238
                for n in 1..=MAX {
239
                    tx.send(n).await.unwrap();
240
                }
241
            });
242

            
243
            let j2 = rt.spawn_join("thread2", async move {
244
                for n in 1..=MAX {
245
                    tx2.send(n).await.unwrap();
246
                }
247
            });
248

            
249
            let j3 = rt.spawn_join("receiver", async move {
250
                let mut total = 0;
251
                while let Some(x) = rx.next().await {
252
                    total += x; // spot check
253
                    let count = rx.approx_count();
254
                    assert!(count <= MAX * 2);
255
                }
256
                assert_eq!(total, MAX * (MAX + 1)); // two senders, so no "/2".
257
                rx
258
            });
259

            
260
            j1.await;
261
            j2.await;
262
            let rx = j3.await;
263
            assert_eq!(rx.approx_count(), 0);
264
        });
265
    }
266
}