1
//! [`PollAll`]
2

            
3
use futures::FutureExt as _;
4
use smallvec::{SmallVec, smallvec};
5

            
6
use std::future::Future;
7
use std::pin::Pin;
8
use std::task::{Context, Poll};
9

            
10
/// The future type in a [`PollAll`].
11
type BoxedFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
12

            
13
/// Helper for driving multiple futures in lockstep.
14
///
15
/// When `.await`ed, a [`PollAll`] will unconditionally poll *all* of its
16
/// underlying futures, in the order they were [`push`](PollAll::push)ed,
17
/// until one or more of them resolves.
18
/// Any remaining unresolved futures will be dropped.
19
/// An empty `PollAll` will resolve immediately, yielding an empty list.
20
///
21
/// `PollAll` resolves to an *ordered* list of results, obtained from polling
22
/// the futures in insertion order. Because some of the futures may not
23
/// get a chance to resolve, the number of results will always
24
/// be less than or equal to the number of inserted futures.
25
///
26
/// Because `PollAll` drives the futures in lockstep,
27
/// if one future becomes ready, all of the futures will get polled,
28
/// even if they didn't generate a wakeup notification.
29
///
30
/// ### Invariants
31
///
32
/// All of the futures inserted into this set **must** be cancellation safe.
33
#[derive(Default)]
34
pub(crate) struct PollAll<'a, const N: usize, T> {
35
    /// The futures to drive in lockstep.
36
    inner: SmallVec<[BoxedFut<'a, T>; N]>,
37
}
38

            
39
impl<'a, const N: usize, T> PollAll<'a, N, T> {
40
    /// Create an empty [`PollAll`].
41
12650
    pub(crate) fn new() -> Self {
42
12650
        Self { inner: smallvec![] }
43
12650
    }
44

            
45
    /// Add a future to this [`PollAll`].
46
21158
    pub(crate) fn push<S: Future<Output = T> + Send + 'a>(&mut self, item: S) {
47
21158
        self.inner.push(Box::pin(item));
48
21158
    }
49
}
50

            
51
impl<'a, const N: usize, T> Future for PollAll<'a, N, T> {
52
    type Output = SmallVec<[T; N]>;
53

            
54
16088
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55
16088
        let mut results = smallvec![];
56

            
57
16088
        if self.inner.is_empty() {
58
            // Nothing to do.
59
4
            return Poll::Ready(results);
60
16084
        }
61

            
62
26984
        for fut in self.inner.iter_mut() {
63
26984
            match fut.poll_unpin(cx) {
64
9502
                Poll::Ready(res) => results.push(res),
65
17482
                Poll::Pending => continue,
66
            }
67
        }
68

            
69
16084
        if results.is_empty() {
70
6630
            return Poll::Pending;
71
9454
        }
72

            
73
9454
        Poll::Ready(results)
74
16088
    }
75
}
76

            
77
#[cfg(test)]
78
mod test {
79
    // @@ begin test lint list maintained by maint/add_warning @@
80
    #![allow(clippy::bool_assert_comparison)]
81
    #![allow(clippy::clone_on_copy)]
82
    #![allow(clippy::dbg_macro)]
83
    #![allow(clippy::mixed_attributes_style)]
84
    #![allow(clippy::print_stderr)]
85
    #![allow(clippy::print_stdout)]
86
    #![allow(clippy::single_char_pattern)]
87
    #![allow(clippy::unwrap_used)]
88
    #![allow(clippy::unchecked_time_subtraction)]
89
    #![allow(clippy::useless_vec)]
90
    #![allow(clippy::needless_pass_by_value)]
91
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
92
    use super::*;
93

            
94
    use tor_rtmock::MockRuntime;
95

            
96
    use std::sync::Arc;
97
    use std::sync::atomic::{AtomicUsize, Ordering};
98

            
99
    /// Dummy smallvec capacity.
100
    const RES_COUNT: usize = 5;
101

            
102
    /// A wrapper over a future, that counts how many times it is polled.
103
    struct PollCounter<F> {
104
        /// The poll count, shared with the caller.
105
        count: Arc<AtomicUsize>,
106
        /// The underlying future.
107
        inner: F,
108
    }
109

            
110
    /// A future that resolves after a fixed number of calls to `poll()`.
111
    struct ResolveAfter {
112
        /// The number of poll() calls until this future resolves
113
        resolve_after: usize,
114
        /// The number of times poll() was called on this.
115
        poll_count: usize,
116
    }
117

            
118
    impl ResolveAfter {
119
        fn new(resolve_after: usize) -> Self {
120
            Self {
121
                resolve_after,
122
                poll_count: 0,
123
            }
124
        }
125
    }
126

            
127
    impl<F> PollCounter<F> {
128
        fn new(inner: F) -> (Self, Arc<AtomicUsize>) {
129
            let count = Arc::new(AtomicUsize::new(0));
130
            let poll_counter = Self {
131
                count: Arc::clone(&count),
132
                inner,
133
            };
134

            
135
            (poll_counter, count)
136
        }
137
    }
138

            
139
    impl<F: Future + Unpin> Future for PollCounter<F> {
140
        type Output = F::Output;
141

            
142
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143
            let _ = self.count.fetch_add(1, Ordering::Relaxed);
144
            self.inner.poll_unpin(cx)
145
        }
146
    }
147

            
148
    impl Future for ResolveAfter {
149
        type Output = usize;
150

            
151
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152
            self.poll_count += 1;
153

            
154
            // TODO MSRV 1.87: Remove this allow.
155
            #[allow(
156
                clippy::comparison_chain,
157
                reason = "This is more readable than a match, and the lint is
158
                moved to clippy::pedantic in 1.87."
159
            )]
160
            if self.poll_count == self.resolve_after {
161
                Poll::Ready(self.resolve_after)
162
            } else if self.poll_count > self.resolve_after {
163
                panic!("future polled after completion?!");
164
            } else {
165
                // Immediately wake the waker
166
                cx.waker().wake_by_ref();
167
                Poll::Pending
168
            }
169
        }
170
    }
171

            
172
    #[test]
173
    fn poll_none() {
174
        MockRuntime::test_with_various(|_| async move {
175
            assert!(PollAll::<RES_COUNT, ()>::new().await.is_empty());
176
        });
177
    }
178

            
179
    #[test]
180
    fn poll_multiple() {
181
        MockRuntime::test_with_various(|_| async move {
182
            let mut poll_all = PollAll::<RES_COUNT, usize>::new();
183

            
184
            let (never_fut, never_count) = PollCounter::new(futures::future::pending::<usize>());
185
            poll_all.push(never_fut);
186

            
187
            let (futures, counters): (Vec<_>, Vec<_>) = [
188
                PollCounter::new(ResolveAfter::new(5)),
189
                PollCounter::new(ResolveAfter::new(5)),
190
                // These won't get a chance to resolve
191
                PollCounter::new(ResolveAfter::new(8)),
192
                PollCounter::new(ResolveAfter::new(9)),
193
            ]
194
            .into_iter()
195
            .unzip();
196

            
197
            for fut in futures {
198
                poll_all.push(fut);
199
            }
200

            
201
            let res = poll_all.await;
202
            assert_eq!(&res[..], &[5, 5]);
203

            
204
            // All futures were polled 5 times.
205
            assert_eq!(never_count.load(Ordering::Relaxed), 5);
206
            for counter in counters {
207
                assert_eq!(counter.load(Ordering::Relaxed), 5);
208
            }
209
        });
210
    }
211
}