1
//! Queues for stream messages.
2
//!
3
//! While these are technically "channels", we call them "queues" to indicate that they're mostly
4
//! just dumb pipes. They do some tracking (memquota and size), but nothing else. The higher-level
5
//! object is [`StreamReceiver`](crate::stream::raw::StreamReceiver) which tracks SENDME and END
6
//! messages. So the idea is that the "queue" (ex: [`StreamQueueReceiver`]) just holds data and the
7
//! "channel" (ex: `StreamReceiver`) adds the Tor logic.
8
//!
9
//! The main purpose of these types are so that we can count how many bytes of stream data are
10
//! stored for the stream. Ideally we'd use a channel type that tracks and reports this as part of
11
//! its implementation, but popular channel implementations don't seem to do that.
12

            
13
use std::fmt::Debug;
14
use std::pin::Pin;
15
use std::sync::{Arc, Mutex};
16
use std::task::{Context, Poll};
17

            
18
use futures::{Sink, SinkExt, Stream};
19
use tor_async_utils::SinkTrySend;
20
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
21
use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
22
use tor_cell::relaycell::UnparsedRelayMsg;
23
use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec, MpscUnboundedSpec};
24
use tor_rtcompat::DynTimeProvider;
25

            
26
use crate::memquota::{SpecificAccount, StreamAccount};
27

            
28
// TODO(arti#534): remove these type aliases when we remove the "flowctl-cc" feature,
29
// and just use `MpscUnboundedSpec` everywhere
30
#[cfg(feature = "flowctl-cc")]
31
/// Alias for the memquota mpsc spec.
32
type Spec = MpscUnboundedSpec;
33
#[cfg(not(feature = "flowctl-cc"))]
34
/// Alias for the memquota mpsc spec.
35
type Spec = MpscSpec;
36

            
37
/// Create a new stream queue for incoming messages.
38
388
pub(crate) fn stream_queue(
39
388
    #[cfg(not(feature = "flowctl-cc"))] size: usize,
40
388
    memquota: &StreamAccount,
41
388
    time_prov: &DynTimeProvider,
42
388
) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
43
388
    let (sender, receiver) = {
44
        cfg_if::cfg_if! {
45
            if #[cfg(not(feature = "flowctl-cc"))] {
46
                MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?
47
            } else {
48
388
                MpscUnboundedSpec::new().new_mq(time_prov.clone(), memquota.as_raw_account())?
49
            }
50
        }
51
    };
52

            
53
388
    let receiver = StreamUnobtrusivePeeker::new(receiver);
54
388
    let counter = Arc::new(Mutex::new(0));
55
388
    Ok((
56
388
        StreamQueueSender {
57
388
            sender,
58
388
            counter: Arc::clone(&counter),
59
388
        },
60
388
        StreamQueueReceiver { receiver, counter },
61
388
    ))
62
388
}
63

            
64
/// For testing purposes, create a stream queue wth a no-op memquota account and a fake time
65
/// provider.
66
#[cfg(test)]
67
256
pub(crate) fn fake_stream_queue(
68
256
    #[cfg(not(feature = "flowctl-cc"))] size: usize,
69
256
) -> (StreamQueueSender, StreamQueueReceiver) {
70
    // The fake Account doesn't care about the data ages, so this will do.
71
    //
72
    // This would be wrong to use generally in tests, where we might want to mock time,
73
    // since we end up, here with totally *different* mocked time.
74
    // But it's OK here, and saves passing a runtime parameter into this function.
75
256
    stream_queue(
76
        #[cfg(not(feature = "flowctl-cc"))]
77
        size,
78
256
        &StreamAccount::new_noop(),
79
256
        &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
80
    )
81
256
    .expect("create fake stream queue")
82
256
}
83

            
84
/// The sending end of a channel of incoming stream messages.
85
#[derive(Debug)]
86
#[pin_project::pin_project]
87
pub(crate) struct StreamQueueSender {
88
    /// The inner sender.
89
    #[pin]
90
    sender: mq_queue::Sender<UnparsedRelayMsg, Spec>,
91
    /// Number of bytes within the queue.
92
    counter: Arc<Mutex<usize>>,
93
}
94

            
95
/// The receiving end of a channel of incoming stream messages.
96
#[derive(Debug)]
97
#[pin_project::pin_project]
98
pub(crate) struct StreamQueueReceiver {
99
    /// The inner receiver.
100
    ///
101
    /// We add the [`StreamUnobtrusivePeeker`] here so that peeked messages are included in
102
    /// `counter`.
103
    // TODO(arti#534): the possible extra msg held by the `StreamUnobtrusivePeeker` isn't tracked by
104
    // memquota
105
    #[pin]
106
    receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, Spec>>,
107
    /// Number of bytes within the queue.
108
    counter: Arc<Mutex<usize>>,
109
}
110

            
111
impl StreamQueueSender {
112
    /// Get the approximate number of data bytes queued for this stream.
113
    ///
114
    /// As messages can be dequeued at any time, the return value may be larger than the actual
115
    /// number of bytes queued for this stream.
116
64
    pub(crate) fn approx_stream_bytes(&self) -> usize {
117
64
        *self.counter.lock().expect("poisoned")
118
64
    }
119
}
120

            
121
impl StreamQueueReceiver {
122
    /// Get the approximate number of data bytes queued for this stream.
123
    ///
124
    /// As messages can be enqueued at any time, the return value may be smaller than the actual
125
    /// number of bytes queued for this stream.
126
    pub(crate) fn approx_stream_bytes(&self) -> usize {
127
        *self.counter.lock().expect("poisoned")
128
    }
129
}
130

            
131
impl Sink<UnparsedRelayMsg> for StreamQueueSender {
132
    type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
133

            
134
    fn poll_ready(
135
        mut self: Pin<&mut Self>,
136
        cx: &mut Context<'_>,
137
    ) -> Poll<std::result::Result<(), Self::Error>> {
138
        self.sender.poll_ready_unpin(cx)
139
    }
140

            
141
    fn start_send(
142
        mut self: Pin<&mut Self>,
143
        item: UnparsedRelayMsg,
144
    ) -> std::result::Result<(), Self::Error> {
145
        let mut self_ = self.as_mut().project();
146

            
147
        let stream_data_len = data_len(&item);
148

            
149
        // This lock ensures that us sending the item and the counter increase are done
150
        // "atomically", so that the receiver doesn't see the item and try to decrement the
151
        // counter before we've incremented the counter, which could cause an underflow.
152
        let mut counter = self_.counter.lock().expect("poisoned");
153

            
154
        self_.sender.start_send_unpin(item)?;
155

            
156
        *counter = counter
157
            .checked_add(stream_data_len.into())
158
            .expect("queue has more than `usize::MAX` bytes?!");
159

            
160
        Ok(())
161
    }
162

            
163
    fn poll_flush(
164
        mut self: Pin<&mut Self>,
165
        cx: &mut Context<'_>,
166
    ) -> Poll<std::result::Result<(), Self::Error>> {
167
        self.sender.poll_flush_unpin(cx)
168
    }
169

            
170
    fn poll_close(
171
        mut self: Pin<&mut Self>,
172
        cx: &mut Context<'_>,
173
    ) -> Poll<std::result::Result<(), Self::Error>> {
174
        self.sender.poll_close_unpin(cx)
175
    }
176
}
177

            
178
impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
179
    type Error =
180
        <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
181

            
182
208
    fn try_send_or_return(
183
208
        mut self: Pin<&mut Self>,
184
208
        item: UnparsedRelayMsg,
185
208
    ) -> Result<
186
208
        (),
187
208
        (
188
208
            <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
189
208
            UnparsedRelayMsg,
190
208
        ),
191
208
    > {
192
208
        let self_ = self.as_mut().project();
193

            
194
208
        let stream_data_len = data_len(&item);
195

            
196
        // See comments in `StreamQueueSender::start_send`.
197
208
        let mut counter = self_.counter.lock().expect("poisoned");
198

            
199
208
        self_.sender.try_send_or_return(item)?;
200

            
201
208
        *counter = counter
202
208
            .checked_add(stream_data_len.into())
203
208
            .expect("queue has more than `usize::MAX` bytes?!");
204

            
205
208
        Ok(())
206
208
    }
207
}
208

            
209
impl Stream for StreamQueueReceiver {
210
    type Item = UnparsedRelayMsg;
211

            
212
428
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213
428
        let self_ = self.as_mut().project();
214

            
215
        // This lock ensures that us receiving the item and the counter decrease are done
216
        // "atomically", so that the sender doesn't send a new item and try to increase the
217
        // counter before we've decreased the counter, which could cause an overflow.
218
428
        let mut counter = self_.counter.lock().expect("poisoned");
219

            
220
428
        let item = match self_.receiver.poll_next(cx) {
221
208
            Poll::Ready(Some(x)) => x,
222
            Poll::Ready(None) => return Poll::Ready(None),
223
220
            Poll::Pending => return Poll::Pending,
224
        };
225

            
226
208
        let stream_data_len = data_len(&item);
227

            
228
208
        if stream_data_len != 0 {
229
88
            *counter = counter
230
88
                .checked_sub(stream_data_len.into())
231
88
                .expect("we've removed more bytes than we've added?!");
232
120
        }
233

            
234
208
        Poll::Ready(Some(item))
235
428
    }
236
}
237

            
238
impl UnobtrusivePeekableStream for StreamQueueReceiver {
239
    fn unobtrusive_peek_mut<'s>(
240
        self: Pin<&'s mut Self>,
241
    ) -> Option<&'s mut <Self as futures::Stream>::Item> {
242
        self.project().receiver.unobtrusive_peek_mut()
243
    }
244
}
245

            
246
/// The `length` field of the message, or 0 if not a data message.
247
///
248
/// If the RELAY_DATA message had an invalid length field, we just ignore the message.
249
/// The receiver will find out eventually when it tries to parse the message.
250
/// We could return an error here, but for now I think it's best not to behave as if this
251
/// queue is performing any validation.
252
///
253
/// This is its own function so that all parts of the code use the same logic.
254
416
fn data_len(item: &UnparsedRelayMsg) -> u16 {
255
416
    item.data_len().unwrap_or(0)
256
416
}