1
//! [`StreamUnobtrusivePeeker`]
2
//!
3
//! The memory tracker needs a way to look at the next item of a stream
4
//! (if there is one, or there can immediately be one),
5
//! *without* getting involved with the async tasks.
6

            
7
use educe::Educe;
8
use futures::Stream;
9
use futures::stream::FusedStream;
10
use pin_project::pin_project;
11

            
12
use crate::peekable_stream::{PeekableStream, UnobtrusivePeekableStream};
13

            
14
use std::fmt::Debug;
15
use std::future::Future;
16
use std::pin::Pin;
17
use std::task::{Context, Poll, Poll::*, Waker};
18

            
19
/// Wraps [`Stream`] and provides `\[poll_]peek` and `unobtrusive_peek`
20
///
21
/// [`unobtrusive_peek`](StreamUnobtrusivePeeker::unobtrusive_peek)
22
/// is callable in sync contexts, outside the reading task.
23
///
24
/// Like [`futures::stream::Peekable`],
25
/// this has an async `peek` method, and `poll_peek`,
26
/// for use from the task that is also reading (via the [`Stream`] impl).
27
/// But, that type doesn't have `unobtrusive_peek`.
28
///
29
/// One way to conceptualise this is that `StreamUnobtrusivePeeker` is dual-ported:
30
/// the two sets of APIs, while provided on the same type,
31
/// are typically called from different contexts.
32
//
33
// It wasn't particularly easy to think of a good name for this type.
34
// We intend, probably:
35
//     struct StreamUnobtrusivePeeker
36
//     trait StreamUnobtrusivePeekable
37
//     trait StreamPeekable (impl for StreamUnobtrusivePeeker and futures::stream::Peekable)
38
//
39
// Searching a thesaurus produced these suggested words:
40
//     unobtrusive subtle discreet inconspicuous cautious furtive
41
// Asking in MR review also suggested
42
//     quick
43
//
44
// It's awkward because "peek" already has significant connotations of not disturbing things.
45
// That's why it was used in Iterator::peek.
46
//
47
// But when we translate this into async context,
48
// we have the poll_peek method on futures::stream::Peekable,
49
// which doesn't remove items from the stream,
50
// but *does* *wait* for items and therefore engages with the async context,
51
// and therefore involves *mutating* the Peekable (to store the new waker).
52
//
53
// Now we end up needing a word for an *even less disturbing* kind of interaction.
54
//
55
// `quick` (and synonyms) isn't quite right either because it's not necessarily faster,
56
// and certainly not more performant.
57
#[derive(Debug)]
58
#[pin_project(project = PeekerProj)]
59
pub struct StreamUnobtrusivePeeker<S: Stream> {
60
    /// An item that we have peeked.
61
    ///
62
    /// (If we peeked EOF, that's represented by `None` in inner.)
63
    buffered: Option<S::Item>,
64

            
65
    /// The `Waker` from the last time we were polled and returned `Pending`
66
    ///
67
    /// "polled" includes any of our `poll_` methods
68
    /// but *not* `unobtrusive_peek`.
69
    ///
70
    /// `None` if we haven't been polled, or the last poll returned `Ready`.
71
    poll_waker: Option<Waker>,
72

            
73
    /// The inner stream
74
    ///
75
    /// `None if it has yielded `None` meaning EOF.  We don't require S: FusedStream.
76
    #[pin]
77
    inner: Option<S>,
78
}
79

            
80
impl<S: Stream> StreamUnobtrusivePeeker<S> {
81
    /// Create a new `StreamUnobtrusivePeeker` from a `Stream`
82
4270
    pub fn new(inner: S) -> Self {
83
4270
        StreamUnobtrusivePeeker {
84
4270
            buffered: None,
85
4270
            poll_waker: None,
86
4270
            inner: Some(inner),
87
4270
        }
88
4270
    }
89
}
90

            
91
impl<S: Stream> UnobtrusivePeekableStream for StreamUnobtrusivePeeker<S> {
92
4310
    fn unobtrusive_peek_mut<'s>(mut self: Pin<&'s mut Self>) -> Option<&'s mut S::Item> {
93
        #[allow(clippy::question_mark)] // We use explicit control flow here for clarity
94
4310
        if self.as_mut().project().buffered.is_none() {
95
            // We don't have a buffered item, but the stream may have an item available.
96
            // We must poll it to find out.
97
            //
98
            // We need to pass a Context to poll_next.
99
            // inner may store this context, replacing one provided via poll_*.
100
            //
101
            // Despite that, we need to make sure that wakeups will happen as expected.
102
            // To achieve this we have retained a copy of the caller's Waker.
103
            //
104
            // When a future or stream returns Pending, it proposes to wake `waker`
105
            // when it wants to be polled again.
106
            //
107
            // We uphold that promise by
108
            // - only returning Pending from our poll methods if inner also returned Pending
109
            // - when one of our poll methods returns Pending, saving the caller-supplied
110
            //   waker, so that we can make the intermediate poll call here.
111
            //
112
            // If the inner poll returns Ready, inner no longer guarantees to wake anyone.
113
            // In principle, if our user is waiting (we returned Pending),
114
            // then inner ought to have called `wake` on the caller's `Waker`.
115
            // But I don't think we can guarantee that an executor won't defer a wakeup,
116
            // and respond to a dropped Waker by cancelling that wakeup;
117
            // or to put it another way, the wakeup might be "in flight" on entry,
118
            // but the call to inner's poll_next returning Ready
119
            // might somehow "cancel" the wakeup.
120
            //
121
            // So just to be sure, if we get a Ready here, we wake the stored waker.
122

            
123
130
            let mut self_ = self.as_mut().project();
124

            
125
130
            let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
126
58
                return None;
127
            };
128

            
129
72
            let waker = if let Some(waker) = self_.poll_waker.as_ref() {
130
                waker
131
            } else {
132
72
                Waker::noop()
133
            };
134

            
135
72
            match inner.poll_next(&mut Context::from_waker(waker)) {
136
24
                Pending => {}
137
48
                Ready(item_or_eof) => {
138
48
                    if let Some(waker) = self_.poll_waker.take() {
139
                        waker.wake();
140
48
                    }
141
48
                    match item_or_eof {
142
                        None => self_.inner.set(None),
143
48
                        Some(item) => *self_.buffered = Some(item),
144
                    }
145
                }
146
            };
147
4180
        }
148

            
149
4252
        self.project().buffered.as_mut()
150
4310
    }
151
}
152

            
153
impl<S: Stream> PeekableStream for StreamUnobtrusivePeeker<S> {
154
44
    fn poll_peek<'s>(self: Pin<&'s mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
155
44
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_ref())
156
44
    }
157

            
158
12778
    fn poll_peek_mut<'s>(
159
12778
        self: Pin<&'s mut Self>,
160
12778
        cx: &mut Context<'_>,
161
12778
    ) -> Poll<Option<&'s mut S::Item>> {
162
12778
        self.impl_poll_next_or_peek(cx, |buffered| buffered.as_mut())
163
12778
    }
164
}
165

            
166
impl<S: Stream> StreamUnobtrusivePeeker<S> {
167
    /// Implementation of `poll_{peek,next}`
168
    ///
169
    /// This takes care of
170
    ///   * examining the state of our buffer, and polling inner if needed
171
    ///   * ensuring that we store a waker, if needed
172
    ///   * dealing with some borrowck awkwardness
173
    ///
174
    /// The `Ready` value is always calculated from `buffer`.
175
    /// `return_value_obtainer` is called only if we are going to return `Ready`.
176
    /// It's given `buffer` and should either:
177
    ///   * [`take`](Option::take) the contained value (for `poll_next`)
178
    ///   * return a reference using [`Option::as_ref`] (for `poll_peek`)
179
40902
    fn impl_poll_next_or_peek<'s, R: 's>(
180
40902
        self: Pin<&'s mut Self>,
181
40902
        cx: &mut Context<'_>,
182
40902
        return_value_obtainer: impl FnOnce(&'s mut Option<S::Item>) -> Option<R>,
183
40902
    ) -> Poll<Option<R>> {
184
40902
        let mut self_ = self.project();
185
40902
        let r = Self::next_or_peek_inner(&mut self_, cx);
186
40902
        let r = r.map(|()| return_value_obtainer(self_.buffered));
187
40902
        Self::return_from_poll(self_.poll_waker, cx, r)
188
40902
    }
189

            
190
    /// Try to populate `buffer`, and calculate if we're `Ready`
191
    ///
192
    /// Returns `Ready` iff `poll_next` or `poll_peek` should return `Ready`.
193
    /// The actual `Ready` value (an `Option`) will be calculated later.
194
40902
    fn next_or_peek_inner(self_: &mut PeekerProj<S>, cx: &mut Context<'_>) -> Poll<()> {
195
40902
        if let Some(_item) = self_.buffered.as_ref() {
196
            // `return_value_obtainer` will find `Some` in `buffered`;
197
            // overall, we'll return `Ready(Some(..))`.
198
12612
            return Ready(());
199
28290
        }
200
28290
        let Some(inner) = self_.inner.as_mut().as_pin_mut() else {
201
            // `return_value_obtainer` will find `None` in `buffered`;
202
            // overall, we'll return `Ready(None)`, ie EOF.
203
198
            return Ready(());
204
        };
205
28092
        match inner.poll_next(cx) {
206
            Ready(None) => {
207
3432
                self_.inner.set(None);
208
                // `buffered` is `None`, still.
209
                // overall, we'll return `Ready(None)`, ie EOF.
210
3432
                Ready(())
211
            }
212
14218
            Ready(Some(item)) => {
213
14218
                *self_.buffered = Some(item);
214
                // return_value_obtainer` will find `Some` in `buffered`
215
14218
                Ready(())
216
            }
217
            Pending => {
218
                // `return_value_obtainer` won't be called.
219
                // overall, we'll return Pending
220
10442
                Pending
221
            }
222
        }
223
40902
    }
224

            
225
    /// Wait for an item to be ready, and then inspect it
226
    ///
227
    /// Equivalent to [`futures::stream::Peekable::peek`].
228
    ///
229
    /// # Tasks, waking, and calling context
230
    ///
231
    /// This should be called by the task that is reading from the stream.
232
    /// If it is called by another task, the reading task would miss notifications.
233
    //
234
    // This ^ docs section is triplicated for poll_peek, poll_peek_mut, and peek
235
    //
236
    // TODO this should be a method on the `PeekableStream` trait? Or a
237
    // `PeekableStreamExt` trait?
238
    // TODO should there be peek_mut ?
239
    #[allow(dead_code)] // TODO remove this allow if and when we make this module public
240
16
    pub fn peek(self: Pin<&mut Self>) -> PeekFuture<Self> {
241
16
        PeekFuture { peeker: Some(self) }
242
16
    }
243

            
244
    /// Return from a `poll_*` function, setting the stored waker appropriately
245
    ///
246
    /// Our `poll` functions always use this.
247
    /// The rule is that if a future returns `Pending`, it has stored the waker.
248
40902
    fn return_from_poll<R>(
249
40902
        poll_waker: &mut Option<Waker>,
250
40902
        cx: &mut Context<'_>,
251
40902
        r: Poll<R>,
252
40902
    ) -> Poll<R> {
253
40902
        *poll_waker = match &r {
254
            Ready(_) => {
255
                // No need to wake this task up any more.
256
30460
                None
257
            }
258
            Pending => {
259
                // try_peek must use the same waker to poll later
260
10442
                Some(cx.waker().clone())
261
            }
262
        };
263
40902
        r
264
40902
    }
265

            
266
    /// Obtain a raw reference to the inner stream
267
    ///
268
    /// ### Correctness!
269
    ///
270
    /// This method must be used with care!
271
    /// Whatever you do mustn't interfere with polling and peeking.
272
    /// Careless use can result in wrong behaviour including deadlocks.
273
3374
    pub fn as_raw_inner_pin_mut<'s>(self: Pin<&'s mut Self>) -> Option<Pin<&'s mut S>> {
274
3374
        self.project().inner.as_pin_mut()
275
3374
    }
276
}
277

            
278
impl<S: Stream> Stream for StreamUnobtrusivePeeker<S> {
279
    type Item = S::Item;
280

            
281
28080
    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
282
28080
        self.impl_poll_next_or_peek(cx, |buffered| buffered.take())
283
28080
    }
284

            
285
    fn size_hint(&self) -> (usize, Option<usize>) {
286
        let buf = self.buffered.iter().count();
287
        let (imin, imax) = match &self.inner {
288
            Some(inner) => inner.size_hint(),
289
            None => (0, Some(0)),
290
        };
291
        (imin + buf, imax.and_then(|imap| imap.checked_add(buf)))
292
    }
293
}
294

            
295
impl<S: Stream> FusedStream for StreamUnobtrusivePeeker<S> {
296
5582
    fn is_terminated(&self) -> bool {
297
5582
        self.buffered.is_none() && self.inner.is_none()
298
5582
    }
299
}
300

            
301
/// Future from [`StreamUnobtrusivePeeker::peek`]
302
// TODO: Move to tor_async_utils::peekable_stream.
303
#[derive(Educe)]
304
#[educe(Debug(bound("S: Debug")))]
305
#[must_use = "peek() return a Future, which does nothing unless awaited"]
306
pub struct PeekFuture<'s, S> {
307
    /// The underlying stream.
308
    ///
309
    /// `Some` until we have returned `Ready`, then `None`.
310
    /// See comment in `poll`.
311
    peeker: Option<Pin<&'s mut S>>,
312
}
313

            
314
impl<'s, S: PeekableStream> PeekFuture<'s, S> {
315
    /// Create a new `PeekFuture`.
316
    // TODO: replace with a trait method.
317
    pub fn new(stream: Pin<&'s mut S>) -> Self {
318
        Self {
319
            peeker: Some(stream),
320
        }
321
    }
322
}
323

            
324
impl<'s, S: PeekableStream> Future for PeekFuture<'s, S> {
325
    type Output = Option<&'s S::Item>;
326
28
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<&'s S::Item>> {
327
28
        let self_ = self.get_mut();
328
28
        let peeker = self_
329
28
            .peeker
330
28
            .as_mut()
331
28
            .expect("PeekFuture polled after Ready");
332
28
        match peeker.as_mut().poll_peek(cx) {
333
12
            Pending => return Pending,
334
16
            Ready(_y) => {
335
16
                // Ideally we would have returned `y` here, but it's borrowed from PeekFuture
336
16
                // not from the original StreamUnobtrusivePeeker, and there's no way
337
16
                // to get a value with the right lifetime.  (In non-async code,
338
16
                // this is usually handled by the special magic for reborrowing &mut.)
339
16
                //
340
16
                // So we must redo the poll, but this time consuming `peeker`,
341
16
                // which gets us the right lifetime.  That's why it has to be `Option`.
342
16
                // Because we own &mut ... Self, we know that repeating the poll
343
16
                // gives the same answer.
344
16
            }
345
        }
346
16
        let peeker = self_.peeker.take().expect("it was Some before!");
347
16
        let r = peeker.poll_peek(cx);
348
16
        assert!(r.is_ready(), "it was Ready before!");
349
16
        r
350
28
    }
351
}
352

            
353
#[cfg(test)]
354
mod test {
355
    // @@ begin test lint list maintained by maint/add_warning @@
356
    #![allow(clippy::bool_assert_comparison)]
357
    #![allow(clippy::clone_on_copy)]
358
    #![allow(clippy::dbg_macro)]
359
    #![allow(clippy::mixed_attributes_style)]
360
    #![allow(clippy::print_stderr)]
361
    #![allow(clippy::print_stdout)]
362
    #![allow(clippy::single_char_pattern)]
363
    #![allow(clippy::unwrap_used)]
364
    #![allow(clippy::unchecked_time_subtraction)]
365
    #![allow(clippy::useless_vec)]
366
    #![allow(clippy::needless_pass_by_value)]
367
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
368

            
369
    use super::*;
370
    use futures::channel::mpsc;
371
    use futures::{SinkExt as _, StreamExt as _};
372
    use std::pin::pin;
373
    use std::sync::{Arc, Mutex};
374
    use std::time::Duration;
375
    use tor_rtcompat::SleepProvider as _;
376
    use tor_rtmock::MockRuntime;
377

            
378
    fn ms(ms: u64) -> Duration {
379
        Duration::from_millis(ms)
380
    }
381

            
382
    #[test]
383
    fn wakeups() {
384
        MockRuntime::test_with_various(|rt| async move {
385
            let (mut tx, rx) = mpsc::unbounded();
386
            let ended = Arc::new(Mutex::new(false));
387

            
388
            rt.spawn_identified("rxr", {
389
                let rt = rt.clone();
390
                let ended = ended.clone();
391

            
392
                async move {
393
                    let rx = StreamUnobtrusivePeeker::new(rx);
394
                    let mut rx = pin!(rx);
395

            
396
                    let mut next = 0;
397
                    loop {
398
                        rt.sleep(ms(50)).await;
399
                        eprintln!("rx peek... ");
400
                        let peeked = rx.as_mut().unobtrusive_peek_mut();
401
                        eprintln!("rx peeked {peeked:?}");
402

            
403
                        if let Some(peeked) = peeked {
404
                            assert_eq!(*peeked, next);
405
                        }
406

            
407
                        rt.sleep(ms(50)).await;
408
                        eprintln!("rx next... ");
409
                        let eaten = rx.next().await;
410
                        eprintln!("rx eaten {eaten:?}");
411
                        if let Some(eaten) = eaten {
412
                            assert_eq!(eaten, next);
413
                            next += 1;
414
                        } else {
415
                            break;
416
                        }
417
                    }
418

            
419
                    *ended.lock().unwrap() = true;
420
                    eprintln!("rx ended");
421
                }
422
            });
423

            
424
            rt.spawn_identified("tx", {
425
                let rt = rt.clone();
426

            
427
                async move {
428
                    let mut numbers = 0..;
429
                    for wait in [125, 1, 125, 45, 1, 1, 1, 1000, 20, 1, 125, 125, 1000] {
430
                        eprintln!("tx sleep {wait}");
431
                        rt.sleep(ms(wait)).await;
432
                        let num = numbers.next().unwrap();
433
                        eprintln!("tx sending {num}");
434
                        tx.send(num).await.unwrap();
435
                    }
436

            
437
                    // This schedule arranges that, when we send EOF, the rx task
438
                    // has *peeked* rather than *polled* most recently,
439
                    // demonstrating that we can wake up the subsequent poll on EOF too.
440
                    eprintln!("tx final #1");
441
                    rt.sleep(ms(75)).await;
442
                    eprintln!("tx EOF");
443
                    drop(tx);
444
                    eprintln!("tx final #2");
445
                    rt.sleep(ms(10)).await;
446
                    assert!(!*ended.lock().unwrap());
447
                    eprintln!("tx final #3");
448
                    rt.sleep(ms(50)).await;
449
                    eprintln!("tx final #4");
450
                    assert!(*ended.lock().unwrap());
451
                }
452
            });
453

            
454
            rt.advance_until_stalled().await;
455
        });
456
    }
457

            
458
    #[test]
459
    fn poll_peek_paths() {
460
        MockRuntime::test_with_various(|rt| async move {
461
            let (mut tx, rx) = mpsc::unbounded();
462
            let ended = Arc::new(Mutex::new(false));
463

            
464
            rt.spawn_identified("rxr", {
465
                let rt = rt.clone();
466
                let ended = ended.clone();
467

            
468
                async move {
469
                    let rx = StreamUnobtrusivePeeker::new(rx);
470
                    let mut rx = pin!(rx);
471

            
472
                    while let Some(peeked) = rx.as_mut().peek().await.copied() {
473
                        eprintln!("rx peeked {peeked}");
474
                        let eaten = rx.next().await.unwrap();
475
                        eprintln!("rx eaten  {eaten}");
476
                        assert_eq!(peeked, eaten);
477
                        rt.sleep(ms(10)).await;
478
                        eprintln!("rx slept, peeking");
479
                    }
480
                    *ended.lock().unwrap() = true;
481
                    eprintln!("rx ended");
482
                }
483
            });
484

            
485
            rt.spawn_identified("tx", {
486
                let rt = rt.clone();
487

            
488
                async move {
489
                    let mut numbers = 0..;
490

            
491
                    // macro because we don't have proper async closures
492
                    macro_rules! send { {} => {
493
                        let num = numbers.next().unwrap();
494
                        eprintln!("tx send   {num}");
495
                        tx.send(num).await.unwrap();
496
                    } }
497

            
498
                    eprintln!("tx starting");
499
                    rt.sleep(ms(100)).await;
500
                    send!();
501
                    rt.sleep(ms(100)).await;
502
                    send!();
503
                    send!();
504
                    rt.sleep(ms(100)).await;
505
                    eprintln!("tx dropping");
506
                    drop(tx);
507
                    rt.sleep(ms(5)).await;
508
                    eprintln!("tx ending");
509
                    assert!(*ended.lock().unwrap());
510
                }
511
            });
512

            
513
            rt.advance_until_stalled().await;
514
        });
515
    }
516
}