1
//! Code for turning safelogging on and off.
2
//!
3
//! By default, safelogging is on.  There are two ways to turn it off: Globally
4
//! (with [`disable_safe_logging`]) and locally (with
5
//! [`with_safe_logging_suppressed`]).
6

            
7
use crate::{Error, Result};
8
use fluid_let::fluid_let;
9
use std::sync::atomic::{AtomicIsize, Ordering};
10

            
11
/// A global atomic used to track locking guards for enabling and disabling
12
/// safe-logging.
13
///
14
/// The value of this atomic is less than 0 if we have enabled unsafe logging.
15
/// greater than 0 if we have enabled safe logging, and 0 if nobody cares.
16
static LOGGING_STATE: AtomicIsize = AtomicIsize::new(0);
17

            
18
fluid_let!(
19
    /// A dynamic variable used to temporarily disable safe-logging.
20
    static SAFE_LOGGING_SUPPRESSED_IN_THREAD: bool
21
);
22

            
23
/// Returns true if we are displaying sensitive values, false otherwise.
24
#[doc(hidden)]
25
183420
pub fn unsafe_logging_enabled() -> bool {
26
183420
    LOGGING_STATE.load(Ordering::Relaxed) < 0
27
170257
        || SAFE_LOGGING_SUPPRESSED_IN_THREAD.get(|v| v == Some(&true))
28
183420
}
29

            
30
/// Run a given function with the regular `safelog` functionality suppressed.
31
///
32
/// The provided function, and everything it calls, will display
33
/// [`Sensitive`](crate::Sensitive) values as if they were not sensitive.
34
///
35
/// # Examples
36
///
37
/// ```
38
/// use safelog::{Sensitive, with_safe_logging_suppressed};
39
///
40
/// let string = Sensitive::new("swordfish");
41
///
42
/// // Ordinarily, the string isn't displayed as normal
43
/// assert_eq!(format!("The value is {}", string),
44
///            "The value is [scrubbed]");
45
///
46
/// // But you can override that:
47
/// assert_eq!(
48
///     with_safe_logging_suppressed(|| format!("The value is {}", string)),
49
///     "The value is swordfish"
50
/// );
51
/// ```
52
20120
pub fn with_safe_logging_suppressed<F, V>(func: F) -> V
53
20120
where
54
20120
    F: FnOnce() -> V,
55
{
56
    // This sets the value of the variable to Some(true) temporarily, for as
57
    // long as `func` is being called.  It uses thread-local variables
58
    // internally.
59
20120
    SAFE_LOGGING_SUPPRESSED_IN_THREAD.set(true, func)
60
20120
}
61

            
62
/// Enum to describe what kind of a [`Guard`] we've created.
63
#[derive(Debug, Copy, Clone)]
64
enum GuardKind {
65
    /// We are forcing safe-logging to be enabled, so that nobody
66
    /// can turn it off with `disable_safe_logging`
67
    Safe,
68
    /// We have are turning safe-logging off with `disable_safe_logging`.
69
    Unsafe,
70
}
71

            
72
/// A guard object used to enforce safe logging, or turn it off.
73
///
74
/// For as long as this object exists, the chosen behavior will be enforced.
75
//
76
// TODO: Should there be different types for "keep safe logging on" and "turn
77
// safe logging off"?  Having the same type makes it easier to write code that
78
// does stuff like this:
79
//
80
//     let g = if cfg.safe {
81
//         enforce_safe_logging()
82
//     } else {
83
//         disable_safe_logging()
84
//     };
85
#[derive(Debug)]
86
#[must_use = "If you drop the guard immediately, it won't do anything."]
87
pub struct Guard {
88
    /// What kind of guard is this?
89
    kind: GuardKind,
90
}
91

            
92
impl GuardKind {
93
    /// Return an error if `val` (as a value of `LOGGING_STATE`) indicates that
94
    /// intended kind of guard cannot be created.
95
103926
    fn check(&self, val: isize) -> Result<()> {
96
103926
        match self {
97
            GuardKind::Safe => {
98
29948
                if val < 0 {
99
19548
                    return Err(Error::AlreadyUnsafe);
100
10400
                }
101
            }
102
            GuardKind::Unsafe => {
103
73978
                if val > 0 {
104
20456
                    return Err(Error::AlreadySafe);
105
53522
                }
106
            }
107
        }
108
63922
        Ok(())
109
103926
    }
110
    /// Return the value by which `LOGGING_STATE` should change while a guard of
111
    /// this type exists.
112
154900
    fn increment(&self) -> isize {
113
154900
        match self {
114
37448
            GuardKind::Safe => 1,
115
117452
            GuardKind::Unsafe => -1,
116
        }
117
154900
    }
118
}
119

            
120
impl Guard {
121
    /// Helper: Create a guard of a given kind.
122
97452
    fn new(kind: GuardKind) -> Result<Self> {
123
97452
        let inc = kind.increment();
124
        loop {
125
            // Find the current value of LOGGING_STATE and see if this guard can
126
            // be created.
127
103926
            let old_val = LOGGING_STATE.load(Ordering::SeqCst);
128
            // Exit if this guard can't be created.
129
103926
            kind.check(old_val)?;
130
            // Otherwise, try changing LOGGING_STATE to the new value that it
131
            // _should_ have when this guard exists.
132
63922
            let new_val = match old_val.checked_add(inc) {
133
63922
                Some(v) => v,
134
                None => return Err(Error::Overflow),
135
            };
136
57448
            if let Ok(v) =
137
63922
                LOGGING_STATE.compare_exchange(old_val, new_val, Ordering::SeqCst, Ordering::SeqCst)
138
            {
139
                // Great, we set the value to what it should be; we're done.
140
57448
                debug_assert_eq!(v, old_val);
141
57448
                return Ok(Self { kind });
142
6474
            }
143
            // Otherwise, somebody else altered this value concurrently: try
144
            // again.
145
        }
146
97452
    }
147
}
148

            
149
impl Drop for Guard {
150
57448
    fn drop(&mut self) {
151
57448
        let inc = self.kind.increment();
152
57448
        LOGGING_STATE.fetch_sub(inc, Ordering::SeqCst);
153
57448
    }
154
}
155

            
156
/// Create a new [`Guard`] to prevent anyone else from disabling safe logging.
157
///
158
/// Until the resulting `Guard` is dropped, any attempts to call
159
/// `disable_safe_logging` will give an error.  This guard does _not_ affect
160
/// calls to [`with_safe_logging_suppressed`].
161
///
162
/// This call will return an error if safe logging is _already_ disabled.
163
///
164
/// Note that this function is called "enforce", not "enable", since safe
165
/// logging is enabled by default.  Its purpose is to make sure that nothing
166
/// _else_ has called disable_safe_logging().
167
28498
pub fn enforce_safe_logging() -> Result<Guard> {
168
28498
    Guard::new(GuardKind::Safe)
169
28498
}
170

            
171
/// Create a new [`Guard`] to disable safe logging.
172
///
173
/// Until the resulting `Guard` is dropped, all [`Sensitive`](crate::Sensitive)
174
/// values will be displayed as if they were not sensitive.
175
///
176
/// This call will return an error if safe logging has been enforced with
177
/// [`enforce_safe_logging`].
178
68954
pub fn disable_safe_logging() -> Result<Guard> {
179
68954
    Guard::new(GuardKind::Unsafe)
180
68954
}
181

            
182
#[cfg(test)]
183
mod test {
184
    // @@ begin test lint list maintained by maint/add_warning @@
185
    #![allow(clippy::bool_assert_comparison)]
186
    #![allow(clippy::clone_on_copy)]
187
    #![allow(clippy::dbg_macro)]
188
    #![allow(clippy::mixed_attributes_style)]
189
    #![allow(clippy::print_stderr)]
190
    #![allow(clippy::print_stdout)]
191
    #![allow(clippy::single_char_pattern)]
192
    #![allow(clippy::unwrap_used)]
193
    #![allow(clippy::unchecked_time_subtraction)]
194
    #![allow(clippy::useless_vec)]
195
    #![allow(clippy::needless_pass_by_value)]
196
    #![allow(clippy::string_slice)] // See arti#2571
197
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
198
    use super::*;
199
    // We use "serial_test" to make sure that our tests here run one at a time,
200
    // since they modify global state.
201
    use serial_test::serial;
202

            
203
    #[test]
204
    #[serial]
205
    fn guards() {
206
        // Try operations with logging guards turned on and off, in a single
207
        // thread.
208
        assert!(!unsafe_logging_enabled());
209
        let g1 = enforce_safe_logging().unwrap();
210
        let g2 = enforce_safe_logging().unwrap();
211

            
212
        assert!(!unsafe_logging_enabled());
213

            
214
        let e = disable_safe_logging();
215
        assert!(matches!(e, Err(Error::AlreadySafe)));
216
        assert!(!unsafe_logging_enabled());
217

            
218
        drop(g1);
219
        drop(g2);
220
        let _g3 = disable_safe_logging().unwrap();
221
        assert!(unsafe_logging_enabled());
222
        let e = enforce_safe_logging();
223
        assert!(matches!(e, Err(Error::AlreadyUnsafe)));
224
        assert!(unsafe_logging_enabled());
225
        let _g4 = disable_safe_logging().unwrap();
226

            
227
        assert!(unsafe_logging_enabled());
228
    }
229

            
230
    #[test]
231
    #[serial]
232
    fn suppress() {
233
        // Try out `with_safe_logging_suppressed` and make sure it does what we want
234
        // regardless of the initial state of logging.
235
        {
236
            let _g = enforce_safe_logging().unwrap();
237
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
238
            assert!(!unsafe_logging_enabled());
239
        }
240

            
241
        {
242
            assert!(!unsafe_logging_enabled());
243
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
244
            assert!(!unsafe_logging_enabled());
245
        }
246

            
247
        {
248
            let _g = disable_safe_logging().unwrap();
249
            assert!(unsafe_logging_enabled());
250
            with_safe_logging_suppressed(|| assert!(unsafe_logging_enabled()));
251
        }
252
    }
253

            
254
    #[test]
255
    #[serial]
256
    fn interfere_1() {
257
        // Make sure that two threads trying to enforce and disable safe logging
258
        // can interfere with each other, but will never enter an incorrect
259
        // state.
260
        use std::thread::{spawn, yield_now};
261

            
262
        let thread1 = spawn(|| {
263
            for _ in 0..10_000 {
264
                if let Ok(_g) = enforce_safe_logging() {
265
                    assert!(!unsafe_logging_enabled());
266
                    yield_now();
267
                    assert!(disable_safe_logging().is_err());
268
                }
269
                yield_now();
270
            }
271
        });
272

            
273
        let thread2 = spawn(|| {
274
            for _ in 0..10_000 {
275
                if let Ok(_g) = disable_safe_logging() {
276
                    assert!(unsafe_logging_enabled());
277
                    yield_now();
278
                    assert!(enforce_safe_logging().is_err());
279
                }
280
                yield_now();
281
            }
282
        });
283

            
284
        thread1.join().unwrap();
285
        thread2.join().unwrap();
286
    }
287

            
288
    #[test]
289
    #[serial]
290
    fn interfere_2() {
291
        // Make sure that two threads trying to disable safe logging don't
292
        // interfere.
293
        use std::thread::{spawn, yield_now};
294

            
295
        let thread1 = spawn(|| {
296
            for _ in 0..10_000 {
297
                let g = disable_safe_logging().unwrap();
298
                assert!(unsafe_logging_enabled());
299
                yield_now();
300
                drop(g);
301
                yield_now();
302
            }
303
        });
304

            
305
        let thread2 = spawn(|| {
306
            for _ in 0..10_000 {
307
                let g = disable_safe_logging().unwrap();
308
                assert!(unsafe_logging_enabled());
309
                yield_now();
310
                drop(g);
311
                yield_now();
312
            }
313
        });
314

            
315
        thread1.join().unwrap();
316
        thread2.join().unwrap();
317
    }
318

            
319
    #[test]
320
    #[serial]
321
    fn interfere_3() {
322
        // Make sure that `with_safe_logging_suppressed` only applies to the
323
        // current thread.
324
        use std::thread::{spawn, yield_now};
325

            
326
        let thread1 = spawn(|| {
327
            for _ in 0..10_000 {
328
                assert!(!unsafe_logging_enabled());
329
                yield_now();
330
            }
331
        });
332

            
333
        let thread2 = spawn(|| {
334
            for _ in 0..10_000 {
335
                assert!(!unsafe_logging_enabled());
336
                with_safe_logging_suppressed(|| {
337
                    assert!(unsafe_logging_enabled());
338
                    yield_now();
339
                });
340
            }
341
        });
342

            
343
        thread1.join().unwrap();
344
        thread2.join().unwrap();
345
    }
346
}