1
//! Functionality to copy bidirectionally between two streams
2
//! that implement `AsyncBufRead` and`AsyncWrite`.
3

            
4
use std::{
5
    io,
6
    pin::Pin,
7
    task::{Context, Poll, ready},
8
};
9

            
10
use futures::{AsyncBufRead, AsyncWrite};
11
use pin_project::pin_project;
12

            
13
use crate::{
14
    arc_io_result::{ArcIoResult, wrap_error},
15
    copy_buf::poll_copy_r_to_w,
16
    eof::EofStrategy,
17
    fuse_buf_reader::FuseBufReader,
18
};
19

            
20
/// Return a future to copies bytes from `stream_a` to `stream_b`,
21
/// and from `stream_b` to `stream_a`.
22
///
23
/// The future makes sure that
24
/// if a stream pauses (returns Pending),
25
/// all as-yet-received bytes are still flushed to the other stream.
26
///
27
/// If an EOF is read from `stream_a`,
28
/// the future uses `on_a_eof` to report the EOF to `stream_b`.
29
/// Similarly, if an EOF is read from  `stream_b`,
30
/// the future uses `on_b_eof` to report the EOF to `stream_a`.
31
///
32
/// The future will continue running until either an error has occurred
33
/// (in which case it yields an error),
34
/// or until both streams have returned an EOF as readers
35
/// and have both been flushed as writers
36
/// (in which case it yields a tuple of the number of bytes copied from a to b,
37
/// and the number of bytes copied from b to a.)
38
///
39
/// # Limitations
40
///
41
/// See the crate-level documentation for
42
/// [discussion of this function's limitations](crate#Limitations).
43
52
pub fn copy_buf_bidirectional<A, B, AE, BE>(
44
52
    stream_a: A,
45
52
    stream_b: B,
46
52
    on_a_eof: AE,
47
52
    on_b_eof: BE,
48
52
) -> CopyBufBidirectional<A, B, AE, BE>
49
52
where
50
52
    A: AsyncBufRead + AsyncWrite,
51
52
    B: AsyncBufRead + AsyncWrite,
52
52
    AE: EofStrategy<B>,
53
52
    BE: EofStrategy<A>,
54
{
55
52
    CopyBufBidirectional {
56
52
        stream_a: FuseBufReader::new(stream_a),
57
52
        stream_b: FuseBufReader::new(stream_b),
58
52
        on_a_eof,
59
52
        on_b_eof,
60
52
        copied_a_to_b: 0,
61
52
        copied_b_to_a: 0,
62
52
        a_to_b_status: DirectionStatus::Copying,
63
52
        b_to_a_status: DirectionStatus::Copying,
64
52
    }
65
52
}
66

            
67
/// A future returned by [`copy_buf_bidirectional`].
68
//
69
// Note to the reader: You might think it's a good idea to have two separate CopyBuf futures here.
70
// That won't work, though, since each one would need to own both `stream_a` and `stream_b`.
71
// We could use `split` to share the streams, but that would introduce needless locking overhead.
72
//
73
// Instead, we implement the shared functionality via poll_copy_r_to_w.
74
#[derive(Debug)]
75
#[pin_project]
76
#[must_use = "futures do nothing unless you `.await` or poll them"]
77
pub struct CopyBufBidirectional<A, B, AE, BE> {
78
    /// The first stream.
79
    #[pin]
80
    stream_a: FuseBufReader<A>,
81

            
82
    /// The second stream.
83
    #[pin]
84
    stream_b: FuseBufReader<B>,
85

            
86
    /// An [`EofStrategy`] to use when `stream_a` reaches EOF.
87
    #[pin]
88
    on_a_eof: AE,
89

            
90
    /// An [`EofStrategy`] to use when `stream_b` reaches EOF.
91
    #[pin]
92
    on_b_eof: BE,
93

            
94
    /// The number of bytes from `a` written onto `b` so far.
95
    copied_a_to_b: u64,
96
    /// The number of bytes from `b` written onto `a` so far.
97
    copied_b_to_a: u64,
98

            
99
    /// The current status of copying from `a` to `b`.
100
    a_to_b_status: DirectionStatus,
101

            
102
    /// The current status of copying from `b` to `a`.
103
    b_to_a_status: DirectionStatus,
104
}
105

            
106
impl<A, B, AE, BE> CopyBufBidirectional<A, B, AE, BE> {
107
    /// Consume this CopyBufBirectional future, and return the underlying streams.
108
    pub fn into_inner(self) -> (A, B) {
109
        (self.stream_a.into_inner(), self.stream_b.into_inner())
110
    }
111
}
112

            
113
impl<A, B, AE, BE> Future for CopyBufBidirectional<A, B, AE, BE>
114
where
115
    A: AsyncBufRead + AsyncWrite,
116
    B: AsyncBufRead + AsyncWrite,
117
    AE: EofStrategy<B>,
118
    BE: EofStrategy<A>,
119
{
120
    type Output = io::Result<(u64, u64)>;
121

            
122
9874
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123
        use DirectionStatus::*;
124

            
125
9874
        let mut this = self.project();
126

            
127
9874
        if *this.a_to_b_status != DirectionStatus::Done {
128
9862
            let _ignore_completion = one_direction(
129
9862
                cx,
130
9862
                this.stream_a.as_mut(),
131
9862
                this.stream_b.as_mut(),
132
9862
                this.on_a_eof,
133
9862
                this.copied_a_to_b,
134
9862
                this.a_to_b_status,
135
            )
136
9862
            .map_err(|e| wrap_error(&e))?;
137
12
        }
138

            
139
9874
        if *this.b_to_a_status != DirectionStatus::Done {
140
9866
            let _ignore_completion = one_direction(
141
9866
                cx,
142
9866
                this.stream_b.as_mut(),
143
9866
                this.stream_a.as_mut(),
144
9866
                this.on_b_eof,
145
9866
                this.copied_b_to_a,
146
9866
                this.b_to_a_status,
147
            )
148
9866
            .map_err(|e| wrap_error(&e))?;
149
8
        }
150

            
151
9874
        if (*this.a_to_b_status, *this.b_to_a_status) == (Done, Done) {
152
52
            Poll::Ready(Ok((*this.copied_a_to_b, *this.copied_b_to_a)))
153
        } else {
154
9822
            Poll::Pending
155
        }
156
9874
    }
157
}
158

            
159
/// A possible status for copying in a single direction.
160
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
161
enum DirectionStatus {
162
    /// Copying data: we have not yet reached an EOF.
163
    Copying,
164

            
165
    /// Reached EOF: using an [`EofStrategy`] to propagate the EOF to the writer.
166
    SendingEof,
167

            
168
    /// EOF sent: Nothing more to do.
169
    Done,
170
}
171

            
172
/// Try to make progress copying data in a single data, and propagating the EOF.
173
19728
fn one_direction<A, B, AE>(
174
19728
    cx: &mut Context<'_>,
175
19728
    r: Pin<&mut FuseBufReader<A>>,
176
19728
    mut w: Pin<&mut FuseBufReader<B>>,
177
19728
    eof_strategy: Pin<&mut AE>,
178
19728
    n_copied: &mut u64,
179
19728
    status: &mut DirectionStatus,
180
19728
) -> Poll<ArcIoResult<()>>
181
19728
where
182
19728
    A: AsyncBufRead,
183
19728
    B: AsyncWrite,
184
19728
    AE: EofStrategy<B>,
185
{
186
    use DirectionStatus::*;
187

            
188
19728
    if *status == Copying {
189
19728
        let () = ready!(poll_copy_r_to_w(cx, r, w.as_mut(), n_copied, false))?;
190
104
        *status = SendingEof;
191
    }
192

            
193
104
    if *status == SendingEof {
194
104
        let () = ready!(eof_strategy.poll_send_eof(cx, w.get_pin_mut()))?;
195
104
        *status = Done;
196
    }
197

            
198
104
    assert_eq!(*status, Done);
199
104
    Poll::Ready(Ok(()))
200
19728
}
201

            
202
#[cfg(test)]
203
mod test {
204
    // @@ begin test lint list maintained by maint/add_warning @@
205
    #![allow(clippy::bool_assert_comparison)]
206
    #![allow(clippy::clone_on_copy)]
207
    #![allow(clippy::dbg_macro)]
208
    #![allow(clippy::mixed_attributes_style)]
209
    #![allow(clippy::print_stderr)]
210
    #![allow(clippy::print_stdout)]
211
    #![allow(clippy::single_char_pattern)]
212
    #![allow(clippy::unwrap_used)]
213
    #![allow(clippy::unchecked_time_subtraction)]
214
    #![allow(clippy::useless_vec)]
215
    #![allow(clippy::needless_pass_by_value)]
216
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
217

            
218
    use super::*;
219
    use crate::{eof, test::RWPair};
220

            
221
    use futures::{
222
        AsyncBufReadExt,
223
        io::{BufReader, BufWriter, Cursor},
224
    };
225
    use tor_rtcompat::SpawnExt as _;
226
    use tor_rtmock::{MockRuntime, io::stream_pair};
227

            
228
    /// Return a stream implemented with a pair of Vec-backed cursors.
229
    #[allow(clippy::type_complexity)]
230
    fn cursor_stream(init_data: &[u8]) -> BufReader<RWPair<Cursor<Vec<u8>>, Cursor<Vec<u8>>>> {
231
        BufReader::new(RWPair(
232
            Cursor::new(init_data.to_vec()),
233
            Cursor::new(Vec::new()),
234
        ))
235
    }
236

            
237
    async fn test_transfer_cursor(data_1: &[u8], data_2: &[u8]) {
238
        let mut s1 = cursor_stream(data_1);
239
        let mut s2 = cursor_stream(data_2);
240

            
241
        let (t1, t2) = copy_buf_bidirectional(&mut s1, &mut s2, eof::Close, eof::Close)
242
            .await
243
            .unwrap();
244
        assert_eq!(t1, data_1.len() as u64);
245
        assert_eq!(t2, data_2.len() as u64);
246
        let out1 = s1.into_inner().1.into_inner();
247
        let out2 = s2.into_inner().1.into_inner();
248
        assert_eq!(&out1[..], data_2);
249
        assert_eq!(&out2[..], data_1);
250
    }
251

            
252
    async fn test_transfer_streams(rt: &MockRuntime, data_1: &[u8], data_2: &[u8]) {
253
        let mut s1 = cursor_stream(data_1);
254
        let (s2, s3) = stream_pair();
255
        let mut s4 = cursor_stream(data_2);
256

            
257
        let h1 = rt
258
            .spawn_with_handle(async move {
259
                let r = copy_buf_bidirectional(&mut s1, BufReader::new(s2), eof::Close, eof::Close)
260
                    .await;
261
                (r, s1.into_inner().1.into_inner())
262
            })
263
            .unwrap();
264
        let h2 = rt
265
            .spawn_with_handle(async move {
266
                let r = copy_buf_bidirectional(BufReader::new(s3), &mut s4, eof::Close, eof::Close)
267
                    .await;
268
                (r, s4.into_inner().1.into_inner())
269
            })
270
            .unwrap();
271
        let (r1, buf1) = h1.await;
272
        let (r2, buf2) = h2.await;
273

            
274
        assert_eq!(r1.unwrap(), (data_1.len() as u64, data_2.len() as u64));
275
        assert_eq!(r2.unwrap(), (data_1.len() as u64, data_2.len() as u64));
276
        assert_eq!(&buf1, data_2);
277
        assert_eq!(&buf2, data_1);
278
    }
279

            
280
    fn test_transfer(data_1: &[u8], data_2: &[u8]) {
281
        MockRuntime::test_with_various(async |rt| {
282
            test_transfer_cursor(data_1, data_2).await;
283
            test_transfer_streams(&rt, data_1, data_2).await;
284
        });
285
    }
286

            
287
    fn big(x: u8) -> Vec<u8> {
288
        (1..=x).cycle().take(1_234_567).collect()
289
    }
290

            
291
    #[test]
292
    fn transfer_empty() {
293
        test_transfer(&[], &[]);
294
    }
295

            
296
    #[test]
297
    fn transfer_empty_small() {
298
        test_transfer(&[], b"hello world");
299
    }
300

            
301
    #[test]
302
    fn transfer_small() {
303
        test_transfer(b"hola mundo", b"hello world");
304
    }
305

            
306
    #[test]
307
    fn transfer_huge() {
308
        let big1 = big(79);
309
        let big2 = big(81);
310
        test_transfer(&big1, &big2);
311
    }
312

            
313
    #[test]
314
    fn interactive_protocol() {
315
        use futures::io::AsyncWriteExt as _;
316
        // Test our flush behavior by relaying traffic between a pair of communicators that
317
        // don't say anything until they get a message.
318

            
319
        MockRuntime::test_with_various(async |rt| {
320
            let (s1, s2) = stream_pair();
321
            let (s3, s4) = stream_pair();
322

            
323
            // Using BufWriter here means that unless we propagate the flush correctly,
324
            // flushing won't happen soon enough to cause a reply.
325
            let mut s1 = BufReader::new(s1);
326
            let s2 = BufReader::new(BufWriter::with_capacity(1024, s2));
327
            let s3 = BufReader::new(BufWriter::with_capacity(1024, s3));
328
            let mut s4 = BufReader::new(s4);
329

            
330
            // That's a lot of streams!  Here's how they all connect:
331
            //
332
            // Task 1 <--> s1  <-Rt-> s2 <-> Task 2 <--> s3 <-Rt-> s4 <--> Task 3
333
            //
334
            // In other words, s1 and s2 are automatically connected under the hood by
335
            // the MockRuntime, as are s3 and s4.  Task 1 reads and writes from s1.
336
            // Task 2 tests copy_buf_bidirectional by relaying between s2 and s3.
337
            // And Task 3 reads and writes to s4.
338
            //
339
            // Thus task 1 and task 3 can only communicate with one another if
340
            // task 2 (and copy_buf_bidirectional) do their job.
341

            
342
            // Task 1:
343
            // Write a number starting with 1, then read numbers and write back 1 more.
344
            // Continue until you read a number >= 100.
345
            let h1 = rt
346
                .spawn_with_handle(async move {
347
                    let mut buf = String::new();
348
                    let mut num: u32 = 1;
349

            
350
                    loop {
351
                        s1.write_all(format!("{num}\n").as_bytes()).await?;
352
                        s1.flush().await?;
353

            
354
                        let written = num;
355

            
356
                        let n_bytes_read = s1.read_line(&mut buf).await?;
357
                        if n_bytes_read == 0 {
358
                            break;
359
                        }
360
                        num = buf.trim_ascii().parse().unwrap();
361
                        buf.clear();
362
                        assert_eq!(num, written + 1);
363

            
364
                        if num >= 100 {
365
                            break;
366
                        }
367
                        num += 1;
368
                    }
369

            
370
                    s1.close().await?;
371

            
372
                    Ok::<u32, io::Error>(num)
373
                })
374
                .unwrap();
375

            
376
            // Task 2: Use copy_buf_bidirectional to relay traffic.
377
            let h2 = rt
378
                .spawn_with_handle(copy_buf_bidirectional(s2, s3, eof::Close, eof::Close))
379
                .unwrap();
380

            
381
            // Task 3: Forever: read a number on a line, and write back 1 more.
382
            let h3 = rt
383
                .spawn_with_handle(async move {
384
                    let mut buf = String::new();
385
                    let mut last_written = None;
386

            
387
                    loop {
388
                        let n_bytes_read = s4.read_line(&mut buf).await?;
389
                        if n_bytes_read == 0 {
390
                            break;
391
                        }
392
                        let num: u32 = buf.trim_ascii().parse().unwrap();
393
                        buf.clear();
394
                        if let Some(last) = last_written {
395
                            assert_eq!(num, last + 1);
396
                        }
397

            
398
                        let num = num + 1;
399
                        s4.write_all(format!("{num}\n").as_bytes()).await?;
400
                        s4.flush().await?;
401
                        last_written = Some(num);
402
                    }
403
                    Ok::<_, io::Error>(())
404
                })
405
                .unwrap();
406

            
407
            let outcome1 = h1.await;
408
            let outcome2 = h2.await;
409
            let outcome3 = h3.await;
410

            
411
            assert_eq!(outcome1.unwrap(), 100);
412
            let (_, _) = outcome2.unwrap();
413
            let () = outcome3.unwrap();
414
        });
415
    }
416
}