1
//! Implement a sink-blocking policy that allows a limited number of items to be sent.
2

            
3
use nonany::NonMaxU32;
4
use tor_error::{Bug, internal};
5

            
6
/// A sink-blocking [`Policy`](super::Policy) that can allow a limited number of items to be sent.
7
///
8
/// This policy may be in three states:
9
///  - Completely blocked
10
///  - Completely unblocked: Able to send an unlimited number of items.
11
///  - _Will become blocked_ after a certain number of items are sent.
12
#[derive(Debug, Clone, Copy)]
13
pub(crate) struct CountingPolicy {
14
    /// The number of items that may currently be sent.
15
    ///
16
    /// `None` represents an unlimited number.
17
    remaining: Option<NonMaxU32>,
18
}
19

            
20
/// The largest possible limited number of cells in a CountingPolicy.
21
const MAX_LIMIT: NonMaxU32 = NonMaxU32::new(u32::MAX - 1).expect("Couldn't construct MAX_LIMIT");
22

            
23
impl CountingPolicy {
24
    /// Return a new unlimited CountingPolicy.
25
526
    pub(crate) fn new_unlimited() -> Self {
26
526
        Self { remaining: None }
27
526
    }
28

            
29
    /// Return a new completely blocked CountingPolicy.
30
2
    pub(crate) fn new_blocked() -> Self {
31
        Self {
32
            remaining: Some(
33
                const { NonMaxU32::new(0).expect("Couldn't construct NonMaxU32 from zero.") },
34
            ),
35
        }
36
2
    }
37

            
38
    /// Return a new CountingPolicy that allows `n` items, and then becomes blocked.
39
    ///
40
    /// # Limitations:
41
    ///
42
    /// If `n` is greater than `MAX_LIMIT`, only `MAX_LIMIT` items will be allowed.
43
12
    pub(crate) fn new_limited(n: u32) -> Self {
44
12
        Self {
45
12
            remaining: Some(NonMaxU32::new(n).unwrap_or(MAX_LIMIT)),
46
12
        }
47
12
    }
48

            
49
    /// Return a new CountingPolicy that allows up to `n` more items to be sent
50
    /// than this one.
51
    ///
52
    /// # Limitations:
53
    ///
54
    /// If the total number of allowed items would be greater than `MAX_LIMIT`,
55
    /// only `MAX_LIMIT` items will be allowed.
56
    //
57
    // Correctness: Note that this method returns a new CountingPolicy,
58
    // and does not change self.
59
    // Therefore it obeys the invariants of the `Policy` trait.
60
8
    fn saturating_add(&self, n: u32) -> Self {
61
8
        match self.remaining {
62
6
            Some(current) => Self::new_limited(current.get().saturating_add(n)),
63
2
            None => Self::new_unlimited(),
64
        }
65
8
    }
66
}
67

            
68
impl super::Policy for CountingPolicy {
69
5908
    fn is_blocking(&self) -> bool {
70
5913
        self.remaining.is_some_and(|n| n.get() == 0)
71
5908
    }
72

            
73
    // Correctness:
74
    //
75
    // This is the only method that takes a `&mut CountingPolicy`.
76
    // It can decrement the counter, but never increment it.
77
    // Therefore, it can cause `self` to become blocked,
78
    // but it cannot cause a blocked `self` to become unblocked.
79
    // Thus the invariants of the `Policy` trait are preserved.
80
4444
    fn take_one(&mut self) -> Result<(), Bug> {
81
4444
        match &mut self.remaining {
82
            // Unlimited: nothing to do.
83
4434
            None => Ok(()),
84

            
85
10
            Some(remaining) => {
86
10
                if let Some(n) = remaining.get().checked_sub(1) {
87
6
                    *remaining = n
88
6
                        .try_into()
89
6
                        .expect("Somehow subtracting 1 made us exceed MAX_LIMIT!?");
90
6
                    Ok(())
91
                } else {
92
4
                    Err(internal!(
93
4
                        "Tried to take_one() from a blocked CountingPolicy."
94
4
                    ))
95
                }
96
            }
97
        }
98
4444
    }
99
}
100

            
101
impl<S> super::SinkBlocker<S, CountingPolicy> {
102
    /// Put this `SinkBlocker` into a blocked state.
103
    pub(crate) fn set_blocked(&mut self) {
104
        self.update_policy(CountingPolicy::new_blocked());
105
    }
106

            
107
    /// Put this `SinkBlocker` into an unlimited state.
108
    pub(crate) fn set_unlimited(&mut self) {
109
        // Correctness: Note that this _replaces_ the Policy object,
110
        // and does not modify an existing Policy object.
111
        // This is the permitted way to make a SinkBlocker unblocked.
112
        self.update_policy(CountingPolicy::new_unlimited());
113
    }
114

            
115
    /// Allow `n` additional items to bypass the current blocking of this `SinkBlocker`.
116
    ///
117
    /// (This function has no effect if the `SinkBlocker` is currently unlimited.)
118
    pub(crate) fn allow_n_additional_items(&mut self, n: u32) {
119
        // Correctness: Note that this _replaces_ the Policy object,
120
        // and does not modify an existing Policy object.
121
        // This is the permitted way to make a SinkBlocker unblocked.
122
        self.update_policy(self.policy.saturating_add(n));
123
    }
124

            
125
    /// Return true if there is no limit on this policy.
126
    pub(crate) fn is_unlimited(&self) -> bool {
127
        self.policy.remaining.is_none()
128
    }
129
}
130

            
131
#[cfg(test)]
132
mod test {
133
    // @@ begin test lint list maintained by maint/add_warning @@
134
    #![allow(clippy::bool_assert_comparison)]
135
    #![allow(clippy::clone_on_copy)]
136
    #![allow(clippy::dbg_macro)]
137
    #![allow(clippy::mixed_attributes_style)]
138
    #![allow(clippy::print_stderr)]
139
    #![allow(clippy::print_stdout)]
140
    #![allow(clippy::single_char_pattern)]
141
    #![allow(clippy::unwrap_used)]
142
    #![allow(clippy::unchecked_time_subtraction)]
143
    #![allow(clippy::useless_vec)]
144
    #![allow(clippy::needless_pass_by_value)]
145
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
146

            
147
    use super::*;
148
    use crate::util::sink_blocker::Policy as _;
149

            
150
    #[test]
151
    fn counting_unlimited() {
152
        let mut unlimited = CountingPolicy::new_unlimited();
153
        assert_eq!(unlimited.is_blocking(), false);
154
        assert!(unlimited.take_one().is_ok());
155
        assert!(unlimited.take_one().is_ok());
156
        assert_eq!(unlimited.is_blocking(), false);
157
        let u2 = unlimited.saturating_add(99);
158
        assert!(u2.remaining.is_none()); // still unlimited.
159
    }
160

            
161
    #[test]
162
    fn counting_blocked() {
163
        let mut blocked = CountingPolicy::new_blocked();
164
        assert_eq!(blocked.is_blocking(), true);
165
        assert!(blocked.take_one().is_err());
166
        let mut u2 = blocked.saturating_add(99);
167
        assert_eq!(u2.remaining.unwrap().get(), 99); // New policy is limited  to 99.
168
        assert_eq!(u2.is_blocking(), false);
169
        assert!(u2.take_one().is_ok());
170
        assert_eq!(u2.remaining.unwrap().get(), 98); // You take one down, you pass it around...
171
    }
172

            
173
    #[test]
174
    fn counting_limited() {
175
        let mut limited = CountingPolicy::new_limited(2);
176
        assert_eq!(limited.is_blocking(), false);
177
        assert!(limited.take_one().is_ok());
178
        assert_eq!(limited.is_blocking(), false);
179
        assert!(limited.take_one().is_ok());
180
        assert_eq!(limited.is_blocking(), true);
181
        assert!(limited.take_one().is_err());
182

            
183
        let limited = CountingPolicy::new_limited(99);
184
        let lim2 = limited.saturating_add(25);
185
        assert_eq!(lim2.remaining.unwrap().get(), 25 + 99);
186
        let lim3 = limited.saturating_add(u32::MAX);
187
        assert_eq!(lim3.remaining.unwrap(), MAX_LIMIT);
188

            
189
        let limited = CountingPolicy::new_limited(u32::MAX);
190
        assert_eq!(limited.remaining.unwrap(), MAX_LIMIT);
191
    }
192
}