1
//! Cancellable futures.
2

            
3
use std::{
4
    pin::Pin,
5
    sync::{Arc, Mutex},
6
    task::{Context, Poll, Waker},
7
};
8

            
9
use futures::Future;
10
use pin_project::pin_project;
11

            
12
/// A cancellable future type, loosely influenced by `RemoteHandle`.
13
///
14
/// This type is useful for cases when we can't cancel a future simply by
15
/// dropping it, because the future is owned by some other object (like a
16
/// `FuturesUnordered`) that won't give it up.
17
///
18
/// # Limitations
19
///
20
/// Do not try to cancel a future from inside a cancellable future,
21
/// including the future itself:
22
/// this may cause a panic or deadlock.
23
///
24
/// In `arti-rpcserver`, we prevent this happening by ensuring that
25
/// every method that calls `cancel()` is itself uncancellable.
26
///
27
// TODO: We should probably fix this limitation somehow before exposing
28
// this code outside of this crate.  But see comments inside `Cancel::poll`
29
// for why we might not want to just drop the lock while polling.
30
//
31
// Also: We could use `tokio_util`'s cancellable futures instead here, but I don't
32
// think we want an unconditional tokio_util dependency.
33
#[pin_project]
34
pub(crate) struct Cancel<F> {
35
    /// Shared state between the `Cancel` and the `CancelHandle`.
36
    //
37
    // It would be nice not to have to stick this behind a mutex, but that would
38
    // make it a bit tricky to manage the Waker.
39
    inner: Arc<Mutex<Inner>>,
40
    /// The inner future.
41
    ///
42
    /// TODO: Possibly we should move this into `inner`,
43
    /// so that we can make sure that we don't execute the future without holding the lock,
44
    /// and so we can drop the future immediately when it's cancelled.
45
    /// But that would take some fairly tricky type erasure, so maybe it isn't worth it?
46
    #[pin]
47
    fut: F,
48
}
49

            
50
/// Possible status of `Cancel` future.
51
#[derive(Clone, Copy, Debug)]
52
enum Status {
53
    /// The future has neither finished, nor been cancelled.
54
    Pending,
55
    /// The future has finished; it can no longer be cancelled.
56
    Finished,
57
    /// The future has been cancelled; it should no longer be polled.
58
    Cancelled,
59
}
60

            
61
/// Inner state shared between `Cancel` and the `CancelHandle.
62
struct Inner {
63
    /// Current status of the future.
64
    status: Status,
65
    /// A waker to use in telling this future that it's cancelled.
66
    waker: Option<Waker>,
67
}
68

            
69
/// An object that can be used to cancel a future.
70
#[derive(Clone)]
71
pub(crate) struct CancelHandle {
72
    /// The shared state for the cancellable future between `Cancel` and
73
    /// `CancelHandle`.
74
    inner: Arc<Mutex<Inner>>,
75
}
76

            
77
impl<F> Cancel<F> {
78
    /// Wrap `fut` in a new future that can be cancelled.
79
    ///
80
    /// Returns a handle to cancel the future, and the cancellable future.
81
4102
    pub(crate) fn new(fut: F) -> (CancelHandle, Cancel<F>) {
82
4102
        let inner = Arc::new(Mutex::new(Inner {
83
4102
            status: Status::Pending,
84
4102
            waker: None,
85
4102
        }));
86
4102
        let handle = CancelHandle {
87
4102
            inner: inner.clone(),
88
4102
        };
89
4102
        let future = Cancel { inner, fut };
90
4102
        (handle, future)
91
4102
    }
92
}
93

            
94
impl CancelHandle {
95
    /// Cancel the associated future, if it has not already finished.
96
    ///
97
    /// # Limitations
98
    ///
99
    /// This function may panic or deadlock if you call it from inside a `Cancel<F>`
100
    /// future.  See discussion in [`Cancel`] documentation.
101
4100
    pub(crate) fn cancel(&self) -> Result<(), CannotCancel> {
102
4100
        let mut inner = self.inner.lock().expect("poisoned lock");
103
4100
        match inner.status {
104
2016
            Status::Pending => inner.status = Status::Cancelled,
105
2084
            Status::Finished => return Err(CannotCancel::Finished),
106
            Status::Cancelled => return Err(CannotCancel::Cancelled),
107
        }
108
2016
        if let Some(waker) = inner.waker.take() {
109
4
            drop(inner); // release lock.
110
4
            waker.wake();
111
2012
        }
112
2016
        Ok(())
113
4100
    }
114
}
115

            
116
/// An error returned from a `Cancel` future if it is cancelled.
117
#[derive(thiserror::Error, Clone, Debug)]
118
#[error("Future was cancelled")]
119
pub(crate) struct Cancelled;
120

            
121
/// An error returned when we cannot cancel a future.
122
#[derive(thiserror::Error, Clone, Debug)]
123
pub(crate) enum CannotCancel {
124
    /// This future was already cancelled, and can't be cancelled again.
125
    #[error("Already cancelled")]
126
    Cancelled,
127

            
128
    /// This future has already completed, and can't be cancelled.
129
    #[error("Already finished")]
130
    Finished,
131
}
132

            
133
impl<F: Future> Future for Cancel<F> {
134
    type Output = Result<F::Output, Cancelled>;
135

            
136
4106
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137
4106
        let this = self.project();
138

            
139
4106
        let mut inner = this.inner.lock().expect("lock poisoned");
140
4106
        match inner.status {
141
2090
            Status::Pending => {}
142
            Status::Finished => {
143
                // Yes, we do intentionally allow a finished future to be polled again.
144
                // This does not violate our invariants.
145
                // If you want to prevent this, you need to use Fuse or a similar mechanism.
146
            }
147
2016
            Status::Cancelled => return Poll::Ready(Err(Cancelled)),
148
        }
149
        // Note that we're holding the mutex here while we poll the future.
150
        // This guarantees that the future can't make _any_ progress after it has been
151
        // cancelled.  If we someday decide we don't care about that, we could release the mutex
152
        // while polling, and pick it up again after we're done polling.
153
2090
        match this.fut.poll(cx) {
154
2086
            Poll::Ready(val) => {
155
2086
                inner.status = Status::Finished;
156
2086
                Poll::Ready(Ok(val))
157
            }
158
            Poll::Pending => {
159
4
                if let Some(existing_waker) = &mut inner.waker {
160
                    // If we already have a waker, we use clone_from here,
161
                    // since that function knows to use will_wake
162
                    // to avoid a needless clone.
163
                    existing_waker.clone_from(cx.waker());
164
4
                } else {
165
4
                    // Otherwise, we need to clone cx.waker().
166
4
                    inner.waker = Some(cx.waker().clone());
167
4
                }
168
4
                Poll::Pending
169
            }
170
        }
171
4106
    }
172
}
173

            
174
#[cfg(test)]
175
mod test {
176
    // @@ begin test lint list maintained by maint/add_warning @@
177
    #![allow(clippy::bool_assert_comparison)]
178
    #![allow(clippy::clone_on_copy)]
179
    #![allow(clippy::dbg_macro)]
180
    #![allow(clippy::mixed_attributes_style)]
181
    #![allow(clippy::print_stderr)]
182
    #![allow(clippy::print_stdout)]
183
    #![allow(clippy::single_char_pattern)]
184
    #![allow(clippy::unwrap_used)]
185
    #![allow(clippy::unchecked_time_subtraction)]
186
    #![allow(clippy::useless_vec)]
187
    #![allow(clippy::needless_pass_by_value)]
188
    #![allow(clippy::string_slice)] // See arti#2571
189
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
190

            
191
    use std::{future, time::Duration};
192

            
193
    use super::*;
194
    use futures::{FutureExt as _, StreamExt as _, stream::FuturesUnordered};
195
    use futures_await_test::async_test;
196
    use oneshot_fused_workaround as oneshot;
197
    use tor_basic_utils::RngExt;
198
    use tor_rtcompat::SleepProvider as _;
199

            
200
    #[async_test]
201
    async fn not_cancelled() {
202
        let f = futures::future::ready("hello");
203
        let (_h, f) = Cancel::new(f);
204
        assert_eq!(f.await.unwrap(), "hello");
205
    }
206

            
207
    #[async_test]
208
    async fn cancelled() {
209
        let f = futures::future::pending::<()>();
210
        let (h, f) = Cancel::new(f);
211
        let (r, ()) = futures::join!(f, async {
212
            h.cancel().unwrap();
213
        });
214
        assert!(matches!(r, Err(Cancelled)));
215

            
216
        let (_tx, rx) = oneshot::channel::<()>();
217
        let (h, f) = Cancel::new(rx);
218
        let (r, ()) = futures::join!(f, async {
219
            h.cancel().unwrap();
220
        });
221
        assert!(matches!(r, Err(Cancelled)));
222
    }
223

            
224
    #[test]
225
    fn cancelled_or_not() {
226
        // This looks pretty complicated!  But really what we're doing is running a whole bunch
227
        // of tasks and cancelling them almost-immediately, to make sure that every task either
228
        // succeeds or fails.
229

            
230
        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
231
            #[allow(deprecated)] // TODO #1885
232
            let rt = tor_rtmock::MockSleepRuntime::new(rt);
233

            
234
            const N_TRIES: usize = 1024;
235
            // Time is virtual here, so the interval doesn't matter.
236
            const SLEEP_CEIL: Duration = Duration::from_millis(1);
237
            let work_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
238
            let cancel_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
239

            
240
            let mut futs = FuturesUnordered::new();
241
            for idx in 0..N_TRIES {
242
                let work_succeeded = Arc::clone(&work_succeeded);
243
                let cancel_succeeded = Arc::clone(&cancel_succeeded);
244
                let rt1 = rt.clone();
245
                let rt2 = rt.clone();
246
                let t1 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
247
                let t2 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
248

            
249
                let work = future::ready(());
250
                let (handle, work) = Cancel::new(work);
251
                let f1 = async move {
252
                    rt1.sleep(t1).await;
253
                    let r = handle.cancel();
254
                    cancel_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
255
                };
256
                let f2 = async move {
257
                    rt2.sleep(t2).await;
258
                    let r = work.await;
259
                    work_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
260
                };
261

            
262
                futs.push(f1.boxed());
263
                futs.push(f2.boxed());
264
            }
265

            
266
            rt.wait_for(async { while let Some(()) = futs.next().await {} })
267
                .await;
268
            for idx in 0..N_TRIES {
269
                let ws = work_succeeded.lock().unwrap()[idx];
270
                let cs = cancel_succeeded.lock().unwrap()[idx];
271
                match (ws, cs) {
272
                    (Some(true), Some(false)) => {}
273
                    (Some(false), Some(true)) => {}
274
                    _ => panic!("incorrect values {:?}", (idx, ws, cs)),
275
                }
276
            }
277
        });
278
    }
279
}