1
//! Provides [`KeyedFuturesUnordered`]
2

            
3
// So that we can declare these things as if they were in their own crate.
4
#![allow(unreachable_pub)]
5

            
6
use std::{
7
    collections::{HashMap, hash_map},
8
    hash::Hash,
9
    pin::Pin,
10
    sync::Arc,
11
    task::Poll,
12
};
13

            
14
use futures::future::FutureExt;
15
use futures::{
16
    Future,
17
    channel::mpsc::{UnboundedReceiver, UnboundedSender},
18
};
19
use pin_project::pin_project;
20

            
21
/// Waker for internal use in [`KeyedFuturesUnordered`]
22
///
23
/// When woken, it notifies the parent [`KeyedFuturesUnordered`] that the future
24
/// for a corresponding key is ready to be polled.
25
struct KeyedWaker<K> {
26
    /// The key associated with this waker.
27
    key: K,
28
    /// Sender cloned from the parent [`KeyedFuturesUnordered`].
29
    sender: UnboundedSender<K>,
30
}
31

            
32
impl<K> std::task::Wake for KeyedWaker<K>
33
where
34
    K: Clone,
35
{
36
226
    fn wake(self: Arc<Self>) {
37
226
        self.sender
38
226
            .unbounded_send(self.key.clone())
39
226
            .unwrap_or_else(|e| {
40
4
                if e.is_disconnected() {
41
                    // Other side has disappeared. Can safely ignore.
42
4
                    return;
43
                }
44
                // Shouldn't happen, but probably no need to `panic`.
45
                tracing::error!("Bug: Unexpected send error: {e:?}");
46
4
            });
47
226
    }
48
}
49

            
50
/// Efficiently manages a dynamic set of futures as per
51
/// [`futures::stream::FuturesUnordered`]. Unlike `FuturesUnordered`, each future
52
/// has an associated key. This key is returned along with the future's output,
53
/// and can be used to cancel and *remove* a future from the set.
54
///
55
/// Implements [`futures::Stream`], producing a stream of completed futures and
56
/// their associated keys.
57
///
58
/// # Stream behavior
59
///
60
/// `Stream::poll_next` returns:
61
/// * `Poll::Ready(None)` if there are no futures managed by this object.
62
/// * `Poll::Ready(Some((key, output)))` with the key and output of a ready
63
///    future when there is one.
64
/// * `Poll::Pending` when there are futures managed by this object, but none
65
///    are currently ready.
66
///
67
/// Unlike for a generic `Stream`, it *is* permitted to call `poll_next` again
68
/// after having received `Poll::Ready(None)`. It will still behave as above
69
/// (i.e. returning `Pending` or `Ready` if futures have since been inserted).
70
#[derive(Debug)]
71
#[pin_project]
72
pub struct KeyedFuturesUnordered<K, F>
73
where
74
    F: Future,
75
{
76
    /// Receiver on which we're notified of keys that are ready to be polled.
77
    #[pin]
78
    notification_receiver: UnboundedReceiver<K>,
79
    /// Sender on which to notify `notifications_receiver` that keys are ready
80
    /// to be polled.
81
    // In particular, keys are sent here:
82
    // * When a future is inserted.
83
    // * In `KeyedWaker`, which is the `Waker` we register with futures when we
84
    //   poll them internally.
85
    notification_sender: UnboundedSender<K>,
86
    /// Map of pending futures.
87
    futures: HashMap<K, F>,
88
}
89

            
90
impl<K, F> KeyedFuturesUnordered<K, F>
91
where
92
    F: Future,
93
    K: Eq + Hash + Clone,
94
{
95
    /// Create an empty [`KeyedFuturesUnordered`].
96
1068
    pub fn new() -> Self {
97
1068
        let (send, recv) = futures::channel::mpsc::unbounded();
98
1068
        Self {
99
1068
            notification_sender: send,
100
1068
            notification_receiver: recv,
101
1068
            futures: Default::default(),
102
1068
        }
103
1068
    }
104

            
105
    /// Insert a future and associate it with `key`. Return an error if there is already an entry for `key`.
106
4756
    pub fn try_insert(&mut self, key: K, fut: F) -> Result<(), KeyAlreadyInsertedError<K, F>> {
107
4756
        let hash_map::Entry::Vacant(v) = self.futures.entry(key.clone()) else {
108
            // Key is already present.
109
            return Err(KeyAlreadyInsertedError { key, fut });
110
        };
111
4756
        v.insert(fut);
112
        // Immediately "notify" ourselves, to enqueue this key to be polled.
113
4756
        self.notification_sender
114
4756
            .unbounded_send(key)
115
            // * Since the sender is unbounded, can't fail due to fullness.
116
            // * Since we have our own copy of the receiver, can't be disconnected.
117
4756
            .expect("Unbounded send unexpectedly failed");
118
4756
        Ok(())
119
4756
    }
120

            
121
    /// Remove the entry for `key`, if any, and return the corresponding future.
122
116
    pub fn remove(&mut self, key: &K) -> Option<(K, F)> {
123
116
        self.futures.remove_entry(key)
124
116
    }
125

            
126
    /// Get the future corresponding to `key`, if any.
127
    ///
128
    /// As for [`Self::get_mut`], removing or replacing its [`std::task::Waker`]
129
    /// without waking it (e.g. using internal mutability) results in
130
    /// unspecified (but sound) behavior.
131
    #[allow(dead_code)]
132
16
    pub fn get<'a>(&'a self, key: &K) -> Option<&'a F> {
133
16
        self.futures.get(key)
134
16
    }
135

            
136
    /// Get the future corresponding to `key`, if any.
137
    ///
138
    /// The future should not be `poll`d, nor its registered
139
    /// [`std::task::Waker`] otherwise removed or replaced (unless it is also
140
    /// woken; see below). The result of doing either is unspecified (but
141
    /// sound).
142
    ///
143
    /// This method is useful primarily when the future has other functionality
144
    /// or data bundled with it besides its implementation of the `Future`
145
    /// trait, though it *is* permitted to mutate the object in a way that
146
    /// causes it to become ready (i.e. wakes and discards its registered
147
    /// [`std::task::Waker`]`), or become unready (cause its next poll result to
148
    /// be `Poll::Pending` when it otherwise would have been `Poll::Ready` and
149
    /// may have already woken its registered `Waker`).
150
    //
151
    // More specifically:
152
    // * If the waker is lost without being woken, we'll never
153
    //   poll this future again.
154
    // * If our waker is woken *and* the caller polls the future to completion,
155
    //   we could end up polling it again after completion,
156
    //   breaking the `Future` contract.
157
    #[allow(dead_code)]
158
8722
    pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut F> {
159
8722
        self.futures.get_mut(key)
160
8722
    }
161
}
162

            
163
impl<K, F> futures::Stream for KeyedFuturesUnordered<K, F>
164
where
165
    F: Future + Unpin,
166
    K: Clone + Hash + Eq + Send + Sync + 'static,
167
{
168
    type Item = (K, F::Output);
169

            
170
23588
    fn poll_next(
171
23588
        self: Pin<&mut Self>,
172
23588
        cx: &mut std::task::Context<'_>,
173
23588
    ) -> Poll<Option<Self::Item>> {
174
23588
        if self.futures.is_empty() {
175
            // Follow precedent of `FuturesUnordered` of returning None in this case.
176
            // TODO: Consider breaking this precedent? This behavior is a bit
177
            // odd, since the documentation of the Stream trait indicates that a
178
            // stream shouldn't be polled again after returning None.
179
18744
            return Poll::Ready(None);
180
4844
        }
181
4844
        let mut self_ = self.project();
182
        loop {
183
            // Get the next pollable future, registering the caller's waker.
184
5082
            let key = match self_.notification_receiver.as_mut().poll_next(cx) {
185
4650
                Poll::Ready(key) => key.expect("Unexpected end of stream"),
186
                Poll::Pending => {
187
                    // No more keys to try.
188
432
                    return Poll::Pending;
189
                }
190
            };
191
4650
            let Some(fut) = self_.futures.get_mut(&key) else {
192
                // No future for this key. Presumably because it was removed
193
                // from the map. Try the next key.
194
                continue;
195
            };
196
            // Poll the future itself, using our own waker that will notify us
197
            // that this key is ready.
198
4650
            let waker = std::task::Waker::from(Arc::new(KeyedWaker {
199
4650
                key: key.clone(),
200
4650
                sender: self_.notification_sender.clone(),
201
4650
            }));
202
4650
            match fut.poll_unpin(&mut std::task::Context::from_waker(&waker)) {
203
4412
                Poll::Ready(o) => {
204
                    // Remove and drop the future itself.
205
                    // We *could* return it along with the item, but this would
206
                    // be a departure from the interface of `FuturesUnordered`,
207
                    // and most futures are designed to be discarded after
208
                    // completion.
209
4412
                    self_.futures.remove(&key);
210

            
211
4412
                    return Poll::Ready(Some((key, o)));
212
                }
213
238
                Poll::Pending => {
214
238
                    // This future wasn't actually ready.
215
238
                    //
216
238
                    // This can happen, e.g. because:
217
238
                    // * This is our first time actually polling this future.
218
238
                    // * The futures waker was called spuriously.
219
238
                    // * This was actually a reused key, and we received the notification from
220
238
                    //   a waker for a previous future registered with this key.
221
238
                    //
222
238
                    // Move on to the next key.
223
238
                }
224
            }
225
        }
226
23588
    }
227
}
228

            
229
/// Error returned by [`KeyedFuturesUnordered::try_insert`].
230
#[derive(Debug, thiserror::Error)]
231
#[allow(clippy::exhaustive_structs)]
232
pub struct KeyAlreadyInsertedError<K, F> {
233
    /// Key that caller tried to insert.
234
    #[allow(dead_code)]
235
    pub key: K,
236
    /// Future that caller tried to insert.
237
    #[allow(dead_code)]
238
    pub fut: F,
239
}
240

            
241
#[cfg(test)]
242
mod tests {
243
    // @@ begin test lint list maintained by maint/add_warning @@
244
    #![allow(clippy::bool_assert_comparison)]
245
    #![allow(clippy::clone_on_copy)]
246
    #![allow(clippy::dbg_macro)]
247
    #![allow(clippy::mixed_attributes_style)]
248
    #![allow(clippy::print_stderr)]
249
    #![allow(clippy::print_stdout)]
250
    #![allow(clippy::single_char_pattern)]
251
    #![allow(clippy::unwrap_used)]
252
    #![allow(clippy::unchecked_time_subtraction)]
253
    #![allow(clippy::useless_vec)]
254
    #![allow(clippy::needless_pass_by_value)]
255
    #![allow(clippy::string_slice)] // See arti#2571
256
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
257

            
258
    use std::task::Waker;
259

            
260
    use futures::{StreamExt as _, executor::block_on, future::poll_fn};
261
    use oneshot_fused_workaround as oneshot;
262
    use tor_rtmock::MockRuntime;
263

            
264
    use super::*;
265

            
266
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
267
    struct Key(u64);
268

            
269
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
270
    struct Value(u64);
271

            
272
    /// Simple future for testing. Supports comparison, and can be mutated directly to become ready.
273
    #[derive(Debug, Clone)]
274
    struct ValueFut<V> {
275
        /// Value that will be produced when ready.
276
        value: Option<V>,
277
        /// Whether this is ready.
278
        // We use a distinct flag here instead of a None value so that pending
279
        // instances are still unequal if they have different values.
280
        ready: bool,
281
        // Waker
282
        waker: Option<Waker>,
283
    }
284

            
285
    impl<V> std::cmp::PartialEq for ValueFut<V>
286
    where
287
        V: std::cmp::PartialEq,
288
    {
289
        fn eq(&self, other: &Self) -> bool {
290
            // Ignores the waker, which isn't comparable
291
            self.value == other.value && self.ready == other.ready
292
        }
293
    }
294

            
295
    impl<V> std::cmp::Eq for ValueFut<V> where V: std::cmp::Eq {}
296

            
297
    impl<V> ValueFut<V> {
298
        fn ready(value: V) -> Self {
299
            Self {
300
                value: Some(value),
301
                ready: true,
302
                waker: None,
303
            }
304
        }
305
        fn pending(value: V) -> Self {
306
            Self {
307
                value: Some(value),
308
                ready: false,
309
                waker: None,
310
            }
311
        }
312
        fn make_ready(&mut self) {
313
            self.ready = true;
314
            if let Some(waker) = self.waker.take() {
315
                waker.wake();
316
            }
317
        }
318
    }
319

            
320
    impl<V> Future for ValueFut<V>
321
    where
322
        V: Unpin,
323
    {
324
        type Output = V;
325

            
326
        fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
327
            if !self.ready {
328
                self.waker.replace(cx.waker().clone());
329
                Poll::Pending
330
            } else {
331
                Poll::Ready(self.value.take().expect("Polled future after it was ready"))
332
            }
333
        }
334
    }
335

            
336
    #[test]
337
    fn test_empty() {
338
        block_on(poll_fn(|cx| {
339
            let mut kfu = KeyedFuturesUnordered::<Key, ValueFut<Value>>::new();
340

            
341
            // When there are no futures in the set (ready or pending), returns
342
            // `Poll::Ready(None)` as for `FuturesUnordered`.
343
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
344

            
345
            // Nothing to get.
346
            assert_eq!(kfu.get(&Key(0)), None);
347
            assert_eq!(kfu.get_mut(&Key(0)), None);
348

            
349
            Poll::Ready(())
350
        }));
351
    }
352

            
353
    #[test]
354
    fn test_one_pending_future() {
355
        block_on(poll_fn(|cx| {
356
            let mut kfu = KeyedFuturesUnordered::new();
357

            
358
            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
359

            
360
            // When there are futures in the set, but none are ready, returns
361
            // `Poll::Pending`, as for `FuturesUnordered`
362
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
363

            
364
            // State should be unchanged; same result if we poll again.
365
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
366

            
367
            // We should be able to get the future.
368
            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::pending(Value(0))));
369
            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::pending(Value(0))));
370

            
371
            Poll::Ready(())
372
        }));
373
    }
374

            
375
    #[test]
376
    fn test_one_ready_future() {
377
        block_on(poll_fn(|cx| {
378
            let mut kfu = KeyedFuturesUnordered::new();
379

            
380
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
381

            
382
            // Should be able to get the future before it's polled.
383
            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::ready(Value(1))));
384
            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::ready(Value(1))));
385

            
386
            // When there is a ready future, returns it.
387
            assert_eq!(
388
                kfu.poll_next_unpin(cx),
389
                Poll::Ready(Some((Key(0), Value(1))))
390
            );
391

            
392
            // After having returned the ready future, should be empty again.
393
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
394
            assert_eq!(kfu.get(&Key(0)), None);
395
            assert_eq!(kfu.get_mut(&Key(0)), None);
396

            
397
            Poll::Ready(())
398
        }));
399
    }
400

            
401
    #[test]
402
    fn test_one_pending_then_ready_future() {
403
        block_on(poll_fn(|cx| {
404
            let mut kfu = KeyedFuturesUnordered::new();
405
            let (send, recv) = oneshot::channel::<Value>();
406
            kfu.try_insert(Key(0), recv).unwrap();
407

            
408
            // Nothing ready yet.
409
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
410

            
411
            // Should be able to get it.
412
            assert!(kfu.get(&Key(0)).is_some());
413
            assert!(kfu.get_mut(&Key(0)).is_some());
414

            
415
            send.send(Value(1)).unwrap();
416

            
417
            // oneshot future should be ready.
418
            assert_eq!(
419
                kfu.poll_next_unpin(cx),
420
                Poll::Ready(Some((Key(0), Ok(Value(1)))))
421
            );
422

            
423
            // Empty again.
424
            assert!(kfu.get(&Key(0)).is_none());
425
            assert!(kfu.get_mut(&Key(0)).is_none());
426
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
427

            
428
            Poll::Ready(())
429
        }));
430
    }
431

            
432
    #[test]
433
    fn test_remove_pending() {
434
        block_on(poll_fn(|cx| {
435
            let mut kfu = KeyedFuturesUnordered::new();
436
            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
437
            assert_eq!(
438
                kfu.remove(&Key(0)),
439
                Some((Key(0), ValueFut::pending(Value(0))))
440
            );
441
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
442
            Poll::Ready(())
443
        }));
444
    }
445

            
446
    #[test]
447
    fn test_remove_ready() {
448
        block_on(poll_fn(|cx| {
449
            let mut kfu = KeyedFuturesUnordered::new();
450
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
451
            assert_eq!(
452
                kfu.remove(&Key(0)),
453
                Some((Key(0), ValueFut::ready(Value(1))))
454
            );
455
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
456
            Poll::Ready(())
457
        }));
458
    }
459

            
460
    #[test]
461
    fn test_remove_and_reuse_ready() {
462
        block_on(poll_fn(|cx| {
463
            let mut kfu = KeyedFuturesUnordered::new();
464
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
465
            assert_eq!(
466
                kfu.remove(&Key(0)),
467
                Some((Key(0), ValueFut::ready(Value(1))))
468
            );
469
            kfu.try_insert(Key(0), ValueFut::ready(Value(2))).unwrap();
470

            
471
            // We should get back *only* the second value.
472
            assert_eq!(
473
                kfu.poll_next_unpin(cx),
474
                Poll::Ready(Some((Key(0), Value(2))))
475
            );
476
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
477

            
478
            Poll::Ready(())
479
        }));
480
    }
481

            
482
    #[test]
483
    fn test_remove_and_reuse_pending_then_ready() {
484
        block_on(poll_fn(|cx| {
485
            let mut kfu = KeyedFuturesUnordered::new();
486
            kfu.try_insert(Key(0), ValueFut::pending(Value(1))).unwrap();
487
            let (_key, mut removed_value) = kfu.remove(&Key(0)).unwrap();
488
            kfu.try_insert(Key(0), ValueFut::pending(Value(2))).unwrap();
489

            
490
            // Make the *removed* future ready before polling again. This should
491
            // cause an internal spurious wakeup, but not be visible from the
492
            // user's perspective.
493
            removed_value.make_ready();
494
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
495

            
496
            // Make the future that we replaced it with become ready.
497
            kfu.get_mut(&Key(0)).unwrap().make_ready();
498

            
499
            // We should now get back *only* the second value.
500
            assert_eq!(
501
                kfu.poll_next_unpin(cx),
502
                Poll::Ready(Some((Key(0), Value(2))))
503
            );
504
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
505

            
506
            Poll::Ready(())
507
        }));
508
    }
509

            
510
    #[test]
511
    fn test_async() {
512
        MockRuntime::test_with_various(|rt| async move {
513
            let mut kfu = KeyedFuturesUnordered::new();
514

            
515
            for i in 0..10 {
516
                let (send, recv) = oneshot::channel();
517
                kfu.try_insert(Key(i), recv).unwrap();
518
                rt.spawn_identified(format!("sender-{i}"), async move {
519
                    send.send(Value(i)).unwrap();
520
                });
521
            }
522

            
523
            let values = kfu.collect::<Vec<_>>().await;
524
            let mut values = values
525
                .into_iter()
526
                .map(|(k, v)| (k, v.unwrap()))
527
                .collect::<Vec<_>>();
528
            values.sort();
529

            
530
            let expected_values = (0..10).map(|i| (Key(i), Value(i))).collect::<Vec<_>>();
531
            assert_eq!(values, expected_values);
532
        });
533
    }
534
}