1
//! Queues for stream messages.
2
//!
3
//! While these are technically "channels", we call them "queues" to indicate that they're mostly
4
//! just dumb pipes. They do some tracking (memquota and size), but nothing else. The higher-level
5
//! object is [`StreamReceiver`](crate::stream::raw::StreamReceiver) which tracks SENDME and END
6
//! messages. So the idea is that the "queue" (ex: [`StreamQueueReceiver`]) just holds data and the
7
//! "channel" (ex: `StreamReceiver`) adds the Tor logic.
8
//!
9
//! The main purpose of these types is so that we can count how many bytes of stream data are
10
//! stored for the stream. Ideally we'd use a channel type that tracks and reports this as part of
11
//! its implementation, but popular channel implementations don't seem to do that.
12

            
13
use std::fmt::Debug;
14
use std::pin::Pin;
15
use std::sync::{Arc, Mutex};
16
use std::task::{Context, Poll};
17

            
18
use futures::{Sink, SinkExt, Stream};
19
use tor_async_utils::SinkTrySend;
20
use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
21
use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
22
use tor_cell::relaycell::UnparsedRelayMsg;
23
use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec};
24
use tor_rtcompat::DynTimeProvider;
25

            
26
use crate::memquota::{SpecificAccount, StreamAccount};
27

            
28
/// Create a new stream queue for incoming messages
29
/// (messages arriving on the stream from the Tor network).
30
392
pub(crate) fn stream_queue(
31
392
    size: usize,
32
392
    memquota: &StreamAccount,
33
392
    time_prov: &DynTimeProvider,
34
392
) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
35
    // Note that the size here may be very large,
36
    // for example when used with XON/XOFF flow control.
37
    //
38
    // Someday if we remove support for window-based flow control
39
    // and only support XON/XOFF flow control,
40
    // we may want to make this unbounded instead.
41
    // https://gitlab.torproject.org/tpo/core/arti/-/work_items/2412
42
392
    let (sender, receiver) =
43
392
        MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?;
44

            
45
392
    let receiver = StreamUnobtrusivePeeker::new(receiver);
46
392
    let counter = Arc::new(Mutex::new(0));
47
392
    Ok((
48
392
        StreamQueueSender {
49
392
            sender,
50
392
            counter: Arc::clone(&counter),
51
392
        },
52
392
        StreamQueueReceiver { receiver, counter },
53
392
    ))
54
392
}
55

            
56
/// For testing purposes, create a stream queue with a no-op memquota account and a fake time
57
/// provider.
58
#[cfg(test)]
59
256
pub(crate) fn fake_stream_queue(size: usize) -> (StreamQueueSender, StreamQueueReceiver) {
60
    // The fake Account doesn't care about the data ages, so this will do.
61
    //
62
    // This would be wrong to use generally in tests, where we might want to mock time,
63
    // since we end up, here with totally *different* mocked time.
64
    // But it's OK here, and saves passing a runtime parameter into this function.
65
256
    stream_queue(
66
256
        size,
67
256
        &StreamAccount::new_noop(),
68
256
        &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
69
    )
70
256
    .expect("create fake stream queue")
71
256
}
72

            
73
/// The sending end of a channel of incoming stream messages.
74
#[derive(Debug)]
75
#[pin_project::pin_project]
76
pub(crate) struct StreamQueueSender {
77
    /// The inner sender.
78
    #[pin]
79
    sender: mq_queue::Sender<UnparsedRelayMsg, MpscSpec>,
80
    /// Number of bytes within the queue.
81
    counter: Arc<Mutex<usize>>,
82
}
83

            
84
/// The receiving end of a channel of incoming stream messages.
85
#[derive(Debug)]
86
#[pin_project::pin_project]
87
pub(crate) struct StreamQueueReceiver {
88
    /// The inner receiver.
89
    ///
90
    /// We add the [`StreamUnobtrusivePeeker`] here so that peeked messages are included in
91
    /// `counter`.
92
    // TODO(arti#534): the possible extra msg held by the `StreamUnobtrusivePeeker` isn't tracked by
93
    // memquota
94
    #[pin]
95
    receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, MpscSpec>>,
96
    /// Number of bytes within the queue.
97
    counter: Arc<Mutex<usize>>,
98
}
99

            
100
impl StreamQueueSender {
101
    /// Get the approximate number of data bytes queued for this stream.
102
    ///
103
    /// As messages can be dequeued at any time, the return value may be larger than the actual
104
    /// number of bytes queued for this stream.
105
68
    pub(crate) fn approx_stream_bytes(&self) -> usize {
106
68
        *self.counter.lock().expect("poisoned")
107
68
    }
108
}
109

            
110
impl StreamQueueReceiver {
111
    /// Get the approximate number of data bytes queued for this stream.
112
    ///
113
    /// As messages can be enqueued at any time, the return value may be smaller than the actual
114
    /// number of bytes queued for this stream.
115
    pub(crate) fn approx_stream_bytes(&self) -> usize {
116
        *self.counter.lock().expect("poisoned")
117
    }
118
}
119

            
120
impl Sink<UnparsedRelayMsg> for StreamQueueSender {
121
    type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
122

            
123
    fn poll_ready(
124
        mut self: Pin<&mut Self>,
125
        cx: &mut Context<'_>,
126
    ) -> Poll<std::result::Result<(), Self::Error>> {
127
        self.sender.poll_ready_unpin(cx)
128
    }
129

            
130
    fn start_send(
131
        mut self: Pin<&mut Self>,
132
        item: UnparsedRelayMsg,
133
    ) -> std::result::Result<(), Self::Error> {
134
        let mut self_ = self.as_mut().project();
135

            
136
        let stream_data_len = data_len(&item);
137

            
138
        // This lock ensures that us sending the item and the counter increase are done
139
        // "atomically", so that the receiver doesn't see the item and try to decrement the
140
        // counter before we've incremented the counter, which could cause an underflow.
141
        let mut counter = self_.counter.lock().expect("poisoned");
142

            
143
        self_.sender.start_send_unpin(item)?;
144

            
145
        *counter = counter
146
            .checked_add(stream_data_len.into())
147
            .expect("queue has more than `usize::MAX` bytes?!");
148

            
149
        Ok(())
150
    }
151

            
152
    fn poll_flush(
153
        mut self: Pin<&mut Self>,
154
        cx: &mut Context<'_>,
155
    ) -> Poll<std::result::Result<(), Self::Error>> {
156
        self.sender.poll_flush_unpin(cx)
157
    }
158

            
159
    fn poll_close(
160
        mut self: Pin<&mut Self>,
161
        cx: &mut Context<'_>,
162
    ) -> Poll<std::result::Result<(), Self::Error>> {
163
        self.sender.poll_close_unpin(cx)
164
    }
165
}
166

            
167
impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
168
    type Error =
169
        <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
170

            
171
212
    fn try_send_or_return(
172
212
        mut self: Pin<&mut Self>,
173
212
        item: UnparsedRelayMsg,
174
212
    ) -> Result<
175
212
        (),
176
212
        (
177
212
            <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
178
212
            UnparsedRelayMsg,
179
212
        ),
180
212
    > {
181
212
        let self_ = self.as_mut().project();
182

            
183
212
        let stream_data_len = data_len(&item);
184

            
185
        // See comments in `StreamQueueSender::start_send`.
186
212
        let mut counter = self_.counter.lock().expect("poisoned");
187

            
188
212
        self_.sender.try_send_or_return(item)?;
189

            
190
212
        *counter = counter
191
212
            .checked_add(stream_data_len.into())
192
212
            .expect("queue has more than `usize::MAX` bytes?!");
193

            
194
212
        Ok(())
195
212
    }
196
}
197

            
198
impl Stream for StreamQueueReceiver {
199
    type Item = UnparsedRelayMsg;
200

            
201
440
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202
440
        let self_ = self.as_mut().project();
203

            
204
        // This lock ensures that us receiving the item and the counter decrease are done
205
        // "atomically", so that the sender doesn't send a new item and try to increase the
206
        // counter before we've decreased the counter, which could cause an overflow.
207
440
        let mut counter = self_.counter.lock().expect("poisoned");
208

            
209
440
        let item = match self_.receiver.poll_next(cx) {
210
212
            Poll::Ready(Some(x)) => x,
211
            Poll::Ready(None) => return Poll::Ready(None),
212
228
            Poll::Pending => return Poll::Pending,
213
        };
214

            
215
212
        let stream_data_len = data_len(&item);
216

            
217
212
        if stream_data_len != 0 {
218
92
            *counter = counter
219
92
                .checked_sub(stream_data_len.into())
220
92
                .expect("we've removed more bytes than we've added?!");
221
120
        }
222

            
223
212
        Poll::Ready(Some(item))
224
440
    }
225
}
226

            
227
impl UnobtrusivePeekableStream for StreamQueueReceiver {
228
    fn unobtrusive_peek_mut<'s>(
229
        self: Pin<&'s mut Self>,
230
    ) -> Option<&'s mut <Self as futures::Stream>::Item> {
231
        self.project().receiver.unobtrusive_peek_mut()
232
    }
233
}
234

            
235
/// The `length` field of the message, or 0 if not a data message.
236
///
237
/// If the RELAY_DATA message had an invalid length field, we just ignore the message.
238
/// The receiver will find out eventually when it tries to parse the message.
239
/// We could return an error here, but for now I think it's best not to behave as if this
240
/// queue is performing any validation.
241
///
242
/// This is its own function so that all parts of the code use the same logic.
243
424
fn data_len(item: &UnparsedRelayMsg) -> u16 {
244
424
    item.data_len().unwrap_or(0)
245
424
}