1
//! Functionality to copy bidirectionally between two streams
2
//! that implement `AsyncBufRead` and`AsyncWrite`.
3

            
4
use std::{
5
    io,
6
    pin::Pin,
7
    task::{Context, Poll, ready},
8
};
9

            
10
use futures::{AsyncBufRead, AsyncWrite};
11
use pin_project::pin_project;
12

            
13
use crate::{
14
    arc_io_result::{ArcIoResult, wrap_error},
15
    copy_buf::poll_copy_r_to_w,
16
    eof::EofStrategy,
17
    fuse_buf_reader::FuseBufReader,
18
};
19

            
20
/// Return a future to copies bytes from `stream_a` to `stream_b`,
21
/// and from `stream_b` to `stream_a`.
22
///
23
/// The future makes sure that
24
/// if a stream pauses (returns Pending),
25
/// all as-yet-received bytes are still flushed to the other stream.
26
///
27
/// If an EOF is read from `stream_a`,
28
/// the future uses `on_a_eof` to report the EOF to `stream_b`.
29
/// Similarly, if an EOF is read from  `stream_b`,
30
/// the future uses `on_b_eof` to report the EOF to `stream_a`.
31
///
32
/// The future will continue running until either an error has occurred
33
/// (in which case it yields an error),
34
/// or until both streams have returned an EOF as readers
35
/// and have both been flushed as writers
36
/// (in which case it yields a tuple of the number of bytes copied from a to b,
37
/// and the number of bytes copied from b to a.)
38
///
39
/// # Limitations
40
///
41
/// See the crate-level documentation for
42
/// [discussion of this function's limitations](crate#Limitations).
43
52
pub fn copy_buf_bidirectional<A, B, AE, BE>(
44
52
    stream_a: A,
45
52
    stream_b: B,
46
52
    on_a_eof: AE,
47
52
    on_b_eof: BE,
48
52
) -> CopyBufBidirectional<A, B, AE, BE>
49
52
where
50
52
    A: AsyncBufRead + AsyncWrite,
51
52
    B: AsyncBufRead + AsyncWrite,
52
52
    AE: EofStrategy<B>,
53
52
    BE: EofStrategy<A>,
54
{
55
52
    CopyBufBidirectional {
56
52
        stream_a: FuseBufReader::new(stream_a),
57
52
        stream_b: FuseBufReader::new(stream_b),
58
52
        on_a_eof,
59
52
        on_b_eof,
60
52
        copied_a_to_b: 0,
61
52
        copied_b_to_a: 0,
62
52
        a_to_b_status: DirectionStatus::Copying,
63
52
        b_to_a_status: DirectionStatus::Copying,
64
52
    }
65
52
}
66

            
67
/// A future returned by [`copy_buf_bidirectional`].
68
//
69
// Note to the reader: You might think it's a good idea to have two separate CopyBuf futures here.
70
// That won't work, though, since each one would need to own both `stream_a` and `stream_b`.
71
// We could use `split` to share the streams, but that would introduce needless locking overhead.
72
//
73
// Instead, we implement the shared functionality via poll_copy_r_to_w.
74
#[derive(Debug)]
75
#[pin_project]
76
#[must_use = "futures do nothing unless you `.await` or poll them"]
77
pub struct CopyBufBidirectional<A, B, AE, BE> {
78
    /// The first stream.
79
    #[pin]
80
    stream_a: FuseBufReader<A>,
81

            
82
    /// The second stream.
83
    #[pin]
84
    stream_b: FuseBufReader<B>,
85

            
86
    /// An [`EofStrategy`] to use when `stream_a` reaches EOF.
87
    #[pin]
88
    on_a_eof: AE,
89

            
90
    /// An [`EofStrategy`] to use when `stream_b` reaches EOF.
91
    #[pin]
92
    on_b_eof: BE,
93

            
94
    /// The number of bytes from `a` written onto `b` so far.
95
    copied_a_to_b: u64,
96
    /// The number of bytes from `b` written onto `a` so far.
97
    copied_b_to_a: u64,
98

            
99
    /// The current status of copying from `a` to `b`.
100
    a_to_b_status: DirectionStatus,
101

            
102
    /// The current status of copying from `b` to `a`.
103
    b_to_a_status: DirectionStatus,
104
}
105

            
106
impl<A, B, AE, BE> CopyBufBidirectional<A, B, AE, BE> {
107
    /// Consume this CopyBufBirectional future, and return the underlying streams.
108
    pub fn into_inner(self) -> (A, B) {
109
        (self.stream_a.into_inner(), self.stream_b.into_inner())
110
    }
111
}
112

            
113
impl<A, B, AE, BE> Future for CopyBufBidirectional<A, B, AE, BE>
114
where
115
    A: AsyncBufRead + AsyncWrite,
116
    B: AsyncBufRead + AsyncWrite,
117
    AE: EofStrategy<B>,
118
    BE: EofStrategy<A>,
119
{
120
    type Output = io::Result<(u64, u64)>;
121

            
122
9874
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
123
        use DirectionStatus::*;
124

            
125
9874
        let mut this = self.project();
126

            
127
9874
        if *this.a_to_b_status != DirectionStatus::Done {
128
9862
            let _ignore_completion = one_direction(
129
9862
                cx,
130
9862
                this.stream_a.as_mut(),
131
9862
                this.stream_b.as_mut(),
132
9862
                this.on_a_eof,
133
9862
                this.copied_a_to_b,
134
9862
                this.a_to_b_status,
135
            )
136
9862
            .map_err(|e| wrap_error(&e))?;
137
12
        }
138

            
139
9874
        if *this.b_to_a_status != DirectionStatus::Done {
140
9866
            let _ignore_completion = one_direction(
141
9866
                cx,
142
9866
                this.stream_b.as_mut(),
143
9866
                this.stream_a.as_mut(),
144
9866
                this.on_b_eof,
145
9866
                this.copied_b_to_a,
146
9866
                this.b_to_a_status,
147
            )
148
9866
            .map_err(|e| wrap_error(&e))?;
149
8
        }
150

            
151
9874
        if (*this.a_to_b_status, *this.b_to_a_status) == (Done, Done) {
152
52
            Poll::Ready(Ok((*this.copied_a_to_b, *this.copied_b_to_a)))
153
        } else {
154
9822
            Poll::Pending
155
        }
156
9874
    }
157
}
158

            
159
/// A possible status for copying in a single direction.
160
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
161
enum DirectionStatus {
162
    /// Copying data: we have not yet reached an EOF.
163
    Copying,
164

            
165
    /// Reached EOF: using an [`EofStrategy`] to propagate the EOF to the writer.
166
    SendingEof,
167

            
168
    /// EOF sent: Nothing more to do.
169
    Done,
170
}
171

            
172
/// Try to make progress copying data in a single data, and propagating the EOF.
173
19728
fn one_direction<A, B, AE>(
174
19728
    cx: &mut Context<'_>,
175
19728
    r: Pin<&mut FuseBufReader<A>>,
176
19728
    mut w: Pin<&mut FuseBufReader<B>>,
177
19728
    eof_strategy: Pin<&mut AE>,
178
19728
    n_copied: &mut u64,
179
19728
    status: &mut DirectionStatus,
180
19728
) -> Poll<ArcIoResult<()>>
181
19728
where
182
19728
    A: AsyncBufRead,
183
19728
    B: AsyncWrite,
184
19728
    AE: EofStrategy<B>,
185
{
186
    use DirectionStatus::*;
187

            
188
19728
    if *status == Copying {
189
19728
        let () = ready!(poll_copy_r_to_w(cx, r, w.as_mut(), n_copied, false))?;
190
104
        *status = SendingEof;
191
    }
192

            
193
104
    if *status == SendingEof {
194
104
        let () = ready!(eof_strategy.poll_send_eof(cx, w.get_pin_mut()))?;
195
104
        *status = Done;
196
    }
197

            
198
104
    assert_eq!(*status, Done);
199
104
    Poll::Ready(Ok(()))
200
19728
}
201

            
202
#[cfg(test)]
203
mod test {
204
    // @@ begin test lint list maintained by maint/add_warning @@
205
    #![allow(clippy::bool_assert_comparison)]
206
    #![allow(clippy::clone_on_copy)]
207
    #![allow(clippy::dbg_macro)]
208
    #![allow(clippy::mixed_attributes_style)]
209
    #![allow(clippy::print_stderr)]
210
    #![allow(clippy::print_stdout)]
211
    #![allow(clippy::single_char_pattern)]
212
    #![allow(clippy::unwrap_used)]
213
    #![allow(clippy::unchecked_time_subtraction)]
214
    #![allow(clippy::useless_vec)]
215
    #![allow(clippy::needless_pass_by_value)]
216
    #![allow(clippy::string_slice)] // See arti#2571
217
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
218

            
219
    use super::*;
220
    use crate::{eof, test::RWPair};
221

            
222
    use futures::{
223
        AsyncBufReadExt,
224
        io::{BufReader, BufWriter, Cursor},
225
    };
226
    use tor_rtcompat::SpawnExt as _;
227
    use tor_rtmock::{MockRuntime, io::stream_pair};
228

            
229
    /// Return a stream implemented with a pair of Vec-backed cursors.
230
    #[allow(clippy::type_complexity)]
231
    fn cursor_stream(init_data: &[u8]) -> BufReader<RWPair<Cursor<Vec<u8>>, Cursor<Vec<u8>>>> {
232
        BufReader::new(RWPair(
233
            Cursor::new(init_data.to_vec()),
234
            Cursor::new(Vec::new()),
235
        ))
236
    }
237

            
238
    async fn test_transfer_cursor(data_1: &[u8], data_2: &[u8]) {
239
        let mut s1 = cursor_stream(data_1);
240
        let mut s2 = cursor_stream(data_2);
241

            
242
        let (t1, t2) = copy_buf_bidirectional(&mut s1, &mut s2, eof::Close, eof::Close)
243
            .await
244
            .unwrap();
245
        assert_eq!(t1, data_1.len() as u64);
246
        assert_eq!(t2, data_2.len() as u64);
247
        let out1 = s1.into_inner().1.into_inner();
248
        let out2 = s2.into_inner().1.into_inner();
249
        assert_eq!(&out1[..], data_2);
250
        assert_eq!(&out2[..], data_1);
251
    }
252

            
253
    async fn test_transfer_streams(rt: &MockRuntime, data_1: &[u8], data_2: &[u8]) {
254
        let mut s1 = cursor_stream(data_1);
255
        let (s2, s3) = stream_pair();
256
        let mut s4 = cursor_stream(data_2);
257

            
258
        let h1 = rt
259
            .spawn_with_handle(async move {
260
                let r = copy_buf_bidirectional(&mut s1, BufReader::new(s2), eof::Close, eof::Close)
261
                    .await;
262
                (r, s1.into_inner().1.into_inner())
263
            })
264
            .unwrap();
265
        let h2 = rt
266
            .spawn_with_handle(async move {
267
                let r = copy_buf_bidirectional(BufReader::new(s3), &mut s4, eof::Close, eof::Close)
268
                    .await;
269
                (r, s4.into_inner().1.into_inner())
270
            })
271
            .unwrap();
272
        let (r1, buf1) = h1.await;
273
        let (r2, buf2) = h2.await;
274

            
275
        assert_eq!(r1.unwrap(), (data_1.len() as u64, data_2.len() as u64));
276
        assert_eq!(r2.unwrap(), (data_1.len() as u64, data_2.len() as u64));
277
        assert_eq!(&buf1, data_2);
278
        assert_eq!(&buf2, data_1);
279
    }
280

            
281
    fn test_transfer(data_1: &[u8], data_2: &[u8]) {
282
        MockRuntime::test_with_various(async |rt| {
283
            test_transfer_cursor(data_1, data_2).await;
284
            test_transfer_streams(&rt, data_1, data_2).await;
285
        });
286
    }
287

            
288
    fn big(x: u8) -> Vec<u8> {
289
        (1..=x).cycle().take(1_234_567).collect()
290
    }
291

            
292
    #[test]
293
    fn transfer_empty() {
294
        test_transfer(&[], &[]);
295
    }
296

            
297
    #[test]
298
    fn transfer_empty_small() {
299
        test_transfer(&[], b"hello world");
300
    }
301

            
302
    #[test]
303
    fn transfer_small() {
304
        test_transfer(b"hola mundo", b"hello world");
305
    }
306

            
307
    #[test]
308
    fn transfer_huge() {
309
        let big1 = big(79);
310
        let big2 = big(81);
311
        test_transfer(&big1, &big2);
312
    }
313

            
314
    #[test]
315
    fn interactive_protocol() {
316
        use futures::io::AsyncWriteExt as _;
317
        // Test our flush behavior by relaying traffic between a pair of communicators that
318
        // don't say anything until they get a message.
319

            
320
        MockRuntime::test_with_various(async |rt| {
321
            let (s1, s2) = stream_pair();
322
            let (s3, s4) = stream_pair();
323

            
324
            // Using BufWriter here means that unless we propagate the flush correctly,
325
            // flushing won't happen soon enough to cause a reply.
326
            let mut s1 = BufReader::new(s1);
327
            let s2 = BufReader::new(BufWriter::with_capacity(1024, s2));
328
            let s3 = BufReader::new(BufWriter::with_capacity(1024, s3));
329
            let mut s4 = BufReader::new(s4);
330

            
331
            // That's a lot of streams!  Here's how they all connect:
332
            //
333
            // Task 1 <--> s1  <-Rt-> s2 <-> Task 2 <--> s3 <-Rt-> s4 <--> Task 3
334
            //
335
            // In other words, s1 and s2 are automatically connected under the hood by
336
            // the MockRuntime, as are s3 and s4.  Task 1 reads and writes from s1.
337
            // Task 2 tests copy_buf_bidirectional by relaying between s2 and s3.
338
            // And Task 3 reads and writes to s4.
339
            //
340
            // Thus task 1 and task 3 can only communicate with one another if
341
            // task 2 (and copy_buf_bidirectional) do their job.
342

            
343
            // Task 1:
344
            // Write a number starting with 1, then read numbers and write back 1 more.
345
            // Continue until you read a number >= 100.
346
            let h1 = rt
347
                .spawn_with_handle(async move {
348
                    let mut buf = String::new();
349
                    let mut num: u32 = 1;
350

            
351
                    loop {
352
                        s1.write_all(format!("{num}\n").as_bytes()).await?;
353
                        s1.flush().await?;
354

            
355
                        let written = num;
356

            
357
                        let n_bytes_read = s1.read_line(&mut buf).await?;
358
                        if n_bytes_read == 0 {
359
                            break;
360
                        }
361
                        num = buf.trim_ascii().parse().unwrap();
362
                        buf.clear();
363
                        assert_eq!(num, written + 1);
364

            
365
                        if num >= 100 {
366
                            break;
367
                        }
368
                        num += 1;
369
                    }
370

            
371
                    s1.close().await?;
372

            
373
                    Ok::<u32, io::Error>(num)
374
                })
375
                .unwrap();
376

            
377
            // Task 2: Use copy_buf_bidirectional to relay traffic.
378
            let h2 = rt
379
                .spawn_with_handle(copy_buf_bidirectional(s2, s3, eof::Close, eof::Close))
380
                .unwrap();
381

            
382
            // Task 3: Forever: read a number on a line, and write back 1 more.
383
            let h3 = rt
384
                .spawn_with_handle(async move {
385
                    let mut buf = String::new();
386
                    let mut last_written = None;
387

            
388
                    loop {
389
                        let n_bytes_read = s4.read_line(&mut buf).await?;
390
                        if n_bytes_read == 0 {
391
                            break;
392
                        }
393
                        let num: u32 = buf.trim_ascii().parse().unwrap();
394
                        buf.clear();
395
                        if let Some(last) = last_written {
396
                            assert_eq!(num, last + 1);
397
                        }
398

            
399
                        let num = num + 1;
400
                        s4.write_all(format!("{num}\n").as_bytes()).await?;
401
                        s4.flush().await?;
402
                        last_written = Some(num);
403
                    }
404
                    Ok::<_, io::Error>(())
405
                })
406
                .unwrap();
407

            
408
            let outcome1 = h1.await;
409
            let outcome2 = h2.await;
410
            let outcome3 = h3.await;
411

            
412
            assert_eq!(outcome1.unwrap(), 100);
413
            let (_, _) = outcome2.unwrap();
414
            let () = outcome3.unwrap();
415
        });
416
    }
417
}