1
//! Implement [`SinkBlocker`], a wrapper type to allow policy-based blocking of
2
//! a [futures::Sink].
3

            
4
#![cfg_attr(not(feature = "circ-padding"), expect(dead_code))]
5

            
6
mod boolean_policy;
7
mod counting_policy;
8

            
9
pub(crate) use boolean_policy::BooleanPolicy;
10
pub(crate) use counting_policy::CountingPolicy;
11

            
12
use std::{
13
    pin::Pin,
14
    task::{Context, Poll, Waker},
15
};
16

            
17
use futures::Sink;
18
use pin_project::pin_project;
19
use tor_error::Bug;
20

            
21
/// A wrapper for a [`futures::Sink`] that allows its blocking status to be
22
/// turned on and off according to a policy.
23
///
24
/// While the policy is blocking, attempts to enqueue data on the sink
25
/// via this `Sink` trait will return [`Poll::Pending`].
26
/// Later, when the policy is replaced with a nonblocking one via [`Self::update_policy()`]
27
/// this sink can be written to again.
28
#[pin_project]
29
pub(crate) struct SinkBlocker<S, P = BooleanPolicy> {
30
    /// The inner sink.
31
    #[pin]
32
    inner: S,
33
    /// A policy state object, deciding whether we are blocking or not.
34
    ///
35
    /// Invariant: Whenever we try to send with a blocking Policy,
36
    /// we store the context's waker in self.waker.
37
    /// If later the policy becomes non-blocking,
38
    /// we we alert the `Waker`.
39
    policy: P,
40
    /// A waker that we should alert when `policy` transitions from
41
    /// a blocking to a non-blocking state.
42
    waker: Option<Waker>,
43
}
44

            
45
/// A policy that describes whether cells can be sent on a [`SinkBlocker`].
46
///
47
/// Each `Policy` object can be in different states:
48
/// some states cause the `SinkBlocker` to block traffic,
49
/// and some cause the `SinkBlocker` to permit traffic.
50
///
51
/// The user of a `SinkBlocker` is expected to call
52
/// [`update_policy()`](SinkBlocker::update_policy) from time to time,
53
/// when they need to make a manual change in the `SinkBlocker`'s status.
54
/// This is the only way for a blocked `SinkBlocker` to become unblocked.
55
///
56
/// Invariants:
57
///  - The state of a `Policy` object may transition from
58
///    non-blocking to blocking.
59
///  - The state of a `Policy` object may _not_ transition
60
///    from blocking to non-blocking.
61
///  - If [`is_blocking()`](Policy::is_blocking) has returned false,
62
///    and no intervening changes have been made to the `Policy`,
63
///    [`take_one()`](Policy::take_one) will succeed.
64
///
65
/// Note that because of this last invariant,
66
/// interior mutability is strongly discouraged for implementations of this trait.
67
pub(crate) trait Policy {
68
    /// Returns true if this policy is currently blocking.
69
    ///
70
    /// Invariant: If this returns true on a given Policy,
71
    /// it must always return true on that Policy in the future.
72
    /// (That is, a Policy may become blocked,
73
    /// but may not become unblocked.)
74
    fn is_blocking(&self) -> bool;
75

            
76
    /// Modify this policy in response to having queued one item.
77
    ///
78
    /// Requires that `self.is_blocking()` has just returned false.
79
    /// Returns an error, and does not change `self`, if this _is_ blocked.
80
    /// (That is, you must only call this function on a non-blocked Policy.)
81
    //
82
    // Notes: The above rules mean that `take_one` can transition from
83
    // unblocking to blocking, but never vice versa.
84
    fn take_one(&mut self) -> Result<(), Bug>;
85
}
86

            
87
impl<S, P> SinkBlocker<S, P> {
88
    /// Construct a new `SinkBlocker` wrapping a given sink, with a given
89
    /// initial blocking policy.
90
856
    pub(crate) fn new(inner: S, policy: P) -> Self {
91
856
        SinkBlocker {
92
856
            inner,
93
856
            policy,
94
856
            waker: None,
95
856
        }
96
856
    }
97

            
98
    /// Return a reference to the inner `Sink` of this object.
99
    ///
100
    /// See warnings on `as_inner_mut`.
101
4754
    pub(crate) fn as_inner(&self) -> &S {
102
4754
        &self.inner
103
4754
    }
104

            
105
    /// Return a mutable reference to the inner `Sink` of this object.
106
    ///
107
    /// Note that with this method, it is possible to bypass the blocking features
108
    /// of [`SinkBlocker`].  This is an intentional escape hatch.
109
11764
    pub(crate) fn as_inner_mut(&mut self) -> &mut S {
110
11764
        &mut self.inner
111
11764
    }
112
}
113

            
114
impl<S, P: Policy> SinkBlocker<S, P> {
115
    /// Replace the current [`Policy`] state object with `new_policy`.
116
    ///
117
    /// This method is used to make a blocked `SinkBlocker` unblocked,
118
    /// or vice versa.
119
    //
120
    // Invariants: If we become unblocked, alerts our `Waker`.
121
    //
122
    // (This is the only method that can cause us to transition from blocked to
123
    // unblocked, so this is the only place where we have to alert the waker.)
124
8
    pub(crate) fn update_policy(&mut self, new_policy: P) {
125
8
        let was_blocking = self.policy.is_blocking();
126
8
        let is_blocking = new_policy.is_blocking();
127
8
        self.policy = new_policy;
128
8
        if was_blocking && !is_blocking {
129
4
            if let Some(waker) = self.waker.take() {
130
4
                waker.wake();
131
4
            }
132
4
        }
133
8
    }
134
}
135

            
136
impl<T, S: Sink<T>, P: Policy> Sink<T> for SinkBlocker<S, P> {
137
    type Error = S::Error;
138

            
139
13020
    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140
13020
        let self_ = self.project();
141
13020
        if self_.policy.is_blocking() {
142
            // We're blocked.  We're going to store the context's Waker,
143
            // so that we can invoke it later when the policy changes.
144
12
            *self_.waker = Some(cx.waker().clone());
145
12
            Poll::Pending
146
        } else {
147
            // If this returns Ready, great!
148
            // If this returns Pending, it will wake up the context when it is
149
            // no longer blocked.
150
13008
            self_.inner.poll_ready(cx)
151
        }
152
13020
    }
153

            
154
4464
    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
155
4464
        let self_ = self.project();
156
        // We're only allowed to call this method if poll_ready succeeded,
157
        // so we know that is_blocking() was false.
158
4464
        let () = self_.inner.start_send(item)?;
159

            
160
        // (Invoke take_one, to account for this item.)
161
        //
162
        // Note: Instead of calling expect, perhaps it would be better to have a custom error type
163
        // that wraps S::Error and also allows for a Bug.  But that might be overkill, since
164
        // we only expect this error to happen in the event of a bug.
165
4462
        let _: () = self_.policy.take_one().expect(
166
4462
            "take_one failed after is_blocking returned false: bug in Policy or SinkBlocker",
167
4462
        );
168
        // (Take_one is not allowed to cause us to become unblocked, so we don't
169
        // need to invoke the waiter.)
170

            
171
4462
        Ok(())
172
4464
    }
173

            
174
1262
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175
        // Note that we want to flush the inner sink,
176
        // even if we are blocking attempts to send onto it.
177
1262
        self.project().inner.poll_flush(cx)
178
1262
    }
179

            
180
4
    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181
4
        self.project().inner.poll_close(cx)
182
4
    }
183
}
184

            
185
#[cfg(test)]
186
mod test {
187
    // @@ begin test lint list maintained by maint/add_warning @@
188
    #![allow(clippy::bool_assert_comparison)]
189
    #![allow(clippy::clone_on_copy)]
190
    #![allow(clippy::dbg_macro)]
191
    #![allow(clippy::mixed_attributes_style)]
192
    #![allow(clippy::print_stderr)]
193
    #![allow(clippy::print_stdout)]
194
    #![allow(clippy::single_char_pattern)]
195
    #![allow(clippy::unwrap_used)]
196
    #![allow(clippy::unchecked_time_subtraction)]
197
    #![allow(clippy::useless_vec)]
198
    #![allow(clippy::needless_pass_by_value)]
199
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
200

            
201
    use std::sync::{
202
        Arc,
203
        atomic::{AtomicBool, AtomicUsize, Ordering},
204
    };
205

            
206
    use super::*;
207

            
208
    use futures::{SinkExt as _, StreamExt as _, channel::mpsc, poll};
209
    use tor_rtmock::MockRuntime;
210

            
211
    #[test]
212
    fn block_and_unblock() {
213
        // Try a few different schedulers, to make sure that our logic works for all of them.
214
        MockRuntime::test_with_various(|runtime| async move {
215
            let (tx, mut rx) = mpsc::channel::<u32>(1);
216
            let tx = SinkBlocker::new(tx, BooleanPolicy::Unblocked);
217
            let mut tx = tx.buffer(5);
218

            
219
            let blocked = Arc::new(AtomicBool::new(false));
220
            let n_received = Arc::new(AtomicUsize::new(0));
221

            
222
            let blocked_clone = Arc::clone(&blocked);
223
            let n_received_clone = Arc::clone(&n_received);
224
            let n_received_clone2 = Arc::clone(&n_received);
225

            
226
            runtime.spawn_identified("Transmitter", async move {
227
                tx.send(1).await.unwrap();
228
                tx.send(2).await.unwrap();
229
                blocked.store(true, Ordering::SeqCst);
230
                tx.get_mut().set_blocked();
231
                // Have to use "feed" here since send would flush, which would block.
232
                tx.feed(3).await.unwrap();
233
                tx.feed(4).await.unwrap();
234
                assert!(dbg!(n_received.load(Ordering::SeqCst)) <= 2);
235
                // Make sure that we _cannot_ flush right now.
236
                let flush_future = tx.flush();
237
                assert!(poll!(flush_future).is_pending());
238
                // Now note that we're unblocked, and unblock.
239
                blocked.store(false, Ordering::SeqCst);
240
                tx.get_mut().set_unblocked();
241
                // This time we should actually flush.
242
                tx.flush().await.unwrap();
243
                tx.close().await.unwrap();
244
            });
245

            
246
            runtime.spawn_identified("Receiver", async move {
247
                let n_received = n_received_clone;
248
                let blocked = blocked_clone;
249
                let mut expected = 1;
250
                while let Some(val) = rx.next().await {
251
                    assert_eq!(val, expected);
252
                    expected += 1;
253
                    n_received.fetch_add(1, Ordering::SeqCst);
254
                    if val >= 3 {
255
                        assert_eq!(blocked.load(Ordering::SeqCst), false);
256
                    }
257
                }
258
                dbg!(expected);
259
            });
260

            
261
            runtime.progress_until_stalled().await;
262

            
263
            assert_eq!(dbg!(n_received_clone2.load(Ordering::SeqCst)), 4);
264
        });
265
    }
266
}