1
//! Provides [`KeyedFuturesUnordered`]
2

            
3
// So that we can declare these things as if they were in their own crate.
4
#![allow(unreachable_pub)]
5

            
6
use std::{
7
    collections::{HashMap, hash_map},
8
    hash::Hash,
9
    pin::Pin,
10
    sync::Arc,
11
    task::Poll,
12
};
13

            
14
use futures::future::FutureExt;
15
use futures::{
16
    Future,
17
    channel::mpsc::{UnboundedReceiver, UnboundedSender},
18
};
19
use pin_project::pin_project;
20

            
21
/// Waker for internal use in [`KeyedFuturesUnordered`]
22
///
23
/// When woken, it notifies the parent [`KeyedFuturesUnordered`] that the future
24
/// for a corresponding key is ready to be polled.
25
struct KeyedWaker<K> {
26
    /// The key associated with this waker.
27
    key: K,
28
    /// Sender cloned from the parent [`KeyedFuturesUnordered`].
29
    sender: UnboundedSender<K>,
30
}
31

            
32
impl<K> std::task::Wake for KeyedWaker<K>
33
where
34
    K: Clone,
35
{
36
236
    fn wake(self: Arc<Self>) {
37
236
        self.sender
38
236
            .unbounded_send(self.key.clone())
39
236
            .unwrap_or_else(|e| {
40
4
                if e.is_disconnected() {
41
                    // Other side has disappeared. Can safely ignore.
42
4
                    return;
43
                }
44
                // Shouldn't happen, but probably no need to `panic`.
45
                tracing::error!("Bug: Unexpected send error: {e:?}");
46
4
            });
47
236
    }
48
}
49

            
50
/// Efficiently manages a dynamic set of futures as per
51
/// [`futures::stream::FuturesUnordered`]. Unlike `FuturesUnordered`, each future
52
/// has an associated key. This key is returned along with the future's output,
53
/// and can be used to cancel and *remove* a future from the set.
54
///
55
/// Implements [`futures::Stream`], producing a stream of completed futures and
56
/// their associated keys.
57
///
58
/// # Stream behavior
59
///
60
/// `Stream::poll_next` returns:
61
/// * `Poll::Ready(None)` if there are no futures managed by this object.
62
/// * `Poll::Ready(Some((key, output)))` with the key and output of a ready
63
///    future when there is one.
64
/// * `Poll::Pending` when there are futures managed by this object, but none
65
///    are currently ready.
66
///
67
/// Unlike for a generic `Stream`, it *is* permitted to call `poll_next` again
68
/// after having received `Poll::Ready(None)`. It will still behave as above
69
/// (i.e. returning `Pending` or `Ready` if futures have since been inserted).
70
#[derive(Debug)]
71
#[pin_project]
72
pub struct KeyedFuturesUnordered<K, F>
73
where
74
    F: Future,
75
{
76
    /// Receiver on which we're notified of keys that are ready to be polled.
77
    #[pin]
78
    notification_receiver: UnboundedReceiver<K>,
79
    /// Sender on which to notify `notifications_receiver` that keys are ready
80
    /// to be polled.
81
    // In particular, keys are sent here:
82
    // * When a future is inserted.
83
    // * In `KeyedWaker`, which is the `Waker` we register with futures when we
84
    //   poll them internally.
85
    notification_sender: UnboundedSender<K>,
86
    /// Map of pending futures.
87
    futures: HashMap<K, F>,
88
}
89

            
90
impl<K, F> KeyedFuturesUnordered<K, F>
91
where
92
    F: Future,
93
    K: Eq + Hash + Clone,
94
{
95
    /// Create an empty [`KeyedFuturesUnordered`].
96
1064
    pub fn new() -> Self {
97
1064
        let (send, recv) = futures::channel::mpsc::unbounded();
98
1064
        Self {
99
1064
            notification_sender: send,
100
1064
            notification_receiver: recv,
101
1064
            futures: Default::default(),
102
1064
        }
103
1064
    }
104

            
105
    /// Insert a future and associate it with `key`. Return an error if there is already an entry for `key`.
106
4748
    pub fn try_insert(&mut self, key: K, fut: F) -> Result<(), KeyAlreadyInsertedError<K, F>> {
107
4748
        let hash_map::Entry::Vacant(v) = self.futures.entry(key.clone()) else {
108
            // Key is already present.
109
            return Err(KeyAlreadyInsertedError { key, fut });
110
        };
111
4748
        v.insert(fut);
112
        // Immediately "notify" ourselves, to enqueue this key to be polled.
113
4748
        self.notification_sender
114
4748
            .unbounded_send(key)
115
            // * Since the sender is unbounded, can't fail due to fullness.
116
            // * Since we have our own copy of the receiver, can't be disconnected.
117
4748
            .expect("Unbounded send unexpectedly failed");
118
4748
        Ok(())
119
4748
    }
120

            
121
    /// Remove the entry for `key`, if any, and return the corresponding future.
122
118
    pub fn remove(&mut self, key: &K) -> Option<(K, F)> {
123
118
        self.futures.remove_entry(key)
124
118
    }
125

            
126
    /// Get the future corresponding to `key`, if any.
127
    ///
128
    /// As for [`Self::get_mut`], removing or replacing its [`std::task::Waker`]
129
    /// without waking it (e.g. using internal mutability) results in
130
    /// unspecified (but sound) behavior.
131
    #[allow(dead_code)]
132
16
    pub fn get<'a>(&'a self, key: &K) -> Option<&'a F> {
133
16
        self.futures.get(key)
134
16
    }
135

            
136
    /// Get the future corresponding to `key`, if any.
137
    ///
138
    /// The future should not be `poll`d, nor its registered
139
    /// [`std::task::Waker`] otherwise removed or replaced (unless it is also
140
    /// woken; see below). The result of doing either is unspecified (but
141
    /// sound).
142
    ///
143
    /// This method is useful primarily when the future has other functionality
144
    /// or data bundled with it besides its implementation of the `Future`
145
    /// trait, though it *is* permitted to mutate the object in a way that
146
    /// causes it to become ready (i.e. wakes and discards its registered
147
    /// [`std::task::Waker`]`), or become unready (cause its next poll result to
148
    /// be `Poll::Pending` when it otherwise would have been `Poll::Ready` and
149
    /// may have already woken its registered `Waker`).
150
    //
151
    // More specifically:
152
    // * If the waker is lost without being woken, we'll never
153
    //   poll this future again.
154
    // * If our waker is woken *and* the caller polls the future to completion,
155
    //   we could end up polling it again after completion,
156
    //   breaking the `Future` contract.
157
    #[allow(dead_code)]
158
8710
    pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<&'a mut F> {
159
8710
        self.futures.get_mut(key)
160
8710
    }
161
}
162

            
163
impl<K, F> futures::Stream for KeyedFuturesUnordered<K, F>
164
where
165
    F: Future + Unpin,
166
    K: Clone + Hash + Eq + Send + Sync + 'static,
167
{
168
    type Item = (K, F::Output);
169

            
170
23622
    fn poll_next(
171
23622
        self: Pin<&mut Self>,
172
23622
        cx: &mut std::task::Context<'_>,
173
23622
    ) -> Poll<Option<Self::Item>> {
174
23622
        if self.futures.is_empty() {
175
            // Follow precedent of `FuturesUnordered` of returning None in this case.
176
            // TODO: Consider breaking this precedent? This behavior is a bit
177
            // odd, since the documentation of the Stream trait indicates that a
178
            // stream shouldn't be polled again after returning None.
179
18758
            return Poll::Ready(None);
180
4864
        }
181
4864
        let mut self_ = self.project();
182
        loop {
183
            // Get the next pollable future, registering the caller's waker.
184
5116
            let key = match self_.notification_receiver.as_mut().poll_next(cx) {
185
4662
                Poll::Ready(key) => key.expect("Unexpected end of stream"),
186
                Poll::Pending => {
187
                    // No more keys to try.
188
454
                    return Poll::Pending;
189
                }
190
            };
191
4662
            let Some(fut) = self_.futures.get_mut(&key) else {
192
                // No future for this key. Presumably because it was removed
193
                // from the map. Try the next key.
194
                continue;
195
            };
196
            // Poll the future itself, using our own waker that will notify us
197
            // that this key is ready.
198
4662
            let waker = std::task::Waker::from(Arc::new(KeyedWaker {
199
4662
                key: key.clone(),
200
4662
                sender: self_.notification_sender.clone(),
201
4662
            }));
202
4662
            match fut.poll_unpin(&mut std::task::Context::from_waker(&waker)) {
203
4410
                Poll::Ready(o) => {
204
                    // Remove and drop the future itself.
205
                    // We *could* return it along with the item, but this would
206
                    // be a departure from the interface of `FuturesUnordered`,
207
                    // and most futures are designed to be discarded after
208
                    // completion.
209
4410
                    self_.futures.remove(&key);
210

            
211
4410
                    return Poll::Ready(Some((key, o)));
212
                }
213
252
                Poll::Pending => {
214
252
                    // This future wasn't actually ready.
215
252
                    //
216
252
                    // This can happen, e.g. because:
217
252
                    // * This is our first time actually polling this future.
218
252
                    // * The futures waker was called spuriously.
219
252
                    // * This was actually a reused key, and we received the notification from
220
252
                    //   a waker for a previous future registered with this key.
221
252
                    //
222
252
                    // Move on to the next key.
223
252
                }
224
            }
225
        }
226
23622
    }
227
}
228

            
229
/// Error returned by [`KeyedFuturesUnordered::try_insert`].
230
#[derive(Debug, thiserror::Error)]
231
#[allow(clippy::exhaustive_structs)]
232
pub struct KeyAlreadyInsertedError<K, F> {
233
    /// Key that caller tried to insert.
234
    #[allow(dead_code)]
235
    pub key: K,
236
    /// Future that caller tried to insert.
237
    #[allow(dead_code)]
238
    pub fut: F,
239
}
240

            
241
#[cfg(test)]
242
mod tests {
243
    // @@ begin test lint list maintained by maint/add_warning @@
244
    #![allow(clippy::bool_assert_comparison)]
245
    #![allow(clippy::clone_on_copy)]
246
    #![allow(clippy::dbg_macro)]
247
    #![allow(clippy::mixed_attributes_style)]
248
    #![allow(clippy::print_stderr)]
249
    #![allow(clippy::print_stdout)]
250
    #![allow(clippy::single_char_pattern)]
251
    #![allow(clippy::unwrap_used)]
252
    #![allow(clippy::unchecked_time_subtraction)]
253
    #![allow(clippy::useless_vec)]
254
    #![allow(clippy::needless_pass_by_value)]
255
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
256

            
257
    use std::task::Waker;
258

            
259
    use futures::{StreamExt as _, executor::block_on, future::poll_fn};
260
    use oneshot_fused_workaround as oneshot;
261
    use tor_rtmock::MockRuntime;
262

            
263
    use super::*;
264

            
265
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
266
    struct Key(u64);
267

            
268
    #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
269
    struct Value(u64);
270

            
271
    /// Simple future for testing. Supports comparison, and can be mutated directly to become ready.
272
    #[derive(Debug, Clone)]
273
    struct ValueFut<V> {
274
        /// Value that will be produced when ready.
275
        value: Option<V>,
276
        /// Whether this is ready.
277
        // We use a distinct flag here instead of a None value so that pending
278
        // instances are still unequal if they have different values.
279
        ready: bool,
280
        // Waker
281
        waker: Option<Waker>,
282
    }
283

            
284
    impl<V> std::cmp::PartialEq for ValueFut<V>
285
    where
286
        V: std::cmp::PartialEq,
287
    {
288
        fn eq(&self, other: &Self) -> bool {
289
            // Ignores the waker, which isn't comparable
290
            self.value == other.value && self.ready == other.ready
291
        }
292
    }
293

            
294
    impl<V> std::cmp::Eq for ValueFut<V> where V: std::cmp::Eq {}
295

            
296
    impl<V> ValueFut<V> {
297
        fn ready(value: V) -> Self {
298
            Self {
299
                value: Some(value),
300
                ready: true,
301
                waker: None,
302
            }
303
        }
304
        fn pending(value: V) -> Self {
305
            Self {
306
                value: Some(value),
307
                ready: false,
308
                waker: None,
309
            }
310
        }
311
        fn make_ready(&mut self) {
312
            self.ready = true;
313
            if let Some(waker) = self.waker.take() {
314
                waker.wake();
315
            }
316
        }
317
    }
318

            
319
    impl<V> Future for ValueFut<V>
320
    where
321
        V: Unpin,
322
    {
323
        type Output = V;
324

            
325
        fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
326
            if !self.ready {
327
                self.waker.replace(cx.waker().clone());
328
                Poll::Pending
329
            } else {
330
                Poll::Ready(self.value.take().expect("Polled future after it was ready"))
331
            }
332
        }
333
    }
334

            
335
    #[test]
336
    fn test_empty() {
337
        block_on(poll_fn(|cx| {
338
            let mut kfu = KeyedFuturesUnordered::<Key, ValueFut<Value>>::new();
339

            
340
            // When there are no futures in the set (ready or pending), returns
341
            // `Poll::Ready(None)` as for `FuturesUnordered`.
342
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
343

            
344
            // Nothing to get.
345
            assert_eq!(kfu.get(&Key(0)), None);
346
            assert_eq!(kfu.get_mut(&Key(0)), None);
347

            
348
            Poll::Ready(())
349
        }));
350
    }
351

            
352
    #[test]
353
    fn test_one_pending_future() {
354
        block_on(poll_fn(|cx| {
355
            let mut kfu = KeyedFuturesUnordered::new();
356

            
357
            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
358

            
359
            // When there are futures in the set, but none are ready, returns
360
            // `Poll::Pending`, as for `FuturesUnordered`
361
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
362

            
363
            // State should be unchanged; same result if we poll again.
364
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
365

            
366
            // We should be able to get the future.
367
            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::pending(Value(0))));
368
            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::pending(Value(0))));
369

            
370
            Poll::Ready(())
371
        }));
372
    }
373

            
374
    #[test]
375
    fn test_one_ready_future() {
376
        block_on(poll_fn(|cx| {
377
            let mut kfu = KeyedFuturesUnordered::new();
378

            
379
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
380

            
381
            // Should be able to get the future before it's polled.
382
            assert_eq!(kfu.get(&Key(0)), Some(&ValueFut::ready(Value(1))));
383
            assert_eq!(kfu.get_mut(&Key(0)), Some(&mut ValueFut::ready(Value(1))));
384

            
385
            // When there is a ready future, returns it.
386
            assert_eq!(
387
                kfu.poll_next_unpin(cx),
388
                Poll::Ready(Some((Key(0), Value(1))))
389
            );
390

            
391
            // After having returned the ready future, should be empty again.
392
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
393
            assert_eq!(kfu.get(&Key(0)), None);
394
            assert_eq!(kfu.get_mut(&Key(0)), None);
395

            
396
            Poll::Ready(())
397
        }));
398
    }
399

            
400
    #[test]
401
    fn test_one_pending_then_ready_future() {
402
        block_on(poll_fn(|cx| {
403
            let mut kfu = KeyedFuturesUnordered::new();
404
            let (send, recv) = oneshot::channel::<Value>();
405
            kfu.try_insert(Key(0), recv).unwrap();
406

            
407
            // Nothing ready yet.
408
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
409

            
410
            // Should be able to get it.
411
            assert!(kfu.get(&Key(0)).is_some());
412
            assert!(kfu.get_mut(&Key(0)).is_some());
413

            
414
            send.send(Value(1)).unwrap();
415

            
416
            // oneshot future should be ready.
417
            assert_eq!(
418
                kfu.poll_next_unpin(cx),
419
                Poll::Ready(Some((Key(0), Ok(Value(1)))))
420
            );
421

            
422
            // Empty again.
423
            assert!(kfu.get(&Key(0)).is_none());
424
            assert!(kfu.get_mut(&Key(0)).is_none());
425
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
426

            
427
            Poll::Ready(())
428
        }));
429
    }
430

            
431
    #[test]
432
    fn test_remove_pending() {
433
        block_on(poll_fn(|cx| {
434
            let mut kfu = KeyedFuturesUnordered::new();
435
            kfu.try_insert(Key(0), ValueFut::pending(Value(0))).unwrap();
436
            assert_eq!(
437
                kfu.remove(&Key(0)),
438
                Some((Key(0), ValueFut::pending(Value(0))))
439
            );
440
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
441
            Poll::Ready(())
442
        }));
443
    }
444

            
445
    #[test]
446
    fn test_remove_ready() {
447
        block_on(poll_fn(|cx| {
448
            let mut kfu = KeyedFuturesUnordered::new();
449
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
450
            assert_eq!(
451
                kfu.remove(&Key(0)),
452
                Some((Key(0), ValueFut::ready(Value(1))))
453
            );
454
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
455
            Poll::Ready(())
456
        }));
457
    }
458

            
459
    #[test]
460
    fn test_remove_and_reuse_ready() {
461
        block_on(poll_fn(|cx| {
462
            let mut kfu = KeyedFuturesUnordered::new();
463
            kfu.try_insert(Key(0), ValueFut::ready(Value(1))).unwrap();
464
            assert_eq!(
465
                kfu.remove(&Key(0)),
466
                Some((Key(0), ValueFut::ready(Value(1))))
467
            );
468
            kfu.try_insert(Key(0), ValueFut::ready(Value(2))).unwrap();
469

            
470
            // We should get back *only* the second value.
471
            assert_eq!(
472
                kfu.poll_next_unpin(cx),
473
                Poll::Ready(Some((Key(0), Value(2))))
474
            );
475
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
476

            
477
            Poll::Ready(())
478
        }));
479
    }
480

            
481
    #[test]
482
    fn test_remove_and_reuse_pending_then_ready() {
483
        block_on(poll_fn(|cx| {
484
            let mut kfu = KeyedFuturesUnordered::new();
485
            kfu.try_insert(Key(0), ValueFut::pending(Value(1))).unwrap();
486
            let (_key, mut removed_value) = kfu.remove(&Key(0)).unwrap();
487
            kfu.try_insert(Key(0), ValueFut::pending(Value(2))).unwrap();
488

            
489
            // Make the *removed* future ready before polling again. This should
490
            // cause an internal spurious wakeup, but not be visible from the
491
            // user's perspective.
492
            removed_value.make_ready();
493
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Pending);
494

            
495
            // Make the future that we replaced it with become ready.
496
            kfu.get_mut(&Key(0)).unwrap().make_ready();
497

            
498
            // We should now get back *only* the second value.
499
            assert_eq!(
500
                kfu.poll_next_unpin(cx),
501
                Poll::Ready(Some((Key(0), Value(2))))
502
            );
503
            assert_eq!(kfu.poll_next_unpin(cx), Poll::Ready(None));
504

            
505
            Poll::Ready(())
506
        }));
507
    }
508

            
509
    #[test]
510
    fn test_async() {
511
        MockRuntime::test_with_various(|rt| async move {
512
            let mut kfu = KeyedFuturesUnordered::new();
513

            
514
            for i in 0..10 {
515
                let (send, recv) = oneshot::channel();
516
                kfu.try_insert(Key(i), recv).unwrap();
517
                rt.spawn_identified(format!("sender-{i}"), async move {
518
                    send.send(Value(i)).unwrap();
519
                });
520
            }
521

            
522
            let values = kfu.collect::<Vec<_>>().await;
523
            let mut values = values
524
                .into_iter()
525
                .map(|(k, v)| (k, v.unwrap()))
526
                .collect::<Vec<_>>();
527
            values.sort();
528

            
529
            let expected_values = (0..10).map(|i| (Key(i), Value(i))).collect::<Vec<_>>();
530
            assert_eq!(values, expected_values);
531
        });
532
    }
533
}