1
//! Helpers for iterating over error sources.
2

            
3
use std::{io, sync::Arc};
4

            
5
/// An iterator over the lower-level error sources, and possibly their wrapped errors, of
6
/// an [`std::error::Error`].
7
///
8
/// One of the main reasons why you might want to use this instead of calling [`std::error::Error::source`]
9
/// repeatedly is because the `source` implementation of [`io::Error`] doesn't return wrapped errors unless
10
/// you call `get_ref` on them (see: <https://github.com/rust-lang/rust/pull/124536>). You can think of this
11
/// iterator as walking down the chain of how an error was constructed. However, this iterator shouldn't be
12
/// used to display or format errors. Doing so could result in displaying the same error twice (due to the
13
/// wrapping behavior of `io::Error`).
14
///
15
/// Each call to [`Iterator::next`] will attempt to peel off the outer layer of the error.
16
///
17
/// The first item returned is always the original error. Subsequent items are generated by calling:
18
///   * [`io::Error::get_ref`] if the last error could be downcast to an [`io::Error`] or
19
///     [`Arc<io::Error>`], or
20
///   * [`std::error::Error::source`] in all other cases
21
///
22
/// # Limitations
23
///
24
/// This is currently not handling [`io::Error`]s that are wrapped in containers such as `Box`, `Rc`, etc.
25
pub struct ErrorSources<'a> {
26
    /// The last error we managed to get via `get_ref` or `source`.
27
    ///
28
    /// Initially this is set to the error passed in via [`Self::new`].
29
    error: Option<&'a (dyn std::error::Error + 'static)>,
30
}
31

            
32
impl<'a> ErrorSources<'a> {
33
    /// Create an iterator over the lower-level sources of this error.
34
947
    pub fn new(error: &'a (dyn std::error::Error + 'static)) -> Self {
35
947
        Self { error: Some(error) }
36
947
    }
37
}
38

            
39
impl<'a> Iterator for ErrorSources<'a> {
40
    type Item = &'a (dyn std::error::Error + 'static);
41

            
42
2038
    fn next(&mut self) -> Option<Self::Item> {
43
2038
        let error = self.error.take()?;
44

            
45
1161
        if let Some(io_error) = error.downcast_ref::<io::Error>() {
46
            // This match is necessary to cast from `&dyn Error + Send + Sync` to `&dyn Error` :/
47
            //
48
            // The use of `get_ref` here is intentional because we want to save the error that
49
            // this `io::Error` is wrapping. If we used `source` that would give us the source of
50
            // the error that's being wrapped.
51
3
            self.error = io_error.get_ref().map(|e| e as _);
52
1159
        } else if let Some(io_error) = error.downcast_ref::<Arc<io::Error>>() {
53
3
            self.error = io_error.get_ref().map(|e| e as _);
54
1157
        } else {
55
1157
            self.error = error.source();
56
1157
        }
57

            
58
1161
        Some(error)
59
2038
    }
60
}
61

            
62
// ----------------------------------------------------------------------
63

            
64
#[cfg(test)]
65
mod test {
66
    // @@ begin test lint list maintained by maint/add_warning @@
67
    #![allow(clippy::bool_assert_comparison)]
68
    #![allow(clippy::clone_on_copy)]
69
    #![allow(clippy::dbg_macro)]
70
    #![allow(clippy::mixed_attributes_style)]
71
    #![allow(clippy::print_stderr)]
72
    #![allow(clippy::print_stdout)]
73
    #![allow(clippy::single_char_pattern)]
74
    #![allow(clippy::unwrap_used)]
75
    #![allow(clippy::unchecked_time_subtraction)]
76
    #![allow(clippy::useless_vec)]
77
    #![allow(clippy::needless_pass_by_value)]
78
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
79
    use super::*;
80

            
81
    #[derive(thiserror::Error, Debug)]
82
    #[error("my error")]
83
    struct MyError;
84

            
85
    macro_rules! downcast_next {
86
        ($errors:expr, $ty:ty) => {
87
            $errors.next().unwrap().downcast_ref::<$ty>().unwrap()
88
        };
89
    }
90

            
91
    #[test]
92
    fn error_sources() {
93
        let wrapped_error = io::Error::new(
94
            io::ErrorKind::ConnectionReset,
95
            Arc::new(io::Error::new(io::ErrorKind::ConnectionReset, MyError)),
96
        );
97
        let mut errors = ErrorSources::new(&wrapped_error);
98

            
99
        downcast_next!(errors, io::Error);
100
        downcast_next!(errors, Arc<io::Error>);
101
        downcast_next!(errors, MyError);
102
        assert!(errors.next().is_none());
103
    }
104
}