1
//! [`PollAll`]
2

            
3
use futures::FutureExt as _;
4
use smallvec::{SmallVec, smallvec};
5

            
6
use std::future::Future;
7
use std::pin::Pin;
8
use std::task::{Context, Poll};
9

            
10
/// The future type in a [`PollAll`].
11
type BoxedFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
12

            
13
/// Helper for driving multiple futures in lockstep.
14
///
15
/// When `.await`ed, a [`PollAll`] will unconditionally poll *all* of its
16
/// underlying futures, in the order they were [`push`](PollAll::push)ed,
17
/// until one or more of them resolves.
18
/// Any remaining unresolved futures will be dropped.
19
/// An empty `PollAll` will resolve immediately, yielding an empty list.
20
///
21
/// `PollAll` resolves to an *ordered* list of results, obtained from polling
22
/// the futures in insertion order. Because some of the futures may not
23
/// get a chance to resolve, the number of results will always
24
/// be less than or equal to the number of inserted futures.
25
///
26
/// Because `PollAll` drives the futures in lockstep,
27
/// if one future becomes ready, all of the futures will get polled,
28
/// even if they didn't generate a wakeup notification.
29
///
30
/// ### Invariants
31
///
32
/// All of the futures inserted into this set **must** be cancellation safe.
33
#[derive(Default)]
34
pub(crate) struct PollAll<'a, const N: usize, T> {
35
    /// The futures to drive in lockstep.
36
    inner: SmallVec<[BoxedFut<'a, T>; N]>,
37
}
38

            
39
impl<'a, const N: usize, T> PollAll<'a, N, T> {
40
    /// Create an empty [`PollAll`].
41
12710
    pub(crate) fn new() -> Self {
42
12710
        Self { inner: smallvec![] }
43
12710
    }
44

            
45
    /// Add a future to this [`PollAll`].
46
21320
    pub(crate) fn push<S: Future<Output = T> + Send + 'a>(&mut self, item: S) {
47
21320
        self.inner.push(Box::pin(item));
48
21320
    }
49
}
50

            
51
impl<'a, const N: usize, T> Future for PollAll<'a, N, T> {
52
    type Output = SmallVec<[T; N]>;
53

            
54
16226
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55
16226
        let mut results = smallvec![];
56

            
57
16226
        if self.inner.is_empty() {
58
            // Nothing to do.
59
4
            return Poll::Ready(results);
60
16222
        }
61

            
62
27356
        for fut in self.inner.iter_mut() {
63
27356
            match fut.poll_unpin(cx) {
64
9522
                Poll::Ready(res) => results.push(res),
65
17834
                Poll::Pending => continue,
66
            }
67
        }
68

            
69
16222
        if results.is_empty() {
70
6752
            return Poll::Pending;
71
9470
        }
72

            
73
9470
        Poll::Ready(results)
74
16226
    }
75
}
76

            
77
#[cfg(test)]
78
mod test {
79
    // @@ begin test lint list maintained by maint/add_warning @@
80
    #![allow(clippy::bool_assert_comparison)]
81
    #![allow(clippy::clone_on_copy)]
82
    #![allow(clippy::dbg_macro)]
83
    #![allow(clippy::mixed_attributes_style)]
84
    #![allow(clippy::print_stderr)]
85
    #![allow(clippy::print_stdout)]
86
    #![allow(clippy::single_char_pattern)]
87
    #![allow(clippy::unwrap_used)]
88
    #![allow(clippy::unchecked_time_subtraction)]
89
    #![allow(clippy::useless_vec)]
90
    #![allow(clippy::needless_pass_by_value)]
91
    #![allow(clippy::string_slice)] // See arti#2571
92
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
93
    use super::*;
94

            
95
    use tor_rtmock::MockRuntime;
96

            
97
    use std::sync::Arc;
98
    use std::sync::atomic::{AtomicUsize, Ordering};
99

            
100
    /// Dummy smallvec capacity.
101
    const RES_COUNT: usize = 5;
102

            
103
    /// A wrapper over a future, that counts how many times it is polled.
104
    struct PollCounter<F> {
105
        /// The poll count, shared with the caller.
106
        count: Arc<AtomicUsize>,
107
        /// The underlying future.
108
        inner: F,
109
    }
110

            
111
    /// A future that resolves after a fixed number of calls to `poll()`.
112
    struct ResolveAfter {
113
        /// The number of poll() calls until this future resolves
114
        resolve_after: usize,
115
        /// The number of times poll() was called on this.
116
        poll_count: usize,
117
    }
118

            
119
    impl ResolveAfter {
120
        fn new(resolve_after: usize) -> Self {
121
            Self {
122
                resolve_after,
123
                poll_count: 0,
124
            }
125
        }
126
    }
127

            
128
    impl<F> PollCounter<F> {
129
        fn new(inner: F) -> (Self, Arc<AtomicUsize>) {
130
            let count = Arc::new(AtomicUsize::new(0));
131
            let poll_counter = Self {
132
                count: Arc::clone(&count),
133
                inner,
134
            };
135

            
136
            (poll_counter, count)
137
        }
138
    }
139

            
140
    impl<F: Future + Unpin> Future for PollCounter<F> {
141
        type Output = F::Output;
142

            
143
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
144
            let _ = self.count.fetch_add(1, Ordering::Relaxed);
145
            self.inner.poll_unpin(cx)
146
        }
147
    }
148

            
149
    impl Future for ResolveAfter {
150
        type Output = usize;
151

            
152
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153
            self.poll_count += 1;
154

            
155
            if self.poll_count == self.resolve_after {
156
                Poll::Ready(self.resolve_after)
157
            } else if self.poll_count > self.resolve_after {
158
                panic!("future polled after completion?!");
159
            } else {
160
                // Immediately wake the waker
161
                cx.waker().wake_by_ref();
162
                Poll::Pending
163
            }
164
        }
165
    }
166

            
167
    #[test]
168
    fn poll_none() {
169
        MockRuntime::test_with_various(|_| async move {
170
            assert!(PollAll::<RES_COUNT, ()>::new().await.is_empty());
171
        });
172
    }
173

            
174
    #[test]
175
    fn poll_multiple() {
176
        MockRuntime::test_with_various(|_| async move {
177
            let mut poll_all = PollAll::<RES_COUNT, usize>::new();
178

            
179
            let (never_fut, never_count) = PollCounter::new(futures::future::pending::<usize>());
180
            poll_all.push(never_fut);
181

            
182
            let (futures, counters): (Vec<_>, Vec<_>) = [
183
                PollCounter::new(ResolveAfter::new(5)),
184
                PollCounter::new(ResolveAfter::new(5)),
185
                // These won't get a chance to resolve
186
                PollCounter::new(ResolveAfter::new(8)),
187
                PollCounter::new(ResolveAfter::new(9)),
188
            ]
189
            .into_iter()
190
            .unzip();
191

            
192
            for fut in futures {
193
                poll_all.push(fut);
194
            }
195

            
196
            let res = poll_all.await;
197
            assert_eq!(&res[..], &[5, 5]);
198

            
199
            // All futures were polled 5 times.
200
            assert_eq!(never_count.load(Ordering::Relaxed), 5);
201
            for counter in counters {
202
                assert_eq!(counter.load(Ordering::Relaxed), 5);
203
            }
204
        });
205
    }
206
}