1
//! Helpers for retrying a fallible operation according to a backoff schedule.
2
//!
3
//! [`Runner::run`] retries the specified operation according to the [`BackoffSchedule`] of the
4
//! [`Runner`]. Users can customize the backoff behavior by implementing [`BackoffSchedule`].
5

            
6
// TODO: this is a (somewhat) general-purpose utility, so it should probably be factored out of
7
// tor-hsservice
8

            
9
use std::pin::Pin;
10

            
11
use futures::future::FusedFuture;
12

            
13
use tor_rtcompat::TimeoutError;
14

            
15
use super::*;
16

            
17
/// A runner for a fallible operation, which retries on failure according to a [`BackoffSchedule`].
18
pub(super) struct Runner<B: BackoffSchedule, R: Runtime> {
19
    /// A description of the operation we are trying to do.
20
    doing: String,
21
    /// The backoff schedule.
22
    schedule: B,
23
    /// The runtime.
24
    runtime: R,
25
}
26

            
27
impl<B: BackoffSchedule, R: Runtime> Runner<B, R> {
28
    /// Create a new `Runner`.
29
144
    pub(super) fn new(doing: String, schedule: B, runtime: R) -> Self {
30
144
        Self {
31
144
            doing,
32
144
            schedule,
33
144
            runtime,
34
144
        }
35
144
    }
36

            
37
    /// Run `fallible_fn`, retrying according to the [`BackoffSchedule`] of this `Runner`.
38
    ///
39
    /// If `fallible_fn` eventually returns `Ok(_)`, return that output. Otherwise,
40
    /// keep retrying until either `fallible_fn` has failed too many times, or until
41
    /// a fatal error occurs.
42
    #[allow(clippy::cognitive_complexity)] // TODO: Refactor
43
152
    pub(super) async fn run<T, E, F>(
44
152
        mut self,
45
152
        mut fallible_fn: impl FnMut() -> F,
46
152
    ) -> Result<T, BackoffError<E>>
47
152
    where
48
152
        E: RetriableError,
49
152
        F: Future<Output = Result<T, E>> + Send,
50
152
    {
51
152
        let mut retry_count = 0;
52
152
        let mut errors = RetryError::in_attempt_to(self.doing.clone());
53

            
54
        // When this timeout elapses, the `Runner` will stop retrying the fallible operation.
55
        //
56
        // A `overall_timeout` of `None` means there is no time limit for the retries.
57
152
        let mut overall_timeout = match self.schedule.overall_timeout() {
58
146
            Some(timeout) => Either::Left(Box::pin(self.runtime.sleep(timeout))),
59
6
            None => Either::Right(future::pending()),
60
        }
61
152
        .fuse();
62

            
63
        loop {
64
            // Bail if we've exceeded the number of allowed retries.
65
230
            if matches!(self.schedule.max_retries(), Some(max_retry_count) if retry_count >= max_retry_count)
66
            {
67
4
                return Err(BackoffError::MaxRetryCountExceeded(errors));
68
226
            }
69

            
70
226
            let mut fallible_op = optionally_timeout(
71
226
                &self.runtime,
72
226
                fallible_fn(),
73
226
                self.schedule.single_attempt_timeout(),
74
            );
75

            
76
226
            trace!(attempt = (retry_count + 1), "{}", self.doing);
77

            
78
226
            select_biased! {
79
                () = overall_timeout => {
80
                    // The timeout has elapsed, so stop retrying and return the errors
81
                    // accumulated so far.
82
2
                    return Err(BackoffError::Timeout(errors))
83
                }
84
224
                res = fallible_op => {
85
                    // TODO: the error branches in the match below have different error types,
86
                    // so we must compute should_retry and delay separately, on each branch.
87
                    //
88
                    // We could refactor this to extract the error using
89
                    // let err = match res { ... } and call err.should_retry()
90
                    // and next_delay() after the match, but this will involve
91
                    // rethinking the BackoffSchedule trait and/or RetriableError
92
                    // (currently RetriableError is Clone, so it's not object safe).
93
214
                    let (should_retry, delay) = match res {
94
112
                        Ok(Ok(res)) => return Ok(res),
95
102
                        Ok(Err(e)) => {
96
                            // The operation failed: check if we can retry it.
97
102
                            let should_retry = e.should_retry();
98

            
99
102
                            debug!(
100
                                attempt=(retry_count + 1), can_retry=should_retry,
101
                                "failed to {}: {e}", self.doing
102
                            );
103

            
104
102
                            errors.push_timed(e.clone(), self.runtime.now(), None);
105
102
                            (e.should_retry(), self.schedule.next_delay(&e))
106
                        }
107
10
                        Err(e) => {
108
10
                            trace!("fallible operation timed out; retrying");
109
10
                            (e.should_retry(), self.schedule.next_delay(&e))
110
                        },
111
                    };
112

            
113
112
                    if should_retry {
114
110
                        retry_count += 1;
115

            
116
110
                        let Some(delay) = delay else {
117
                            return Err(BackoffError::ExplicitStop(errors));
118
                        };
119

            
120
                        // Introduce the specified delay between retries
121
110
                        let () = self.runtime.sleep(delay).await;
122

            
123
                        // Try again unless the entire operation has timed out.
124
78
                        continue;
125
2
                    }
126

            
127
2
                    return Err(BackoffError::FatalError(errors));
128
                },
129
            }
130
        }
131
120
    }
132
}
133

            
134
/// Wrap a [`Future`] with an optional timeout.
135
///
136
/// If `timeout` is `Some`, returns a [`Timeout`](tor_rtcompat::Timeout)
137
/// that resolves to the value of `future` if the future completes within `timeout`,
138
/// or a [`TimeoutError`] if it does not.
139
/// If `timeout` is `None`, returns a new future which maps the specified `future`'s
140
/// output type to a `Result::Ok`.
141
226
fn optionally_timeout<'f, R, F>(
142
226
    runtime: &R,
143
226
    future: F,
144
226
    timeout: Option<Duration>,
145
226
) -> Pin<Box<dyn FusedFuture<Output = Result<F::Output, TimeoutError>> + Send + 'f>>
146
226
where
147
226
    R: Runtime,
148
226
    F: Future + Send + 'f,
149
{
150
226
    match timeout {
151
186
        Some(timeout) => Box::pin(runtime.timeout(timeout, future).fuse()),
152
40
        None => Box::pin(future.map(Ok)),
153
    }
154
226
}
155

            
156
/// A trait that specifies the parameters for retrying a fallible operation.
157
pub(super) trait BackoffSchedule {
158
    /// The maximum number of retries.
159
    ///
160
    /// A return value of `None` indicates is no upper limit for the number of retries, and that
161
    /// the operation should be retried until [`BackoffSchedule::overall_timeout`] time elapses (or
162
    /// indefinitely, if [`BackoffSchedule::overall_timeout`] returns `None`).
163
    fn max_retries(&self) -> Option<usize>;
164

            
165
    /// The total amount of time allowed for the retriable operation.
166
    ///
167
    /// A return value of `None` indicates the operation should be retried until
168
    /// [`BackoffSchedule::max_retries`] number of retries are exceeded (or indefinitely, if
169
    /// [`BackoffSchedule::max_retries`] returns `None`).
170
    fn overall_timeout(&self) -> Option<Duration>;
171

            
172
    /// The total amount of time allowed for a single operation.
173
    fn single_attempt_timeout(&self) -> Option<Duration>;
174

            
175
    /// Return the delay to introduce before the next retry.
176
    ///
177
    /// The `error` parameter contains the error returned by the fallible operation. This enables
178
    /// implementors to (optionally) implement adaptive backoff. For example, if the operation is
179
    /// sending an HTTP request, and the error is a 429 (Too Many Requests) HTTP response with a
180
    /// `Retry-After` header, the implementor can implement a backoff schedule where the next retry
181
    /// is delayed by the value specified in the `Retry-After` header.
182
    fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration>;
183
}
184

            
185
/// The type of error encountered while running a fallible operation.
186
#[derive(Clone, Debug, thiserror::Error)]
187
pub(super) enum BackoffError<E> {
188
    /// A fatal (non-transient) error occurred.
189
    #[error("A fatal (non-transient) error occurred")]
190
    FatalError(RetryError<E>),
191

            
192
    /// Ran out of retries.
193
    #[error("Ran out of retries")]
194
    MaxRetryCountExceeded(RetryError<E>),
195

            
196
    /// Exceeded the maximum allowed time.
197
    #[error("Timeout exceeded")]
198
    Timeout(RetryError<E>),
199

            
200
    /// The [`BackoffSchedule`] told us to stop retrying.
201
    #[error("Stopped retrying as requested by BackoffSchedule")]
202
    ExplicitStop(RetryError<E>),
203
}
204

            
205
impl<E> From<BackoffError<E>> for RetryError<E> {
206
    fn from(e: BackoffError<E>) -> Self {
207
        match e {
208
            BackoffError::FatalError(e)
209
            | BackoffError::MaxRetryCountExceeded(e)
210
            | BackoffError::Timeout(e)
211
            | BackoffError::ExplicitStop(e) => e,
212
        }
213
    }
214
}
215

            
216
/// A trait for representing retriable errors.
217
pub(super) trait RetriableError: StdError + Clone {
218
    /// Whether this error is transient.
219
    fn should_retry(&self) -> bool;
220
}
221

            
222
impl RetriableError for TimeoutError {
223
10
    fn should_retry(&self) -> bool {
224
10
        true
225
10
    }
226
}
227

            
228
#[cfg(test)]
229
mod tests {
230
    // @@ begin test lint list maintained by maint/add_warning @@
231
    #![allow(clippy::bool_assert_comparison)]
232
    #![allow(clippy::clone_on_copy)]
233
    #![allow(clippy::dbg_macro)]
234
    #![allow(clippy::mixed_attributes_style)]
235
    #![allow(clippy::print_stderr)]
236
    #![allow(clippy::print_stdout)]
237
    #![allow(clippy::single_char_pattern)]
238
    #![allow(clippy::unwrap_used)]
239
    #![allow(clippy::unchecked_time_subtraction)]
240
    #![allow(clippy::useless_vec)]
241
    #![allow(clippy::needless_pass_by_value)]
242
    #![allow(clippy::string_slice)] // See arti#2571
243
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
244

            
245
    use super::*;
246
    use std::sync::Arc;
247

            
248
    use std::iter;
249
    use std::sync::RwLock;
250

            
251
    use oneshot_fused_workaround as oneshot;
252
    use tor_rtcompat::{SleepProvider, ToplevelBlockOn};
253
    use tor_rtmock::MockRuntime;
254

            
255
    const SHORT_DELAY: Duration = Duration::from_millis(10);
256
    const TIMEOUT: Duration = Duration::from_millis(100);
257
    const SINGLE_TIMEOUT: Duration = Duration::from_millis(50);
258
    const MAX_RETRIES: usize = 5;
259

            
260
    macro_rules! impl_backoff_sched {
261
        ($name:ty, $max_retries:expr, $timeout:expr, $single_timeout:expr, $next_delay:expr) => {
262
            impl BackoffSchedule for $name {
263
                fn max_retries(&self) -> Option<usize> {
264
                    $max_retries
265
                }
266

            
267
                fn overall_timeout(&self) -> Option<Duration> {
268
                    $timeout
269
                }
270

            
271
                fn single_attempt_timeout(&self) -> Option<Duration> {
272
                    $single_timeout
273
                }
274

            
275
                #[allow(unused_variables)]
276
                fn next_delay<E: RetriableError>(&mut self, error: &E) -> Option<Duration> {
277
                    $next_delay
278
                }
279
            }
280
        };
281
    }
282

            
283
    struct BackoffWithMaxRetries;
284

            
285
    impl_backoff_sched!(
286
        BackoffWithMaxRetries,
287
        Some(MAX_RETRIES),
288
        None,
289
        None,
290
        Some(SHORT_DELAY)
291
    );
292

            
293
    struct BackoffWithTimeout;
294

            
295
    impl_backoff_sched!(
296
        BackoffWithTimeout,
297
        None,
298
        Some(TIMEOUT),
299
        None,
300
        Some(SHORT_DELAY)
301
    );
302

            
303
    struct BackoffWithSingleTimeout;
304

            
305
    impl_backoff_sched!(
306
        BackoffWithSingleTimeout,
307
        Some(MAX_RETRIES),
308
        None,
309
        Some(SINGLE_TIMEOUT),
310
        Some(SHORT_DELAY)
311
    );
312

            
313
    /// A potentially retriable error.
314
    #[derive(Debug, Copy, Clone, thiserror::Error)]
315
    enum TestError {
316
        /// A fatal error
317
        #[error("A fatal test error")]
318
        Fatal,
319
        /// A transient error
320
        #[error("A transient test error")]
321
        Transient,
322
    }
323

            
324
    impl RetriableError for TestError {
325
        fn should_retry(&self) -> bool {
326
            match self {
327
                Self::Fatal => false,
328
                Self::Transient => true,
329
            }
330
        }
331
    }
332

            
333
    /// Run a single [`Runner`] test.
334
    fn run_test<E: RetriableError + Send + Sync + 'static>(
335
        sleep_for: Option<Duration>,
336
        schedule: impl BackoffSchedule + Send + 'static,
337
        errors: impl Iterator<Item = E> + Send + Sync + 'static,
338
        expected_run_count: usize,
339
        description: &'static str,
340
        expected_duration: Duration,
341
    ) {
342
        let runtime = MockRuntime::new();
343

            
344
        runtime.clone().block_on(async move {
345
            let runner = Runner {
346
                doing: description.into(),
347
                schedule,
348
                runtime: runtime.clone(),
349
            };
350

            
351
            let retry_count = Arc::new(RwLock::new(0));
352
            let (tx, rx) = oneshot::channel();
353

            
354
            let start = runtime.now();
355
            runtime
356
                .mock_task()
357
                .spawn_identified(format!("retry runner task: {description}"), {
358
                    let retry_count = Arc::clone(&retry_count);
359
                    let errors = Arc::new(RwLock::new(errors));
360
                    let runtime = runtime.clone();
361
                    async move {
362
                        if let Ok(()) = runner
363
                            .run(|| async {
364
                                *retry_count.write().unwrap() += 1;
365

            
366
                                if let Some(dur) = sleep_for {
367
                                    runtime.sleep(dur).await;
368
                                }
369

            
370
                                Err::<(), _>(errors.write().unwrap().next().unwrap())
371
                            })
372
                            .await
373
                        {
374
                            unreachable!();
375
                        }
376

            
377
                        let () = tx.send(()).unwrap();
378
                    }
379
                });
380

            
381
            // The expected retry count may be unknown (for example, if we set a timeout but no
382
            // upper limit for the number of retries, it's impossible to tell exactly how many
383
            // times the operation will be retried)
384
            for i in 1..=expected_run_count {
385
                runtime.mock_task().progress_until_stalled().await;
386
                // If our fallible_op is sleeping, advance the time until after it times out or
387
                // finishes sleeping.
388
                if let Some(sleep_for) = sleep_for {
389
                    runtime
390
                        .mock_sleep()
391
                        .advance(std::cmp::min(SINGLE_TIMEOUT, sleep_for));
392
                }
393
                runtime.mock_task().progress_until_stalled().await;
394
                runtime.mock_sleep().advance(SHORT_DELAY);
395
                assert_eq!(*retry_count.read().unwrap(), i);
396
            }
397

            
398
            let () = rx.await.unwrap();
399
            let end = runtime.now();
400

            
401
            assert_eq!(*retry_count.read().unwrap(), expected_run_count);
402
            assert!(duration_close_to(end - start, expected_duration));
403
        });
404
    }
405

            
406
    /// Return true if d1 is in range [d2...d2 + 0.01sec]
407
    ///
408
    /// TODO: lifted from tor-circmgr
409
    fn duration_close_to(d1: Duration, d2: Duration) -> bool {
410
        d1 >= d2 && d1 <= d2 + SHORT_DELAY
411
    }
412

            
413
    #[test]
414
    fn max_retries() {
415
        run_test(
416
            None,
417
            BackoffWithMaxRetries,
418
            iter::repeat(TestError::Transient),
419
            MAX_RETRIES,
420
            "backoff with max_retries and no timeout (transient errors)",
421
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * MAX_RETRIES as u64),
422
        );
423
    }
424

            
425
    #[test]
426
    fn max_retries_fatal() {
427
        use TestError::*;
428

            
429
        /// The number of transient errors that happen before the final, fatal error.
430
        const RETRIES_UNTIL_FATAL: usize = 3;
431
        /// The total number of times we exoect the fallible function to be called.
432
        /// The first RETRIES_UNTIL_FATAL times, a transient error is returned.
433
        /// The last call corresponds to the fatal error
434
        const EXPECTED_TOTAL_RUNS: usize = RETRIES_UNTIL_FATAL + 1;
435

            
436
        run_test(
437
            None,
438
            BackoffWithMaxRetries,
439
            std::iter::repeat_n(Transient, RETRIES_UNTIL_FATAL)
440
                .chain([Fatal])
441
                .chain(iter::repeat(Transient)),
442
            EXPECTED_TOTAL_RUNS,
443
            "backoff with max_retries and no timeout (transient errors followed by a fatal error)",
444
            Duration::from_millis(SHORT_DELAY.as_millis() as u64 * EXPECTED_TOTAL_RUNS as u64),
445
        );
446
    }
447

            
448
    #[test]
449
    fn timeout() {
450
        use TestError::*;
451

            
452
        let expected_run_count = TIMEOUT.as_millis() / SHORT_DELAY.as_millis();
453

            
454
        run_test(
455
            None,
456
            BackoffWithTimeout,
457
            iter::repeat(Transient),
458
            expected_run_count as usize,
459
            "backoff with timeout and no max_retries (transient errors)",
460
            TIMEOUT,
461
        );
462
    }
463

            
464
    #[test]
465
    fn single_timeout() {
466
        use TestError::*;
467

            
468
        // Each attempt will time out after SINGLE_TIMEOUT time units,
469
        // and the backoff runner sleeps for SLEEP_DELAY units in between retries
470
        let expected_duration = Duration::from_millis(
471
            (SHORT_DELAY.as_millis() + SINGLE_TIMEOUT.as_millis()) as u64 * MAX_RETRIES as u64,
472
        );
473

            
474
        run_test(
475
            // Sleep for more than SINGLE_TIMEOUT units
476
            // to trigger the single_attempt_timeout() timeout
477
            Some(SINGLE_TIMEOUT * 2),
478
            BackoffWithSingleTimeout,
479
            iter::repeat(Transient),
480
            MAX_RETRIES,
481
            "backoff with single timeout and max_retries and no overall timeout",
482
            expected_duration,
483
        );
484
    }
485

            
486
    // TODO (#1120): needs tests for the remaining corner cases
487
}