1
//! Extension trait for using [`Sink`] more safely.
2

            
3
use std::future::Future;
4
use std::marker::PhantomData;
5
use std::pin::Pin;
6
use std::task::{Context, Poll};
7

            
8
use futures::Sink;
9
use futures::future::FusedFuture;
10
use futures::ready;
11
use pin_project::pin_project;
12

            
13
/// Switch to the nontrivial version of this, to get debugging output on stderr
14
macro_rules! dprintln { { $f:literal $($a:tt)* } => { () } }
15
//macro_rules! dprintln { { $f:literal $($a:tt)* } => { eprintln!(concat!("    ",$f) $($a)*) } }
16

            
17
/// Extension trait for [`Sink`] to add a method for cancel-safe usage.
18
pub trait SinkPrepareExt<'w, OS, OM>
19
where
20
    OS: Sink<OM>,
21
{
22
    /// For processing an item obtained from a future, avoiding async cancel lossage
23
    ///
24
    /// ```
25
    /// # use futures::channel::mpsc;
26
    /// # use tor_async_utils::SinkPrepareExt as _;
27
    /// #
28
    /// # #[tokio::main]
29
    /// # async fn main() -> Result<(),mpsc::SendError> {
30
    /// #   let (mut sink, sink_r) = mpsc::unbounded::<usize>();
31
    /// #   let message_generator_future = futures::future::ready(42);
32
    /// #   let process_message = |m| Ok::<_,mpsc::SendError>(m);
33
    ///     let (message, sendable) = sink.prepare_send_from(
34
    ///         message_generator_future
35
    ///     ).await?;
36
    ///     let message = process_message(message)?;
37
    ///     sendable.send(message);
38
    /// #   Ok(())
39
    /// # }
40
    /// ```
41
    ///
42
    /// Prepares to send a output message[^terminology] `OM` to an output sink `OS` (`self`),
43
    /// where the `OM` is made from an input message `IM`,
44
    /// and the `IM` is obtained from a future, `generator: IF`.
45
    ///
46
    /// [^terminology]: We sometimes use slightly inconsistent terminology,
47
    /// "item" vs "message".
48
    /// This avoids having to have the generic parameters by named `OI` and `II`
49
    /// where `I` is sometimes "item" and sometimes "input".
50
    ///
51
    /// When successfully run, `prepare_send_from` gives `(IM, SinkSendable)`.
52
    ///
53
    /// After processing `IM` into `OM`,
54
    /// use the [`SinkSendable`] to [`send`](SinkSendable::send) the `OM` to `OS`.
55
    ///
56
    /// # Why use this
57
    ///
58
    /// This avoids the an async cancellation hazard
59
    /// which exists with naive use of `select!`
60
    /// followed by `OS.send().await`.  You might write this:
61
    ///
62
    /// ```rust,ignore
63
    /// select!{
64
    ///     message = input_stream.next() => {
65
    ///         if let Some(message) = message {
66
    ///             let message = do_our_processing(message);
67
    ///             output_sink(message).await; // <---**BUG**
68
    ///         }
69
    ///     }
70
    ///     control = something_else() => { .. }
71
    /// }
72
    /// ```
73
    ///
74
    /// If, when we reach `BUG`, the output sink is not ready to receive the message,
75
    /// the future for that particular `select!` branch will be suspended.
76
    /// But when `select!` finds that *any one* of the branches is ready,
77
    /// it *drops* the futures for the other branches.
78
    /// That drops all the local variables, including possibly `message`, losing it.
79
    ///
80
    /// For more about cancellation safety, see
81
    /// [Rust for the Polyglot Programmer](https://www.chiark.greenend.org.uk/~ianmdlvl/rust-polyglot/async.html#cancellation-safety)
82
    /// which has a general summary, and
83
    /// Matthias Einwag's
84
    /// [extensive discussion in his gist](https://gist.github.com/Matthias247/ffc0f189742abf6aa41a226fe07398a8#cancellation-in-async-rust)
85
    /// with comparisons to other languages.
86
    ///
87
    /// ## Alternatives
88
    ///
89
    /// Unbounded mpsc channels, and certain other primitives,
90
    /// do not suffer from this problem because they do not block.
91
    /// `UnboundedSender` offers
92
    /// [`unbounded_send`](futures::channel::mpsc::UnboundedSender::unbounded_send)
93
    /// but only as an inherent method, so this does not compose with `Sink` combinators.
94
    /// And of course unbounded channels do not implement any backpressure.
95
    ///
96
    /// The problem can otherwise be avoided by completely eschewing use of `select!`
97
    /// and writing manual implementations of `Future`, `Sink`, and so on,
98
    /// However, such code is typically considerably more complex and involves
99
    /// entangling the primary logic with future machinery.
100
    /// It is normally better to write primary functionality in `async { }`
101
    /// using utilities (often "futures combinators") such as this one.
102
    ///
103
    // Personal note from @Diziet:
104
    // IMO it is generally accepted in the Rust community that
105
    // it is not good practice to write principal code at the manual futues level.
106
    // However, I have not been able to find very clear support for this proposition.
107
    // There are endless articles explaining how futures work internally,
108
    // often by describing how to reimplement standard combinators such as `map`.
109
    // ISTM that these exist to help understanding,
110
    // but it seems to be only rarely stated that doing this is not generally a good idea.
111
    //
112
    // I did find the following:
113
    //
114
    //  https://dev.to/mindflavor/rust-futures-an-uneducated-short-and-hopefully-not-boring-tutorial---part-4---a-real-future-from-scratch-734#conclusion
115
    //
116
    //    Of course you generally do not write a future manually. You use the ones provided by
117
    //    libraries and compose them as needed. It's important to understand how they work
118
    //    nevertheless.
119
    //
120
    // And of curse the existence of the `futures` crate is indicative:
121
    // it consists almost entirely of combinators and utilities
122
    // whose purpose is to allow you to write many structures in async code
123
    // without needing to resort to manual future impls.
124
    //
125
    /// # Example
126
    ///
127
    /// This comprehensive example demonstrates how to read from possibly multiple sources
128
    /// and also be able to process other events:
129
    ///
130
    /// ```
131
    /// # #[tokio::main]
132
    /// # async fn main() {
133
    /// use futures::select;
134
    /// use futures::{SinkExt as _, StreamExt as _};
135
    /// use tor_async_utils::SinkPrepareExt as _;
136
    ///
137
    /// let (mut input_w, mut input_r) = futures::channel::mpsc::unbounded::<usize>();
138
    /// let (mut output_w, mut output_r) = futures::channel::mpsc::unbounded::<String>();
139
    /// input_w.send(42).await;
140
    /// select!{
141
    ///     ret = output_w.prepare_send_from(async {
142
    ///         select!{
143
    ///             got_input = input_r.next() => got_input.expect("input stream ended!"),
144
    ///             () = futures::future::pending() => panic!(), // other branches are OK here
145
    ///         }
146
    ///     }) => {
147
    ///         let (input_msg, sendable) = ret.unwrap();
148
    ///         let output_msg = input_msg.to_string();
149
    ///         let () = sendable.send(output_msg).unwrap();
150
    ///     },
151
    ///     () = futures::future::pending() => panic!(), // other branches are OK here
152
    /// }
153
    ///
154
    /// assert_eq!(output_r.next().await.unwrap(), "42");
155
    /// # }
156
    /// ```
157
    ///
158
    /// # Formally
159
    ///
160
    /// [`prepare_send_from`](SinkPrepareExt::prepare_send_from)
161
    /// returns a [`SinkPrepareSendFuture`] which, when awaited:
162
    ///
163
    ///  * Waits for `OS` to be ready to receive an item.
164
    ///  * Runs `message_generator` to obtain a `IM`.
165
    ///  * Returns the `IM` (for processing), and a [`SinkSendable`].
166
    ///
167
    /// The caller should then:
168
    ///
169
    ///  * Check the error from `prepare_send_from`
170
    ///    (which came from the *output* sink).
171
    ///  * Process the `IM`, making an `OM` out of it.
172
    ///  * Call [`sendable.send()`](SinkSendable::send) (and check its error).
173
    ///
174
    /// # Flushing
175
    ///
176
    /// `prepare_send_from` will (when awaited)
177
    /// [`flush`](futures::SinkExt::flush) the output sink
178
    /// when it finds the input is not ready yet.
179
    /// Until then items may be buffered
180
    /// (as if they had been written with [`feed`](futures::SinkExt::feed)).
181
    ///
182
    /// # Errors
183
    ///
184
    /// ## Output sink errors
185
    ///
186
    /// The call site can experience output sink errors in two places,
187
    /// [`prepare_send_from()`](SinkPrepareExt::prepare_send_from) and [`SinkSendable::send()`].
188
    /// The caller should typically handle them the same way regardless of when they occurred.
189
    ///
190
    /// If the error happens at [`SinkSendable::send()`],
191
    /// the call site will usually be forced to discard the item being processed.
192
    /// This will only occur if the sink is actually broken.
193
    ///
194
    /// ## Errors specific to the call site: faillible input, and fallible processing
195
    ///
196
    /// At some call sites, the input future may yield errors
197
    /// (perhaps it is reading from a `Stream` of [`Result`]s).
198
    /// in that case the value from the input future will be a [`Result`].
199
    /// Then `IM` is a `Result`, and is provided in the `.0` element
200
    /// of the "successful" return from `prepare_send_from`.
201
    ///
202
    /// And, at some call sites, the processing of an `IM` into an `OM` is fallible.
203
    ///
204
    /// Handling these latter two error caess is up to the caller,
205
    /// in the code which processes `IM`.
206
    /// The call site will often want to deal with such an error
207
    /// without sending anything into the output sink,
208
    /// and can then just drop the [`SinkSendable`].
209
    ///
210
    /// # Implementations
211
    ///
212
    /// This is an extension trait and you are not expected to need to implement it.
213
    ///
214
    /// There are provided implementations for `Pin<&mut impl Sink>`
215
    /// and `&mut impl Sink + Unpin`, for your convenience.
216
    fn prepare_send_from<IF, IM>(
217
        self,
218
        message_generator: IF,
219
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
220
    where
221
        IF: Future<Output = IM>;
222
}
223

            
224
impl<'w, OS, OM> SinkPrepareExt<'w, OS, OM> for Pin<&'w mut OS>
225
where
226
    OS: Sink<OM>,
227
{
228
5290
    fn prepare_send_from<'r, IF, IM>(
229
5290
        self,
230
5290
        message_generator: IF,
231
5290
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
232
5290
    where
233
5290
        IF: Future<Output = IM>,
234
    {
235
5290
        SinkPrepareSendFuture {
236
5290
            output: Some(self),
237
5290
            generator: message_generator,
238
5290
            tw: PhantomData,
239
5290
        }
240
5290
    }
241
}
242

            
243
impl<'w, OS, OM> SinkPrepareExt<'w, OS, OM> for &'w mut OS
244
where
245
    OS: Sink<OM> + Unpin,
246
{
247
5290
    fn prepare_send_from<'r, IF, IM>(
248
5290
        self,
249
5290
        message_generator: IF,
250
5290
    ) -> SinkPrepareSendFuture<'w, IF, OS, OM>
251
5290
    where
252
5290
        IF: Future<Output = IM>,
253
    {
254
5290
        Pin::new(self).prepare_send_from(message_generator)
255
5290
    }
256
}
257

            
258
/// Future for `SinkPrepareExt::prepare_send_from`
259
#[pin_project]
260
#[must_use]
261
pub struct SinkPrepareSendFuture<'w, IF, OS, OM> {
262
    /// Underlying future that will yield a message.
263
    #[pin]
264
    generator: IF,
265

            
266
    /// This Option exists because otherwise SinkPrepareSendFuture::poll()
267
    /// can't move `output` out of this struct to put it into the `SinkSendable`.
268
    /// (The poll() impl cannot borrow from SinkPrepareSendFuture.)
269
    output: Option<Pin<&'w mut OS>>,
270

            
271
    /// `fn(OM)` gives contravariance in OM.
272
    ///
273
    /// Variance is confusing.
274
    /// Loosely, a SinkPrepareSendFuture<..OM> consumes an OM.
275
    /// Actually, we don't really need to add any variance restricions wrt OM,
276
    /// because the &mut OS already implies the correct variance,
277
    /// so we could have used the PhantomData<fn(*const OM)> trick.
278
    /// Happily there is no unsafe anywhere nearby, so it is not possible for us to write
279
    /// a bug due to getting the variance wrong - only to erroneously prevent some use
280
    /// case.
281
    tw: PhantomData<fn(OM)>,
282
}
283

            
284
/// A [`Sink`] which is ready to receive an item
285
///
286
/// Produced by [`SinkPrepareExt::prepare_send_from`].  See there for the overview docs.
287
///
288
/// This references an output sink `OS`.
289
/// It offers the ability to write into the sink without blocking,
290
/// (and constitutes a proof token that the sink has declared itself ready for that).
291
///
292
/// The only useful method is [`send`](SinkSendable::send).
293
///
294
/// `SinkSendable` has no drop glue and can be freely dropped,
295
/// for example if you prepare to send a message and then
296
/// encounter an error when producing the output message.
297
#[must_use]
298
pub struct SinkSendable<'w, OS, OM> {
299
    /// Reference to underlying output sink.
300
    output: Pin<&'w mut OS>,
301
    /// Marker to ensure that `OM` is used.
302
    tw: PhantomData<fn(OM)>,
303
}
304

            
305
impl<'w, IF, OS, IM, OM> Future for SinkPrepareSendFuture<'w, IF, OS, OM>
306
where
307
    IF: Future<Output = IM>,
308
    OS: Sink<OM>,
309
{
310
    type Output = Result<(IM, SinkSendable<'w, OS, OM>), OS::Error>;
311

            
312
5752
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
313
5752
        let mut self_ = self.project();
314

            
315
        /// returns `&mut Pin<&'w mut OS>` from self_.output
316
        //
317
        // macro because the closure's type parameters would be unnameable.
318
        macro_rules! get_output {
319
            ($self_:expr) => {
320
                $self_.output.as_mut().expect(BAD_POLL_MSG).as_mut()
321
            };
322
        }
323
        /// Message to give when panicking because of improper extra poll.
324
        const BAD_POLL_MSG: &str = "future from SinkPrepareExt::prepare_send_from (SinkPrepareSendFuture) \
325
                 polled after returning Ready(Ok)";
326

            
327
5752
        let () = match ready!(get_output!(self_).poll_ready(cx)) {
328
100
            Err(e) => {
329
100
                dprintln!("poll: output poll = IF.Err    SO  IF.Err");
330
                // Deliberately don't fuse by `take`ing output.  If we did that, we would expose
331
                // our caller to an additional panic risk.  There is no harm in polling the output
332
                // sink again: although `Sink` documents that a sink that returns errors will
333
                // probably continue to do so, it is not forbidden to try it and see.  This is in
334
                // any case better than definitely crashing if the `SinkPrepareSendFuture` is
335
                // polled after it gave Ready.
336
100
                return Poll::Ready(Err(e));
337
            }
338
5556
            Ok(()) => {
339
5556
                dprintln!("poll: output poll = IF.Ok     calling generator");
340
5556
            }
341
        };
342

            
343
5556
        let value = match self_.generator.as_mut().poll(cx) {
344
            Poll::Pending => {
345
                // We defer flushing the output until the input stops yielding.
346
                // This allows our caller (which is typically a loop) to transfer multiple
347
                // items from their input to their output between flushes.
348
                //
349
                // But we must not return `Pending` without flushing, or the caller could block
350
                // without flushing output, leading to untimely delivery of buffered data.
351
1116
                dprintln!("poll: generator = Pending     calling output flush");
352
1116
                let flushed = get_output!(self_).poll_flush(cx);
353
1114
                return match flushed {
354
2
                    Poll::Ready(Err(e)) => {
355
2
                        dprintln!("poll: output flush = IF.Err   SO  IF.Err");
356
2
                        Poll::Ready(Err(e))
357
                    }
358
                    Poll::Ready(Ok(())) => {
359
1112
                        dprintln!("poll: output flush = IF.Ok    SO  Pending");
360
1112
                        Poll::Pending
361
                    }
362
                    Poll::Pending => {
363
2
                        dprintln!("poll: output flush = Pending  SO  Pending");
364
2
                        Poll::Pending
365
                    }
366
                };
367
            }
368
4440
            Poll::Ready(v) => {
369
4440
                dprintln!("poll: generator = Ready       SO  IF.Ok");
370
4440
                v
371
            }
372
        };
373

            
374
4440
        let sendable = SinkSendable {
375
4440
            output: self_.output.take().expect(BAD_POLL_MSG),
376
4440
            tw: PhantomData,
377
4440
        };
378

            
379
4440
        Poll::Ready(Ok((value, sendable)))
380
5752
    }
381
}
382

            
383
impl<'w, IF, OS, IM, OM> FusedFuture for SinkPrepareSendFuture<'w, IF, OS, OM>
384
where
385
    IF: Future<Output = IM>,
386
    OS: Sink<OM>,
387
{
388
5744
    fn is_terminated(&self) -> bool {
389
5744
        let r = self.output.is_none();
390
5744
        dprintln!("is_terminated = {}", r);
391
5744
        r
392
5744
    }
393
}
394

            
395
impl<'w, OS, OM> SinkSendable<'w, OS, OM>
396
where
397
    OS: Sink<OM>,
398
{
399
    /// Synchronously send an item into `OS`, which is a [`Sink`]
400
    ///
401
    /// Can fail if the sink `OS` reports an error.
402
    ///
403
    /// (However, the existence of the `SinkSendable` demonstrates that
404
    /// the sink reported itself ready for sending,
405
    /// so this call is synchronous, avoiding cancellation hazards.)
406
4376
    pub fn send(self, item: OM) -> Result<(), OS::Error> {
407
4376
        dprintln!("send ...");
408
4376
        let r = self.output.start_send(item);
409
4376
        dprintln!("send: {:?}", r.as_ref().map_err(|_| (())));
410
4376
        r
411
4376
    }
412
}
413

            
414
#[cfg(test)]
415
mod test {
416
    // @@ begin test lint list maintained by maint/add_warning @@
417
    #![allow(clippy::bool_assert_comparison)]
418
    #![allow(clippy::clone_on_copy)]
419
    #![allow(clippy::dbg_macro)]
420
    #![allow(clippy::mixed_attributes_style)]
421
    #![allow(clippy::print_stderr)]
422
    #![allow(clippy::print_stdout)]
423
    #![allow(clippy::single_char_pattern)]
424
    #![allow(clippy::unwrap_used)]
425
    #![allow(clippy::unchecked_time_subtraction)]
426
    #![allow(clippy::useless_vec)]
427
    #![allow(clippy::needless_pass_by_value)]
428
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
429

            
430
    use super::*;
431
    use futures::SinkExt as _;
432
    use futures::channel::mpsc;
433
    use futures::future::poll_fn;
434
    use futures::select_biased;
435
    use futures_await_test::async_test;
436
    use std::convert::Infallible;
437
    use std::sync::Arc;
438
    use std::sync::Mutex;
439

            
440
    #[async_test]
441
    async fn prepare_send() {
442
        // Early versions of this used unfold quite a lot more, but it is not really
443
        // convenient for testing.  It buffers one item internally, and is also buggy:
444
        //   https://github.com/rust-lang/futures-rs/issues/2600
445
        // So we use mpsc channels, which (perhaps with buffering) are quite controllable.
446

            
447
        // The eprintln!("FOR ...") calls correspond go the dprintln1() calls in the impl,
448
        // and can check that each code path in the implementation is used,
449
        // by turning on the dbug and using `--nocapture`.
450
        {
451
            eprintln!("-- disconnected ---");
452
            eprintln!("FOR poll: output poll = IF.Err    SO  IF.Err");
453
            let (mut w, r) = mpsc::unbounded::<usize>();
454
            drop(r);
455
            let ret = w.prepare_send_from(async { Ok::<_, Infallible>(12) }).await;
456
            assert!(ret.map(|_| ()).unwrap_err().is_disconnected());
457
        }
458

            
459
        {
460
            eprintln!("-- buffered late disconnect --");
461
            eprintln!("FOR poll: output poll = IF.Ok     calling generator");
462
            eprintln!("FOR poll: output flush = IF.Err   SO  IF.Err");
463
            let (w, r) = mpsc::unbounded::<usize>();
464
            let mut w = w.buffer(10);
465
            let mut r = Some(r);
466
            w.feed(66).await.unwrap();
467
            let ret = w
468
                .prepare_send_from(poll_fn(move |_cx| {
469
                    drop(r.take());
470
                    Poll::Pending::<usize>
471
                }))
472
                .await;
473
            assert!(ret.map(|_| ()).unwrap_err().is_disconnected());
474
        }
475

            
476
        {
477
            eprintln!("-- flushing before wait --");
478
            eprintln!("FOR poll: output flush = IF.Ok    SO  Pending");
479
            let (mut w, _r) = mpsc::unbounded::<usize>();
480
            let () = select_biased! {
481
                _ = w.prepare_send_from(poll_fn(
482
                    move |_cx| {
483
                        Poll::Pending::<usize>
484
                    }
485
                )) => panic!(),
486
                _ = futures::future::ready(()) => { },
487
            };
488
        }
489

            
490
        {
491
            eprintln!("-- flush before wait is pending --");
492
            eprintln!("FOR poll: output flush = Pending  SO  Pending");
493
            let (mut w, _r) = mpsc::channel::<usize>(0);
494
            let () = w.feed(77).await.unwrap();
495
            let mut w = w.buffer(10);
496
            let () = select_biased! {
497
                _ = w.prepare_send_from(poll_fn(
498
                    move |_cx| {
499
                        Poll::Pending::<usize>
500
                    }
501
                )) => panic!(),
502
                _ = futures::future::ready(()) => { },
503
            };
504
        }
505

            
506
        {
507
            eprintln!("-- flush before wait is pending --");
508
            eprintln!("FOR poll: generator = Ready       SO  IF.Ok");
509
            eprintln!("FOR send ...");
510
            eprintln!("ALSO check that bufferinrg works as expected");
511

            
512
            let sunk = Arc::new(Mutex::new(vec![]));
513
            let unfold = futures::sink::unfold((), |(), v| {
514
                let sunk = sunk.clone();
515
                async move {
516
                    dbg!();
517
                    sunk.lock().unwrap().push(v);
518
                    Ok::<_, Infallible>(())
519
                }
520
            });
521
            let mut unfold = Box::pin(unfold.buffer(10));
522
            for v in [42, 43] {
523
                // We can only do two here because that's how many we can actually buffer in Buffer
524
                // and Unfold.  Because our closure is always ready, the buffering isn't actually
525
                // as copious as all that.  This is fine, because the point of this test is to test
526
                // *flushing*.
527
                dbg!(v);
528
                let ret = unfold
529
                    .prepare_send_from(async move { Ok::<_, Infallible>(v) })
530
                    .await;
531
                let (msg, sendable) = ret.unwrap();
532
                let msg = msg.unwrap();
533
                assert_eq!(msg, v);
534
                let () = sendable.send(msg).unwrap();
535
                let expect: &[u8] = &[];
536
                assert_eq!(*sunk.lock().unwrap(), expect); // It's still buffered
537
            }
538
            select_biased! {
539
                _ = unfold.prepare_send_from(futures::future::pending::<()>()) => panic!(),
540
                _ = futures::future::ready(()) => { },
541
            };
542
            assert_eq!(*sunk.lock().unwrap(), &[42, 43]);
543
        }
544
    }
545
}