1
//! Helper type for making configurations mutable.
2

            
3
use std::sync::{Arc, RwLock};
4

            
5
/// A mutable configuration object.
6
///
7
/// Internally, this is just a `RwLock<Arc<T>>`; this type just defines some
8
/// convenience wrappers for it.
9
#[derive(Debug, Default)]
10
pub struct MutCfg<T> {
11
    /// The interior configuration object.
12
    cfg: RwLock<Arc<T>>,
13
}
14

            
15
impl<T> MutCfg<T> {
16
    /// Return a new MutCfg with the provided value.
17
356
    pub fn new(config: T) -> Self {
18
356
        Self {
19
356
            cfg: RwLock::new(Arc::new(config)),
20
356
        }
21
356
    }
22

            
23
    /// Return the current configuration
24
642
    pub fn get(&self) -> Arc<T> {
25
642
        Arc::clone(&self.cfg.read().expect("poisoned lock"))
26
642
    }
27

            
28
    /// If this configuration object is still the same pointer as `old_config`,
29
    /// replace it with `new_config`.
30
    ///
31
    /// Returns `true` if it was in fact replaced.
32
4
    pub fn check_and_replace(&self, old_config: &Arc<T>, new_config: T) -> bool {
33
4
        let mut cfg = self.cfg.write().expect("poisoned lock");
34
4
        if Arc::ptr_eq(&cfg, old_config) {
35
2
            *cfg = Arc::new(new_config);
36
2
            true
37
        } else {
38
2
            false
39
        }
40
4
    }
41

            
42
    /// Replace this configuration with `new_config`.
43
10
    pub fn replace(&self, new_config: T) {
44
10
        *self.cfg.write().expect("poisoned lock") = Arc::new(new_config);
45
10
    }
46

            
47
    /// Replace the current configuration with the results of evaluating `func` on it.
48
2
    pub fn map_and_replace<F>(&self, func: F)
49
2
    where
50
2
        F: FnOnce(&Arc<T>) -> T,
51
    {
52
2
        let mut cfg = self.cfg.write().expect("poisoned lock");
53
2
        let new_cfg = func(&cfg);
54
2
        *cfg = Arc::new(new_cfg);
55
2
    }
56
}
57

            
58
impl<T> From<T> for MutCfg<T> {
59
348
    fn from(config: T) -> MutCfg<T> {
60
348
        MutCfg::new(config)
61
348
    }
62
}
63

            
64
#[cfg(test)]
65
mod test {
66
    // @@ begin test lint list maintained by maint/add_warning @@
67
    #![allow(clippy::bool_assert_comparison)]
68
    #![allow(clippy::clone_on_copy)]
69
    #![allow(clippy::dbg_macro)]
70
    #![allow(clippy::mixed_attributes_style)]
71
    #![allow(clippy::print_stderr)]
72
    #![allow(clippy::print_stdout)]
73
    #![allow(clippy::single_char_pattern)]
74
    #![allow(clippy::unwrap_used)]
75
    #![allow(clippy::unchecked_time_subtraction)]
76
    #![allow(clippy::useless_vec)]
77
    #![allow(clippy::needless_pass_by_value)]
78
    #![allow(clippy::string_slice)] // See arti#2571
79
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
80
    use super::*;
81

            
82
    #[test]
83
    fn basic_constructors() {
84
        let m = MutCfg::new(7_u32);
85
        assert_eq!(*m.get(), 7);
86
        let m: MutCfg<u32> = MutCfg::default();
87
        assert_eq!(*m.get(), 0);
88
        let m: MutCfg<u32> = 100.into();
89
        assert_eq!(*m.get(), 100);
90
    }
91

            
92
    #[test]
93
    fn mutate_with_existing_ref() {
94
        let m = MutCfg::new(100_u32);
95
        let old_ref = m.get();
96
        m.replace(101);
97
        assert_eq!(*old_ref, 100);
98
        assert_eq!(*m.get(), 101);
99
    }
100

            
101
    #[test]
102
    fn check_and_replace() {
103
        let m = MutCfg::new(100_u32);
104
        let different_100 = Arc::new(100_u32);
105
        // won't replace, since it is a different arc.
106
        assert!(!m.check_and_replace(&different_100, 200));
107
        let old_100 = m.get();
108
        assert_eq!(*old_100, 100);
109
        assert!(m.check_and_replace(&old_100, 200));
110
        assert_eq!(*m.get(), 200);
111
    }
112

            
113
    #[test]
114
    fn map_and_replace() {
115
        let m = MutCfg::new(100_u32);
116
        let m_old = m.get();
117
        m.map_and_replace(|old_val| **old_val * 20);
118
        assert_eq!(*m.get(), 2000);
119
        assert_eq!(*m_old, 100);
120
    }
121
}