1
//! Functionality to copy from an `AsyncBufRead` to an `AsyncWrite`.
2

            
3
use std::{
4
    future::Future,
5
    pin::Pin,
6
    task::{Context, Poll, ready},
7
};
8

            
9
use crate::{
10
    arc_io_result::{ArcIoResult, ArcIoResultExt},
11
    fuse_buf_reader::FuseBufReader,
12
};
13
use futures::{AsyncBufRead, AsyncWrite};
14
use pin_project::pin_project;
15

            
16
/// Return a future to copy all bytes interactively from `reader` to `writer`.
17
///
18
/// Unlike [`futures::io::copy`], this future makes sure that
19
/// if `reader` pauses (returns `Pending`),
20
/// all as-yet-received bytes are still flushed to `writer`.
21
///
22
/// The future continues copying data until either an error occurs
23
/// (in which case it yields an error),
24
/// or the reader returns an EOF
25
/// (in which case it flushes any pending data,
26
/// and returns the number of bytes copied).
27
///
28
/// # Limitations
29
///
30
/// See the crate-level documentation for
31
/// [discussion of this function's limitations](crate#Limitations).
32
72
pub fn copy_buf<R, W>(reader: R, writer: W) -> CopyBuf<R, W>
33
72
where
34
72
    R: AsyncBufRead,
35
72
    W: AsyncWrite,
36
{
37
72
    CopyBuf {
38
72
        reader: FuseBufReader::new(reader),
39
72
        writer,
40
72
        copied: 0,
41
72
    }
42
72
}
43

            
44
/// A future returned by [`copy_buf`].
45
#[derive(Debug)]
46
#[pin_project]
47
#[must_use = "futures do nothing unless you `.await` or poll them"]
48
pub struct CopyBuf<R, W> {
49
    /// The reader that we're taking data from.
50
    ///
51
    /// This is `FuseBufReader` to make our logic simpler.
52
    #[pin]
53
    reader: FuseBufReader<R>,
54

            
55
    /// The writer that we're pushing
56
    #[pin]
57
    writer: W,
58

            
59
    /// The number of bytes written to the writer so far.
60
    copied: u64,
61
}
62

            
63
impl<R, W> CopyBuf<R, W> {
64
    /// Consume this CopyBuf future, and return the underlying reader and writer.
65
    pub fn into_inner(self) -> (R, W) {
66
        (self.reader.into_inner(), self.writer)
67
    }
68
}
69

            
70
impl<R, W> Future for CopyBuf<R, W>
71
where
72
    R: AsyncBufRead,
73
    W: AsyncWrite,
74
{
75
    type Output = std::io::Result<u64>;
76

            
77
28640
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78
28640
        let this = self.project();
79
28640
        let () = ready!(poll_copy_r_to_w(
80
28640
            cx,
81
28640
            this.reader,
82
28640
            this.writer,
83
28640
            this.copied,
84
28640
            false
85
28640
        ))
86
60
        .io_result()?;
87
48
        Poll::Ready(Ok(*this.copied))
88
28640
    }
89
}
90

            
91
/// Core implementation function:
92
/// Try to make progress copying bytes from `reader` to `writer`,
93
/// and add the number of bytes written to `*total_copied`.
94
///
95
/// Returns `Ready` when an error has occurred,
96
/// or when the reader has reached EOF and the writer has been flushed.
97
/// Otherwise, returns `Pending`, and registers itself with `cx`.
98
///
99
/// (This is a separate function so we can use it to implement CopyBuf and CopyBufBidirectional.)
100
48368
pub(crate) fn poll_copy_r_to_w<R, W>(
101
48368
    cx: &mut Context<'_>,
102
48368
    mut reader: Pin<&mut FuseBufReader<R>>,
103
48368
    mut writer: Pin<&mut W>,
104
48368
    total_copied: &mut u64,
105
48368
    flush_on_err: bool,
106
48368
) -> Poll<ArcIoResult<()>>
107
48368
where
108
48368
    R: AsyncBufRead,
109
48368
    W: AsyncWrite,
110
{
111
    // TODO: Instead of using poll_fill_buf() unconditionally,
112
    // it might be a neat idea to use the buffer by reference and just keep writing
113
    // if the buffer is already "full enough".  The futures::io AsyncBufRead API
114
    // doesn't really make that possible, though.  If specialization is ever stabilized,
115
    // we could have a special implementation for BufReader, I guess.
116

            
117
    // TODO: We assume that 'flush' is pretty fast when it has nothing to do.
118
    // If that's wrong, we may need to remember whether we've written data but not flushed it.
119

            
120
    loop {
121
287660
        match reader.as_mut().poll_fill_buf(cx) {
122
            Poll::Pending => {
123
                // If there's nothing to read now, we may need to make sure that the writer
124
                // is flushed.
125
21668
                let () = ready!(writer.as_mut().poll_flush(cx))?;
126
21668
                return Poll::Pending;
127
            }
128
12
            Poll::Ready(Err(e)) => {
129
                //  On error, flush, and propagate the error.
130
12
                if flush_on_err {
131
                    let _ignore_flush_error = ready!(writer.as_mut().poll_flush(cx));
132
12
                }
133
12
                return Poll::Ready(Err(e));
134
            }
135
265980
            Poll::Ready(Ok(&[])) => {
136
                // On EOF, we have already written all the data; make sure we flush it,
137
                // and then return the amount that we copied.
138
152
                let () = ready!(writer.as_mut().poll_flush(cx))?;
139
152
                return Poll::Ready(Ok(()));
140
            }
141
265828
            Poll::Ready(Ok(data)) => {
142
                // If there is pending data, we copy as much as we can.
143
                // We return "pending" if we can't write any.
144
265828
                let n_written: usize = ready!(writer.as_mut().poll_write(cx, data))?;
145
                // Remove the data from the reader.
146
239292
                reader.as_mut().consume(n_written);
147
239292
                *total_copied += n_written as u64;
148
            }
149
        }
150
    }
151
48368
}
152

            
153
#[cfg(test)]
154
mod test {
155
    // @@ begin test lint list maintained by maint/add_warning @@
156
    #![allow(clippy::bool_assert_comparison)]
157
    #![allow(clippy::clone_on_copy)]
158
    #![allow(clippy::dbg_macro)]
159
    #![allow(clippy::mixed_attributes_style)]
160
    #![allow(clippy::print_stderr)]
161
    #![allow(clippy::print_stdout)]
162
    #![allow(clippy::single_char_pattern)]
163
    #![allow(clippy::unwrap_used)]
164
    #![allow(clippy::unchecked_time_subtraction)]
165
    #![allow(clippy::useless_vec)]
166
    #![allow(clippy::needless_pass_by_value)]
167
    #![allow(clippy::string_slice)] // See arti#2571
168
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
169

            
170
    use super::*;
171
    use crate::test::{ErrorRW, PausedRead};
172

            
173
    use futures::{
174
        AsyncReadExt as _,
175
        future::poll_fn,
176
        io::{BufReader, Cursor},
177
    };
178
    use std::io;
179
    use tor_rtcompat::SpawnExt as _;
180
    use tor_rtmock::{MockRuntime, io::stream_pair};
181

            
182
    async fn test_copy_cursor(data: &[u8]) {
183
        let mut out: Vec<u8> = Vec::new();
184
        let r = Cursor::new(data);
185
        let mut w = Cursor::new(&mut out);
186

            
187
        let n_copied = copy_buf(&mut BufReader::new(r), &mut w).await.unwrap();
188
        assert_eq!(n_copied, data.len() as u64);
189
        assert_eq!(&out[..], data);
190
    }
191

            
192
    async fn test_copy_stream(rt: &MockRuntime, data: &[u8]) {
193
        let out: Vec<u8> = Vec::new();
194
        let r1 = Cursor::new(data.to_vec());
195
        let (w1, r2) = stream_pair();
196
        let mut w2 = Cursor::new(out);
197
        let r1 = BufReader::new(r1);
198
        let r2 = BufReader::new(r2);
199
        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
200
        let task2 = rt
201
            .spawn_with_handle(async move {
202
                let copy_result = copy_buf(r2, &mut w2).await;
203
                (copy_result, w2)
204
            })
205
            .unwrap();
206

            
207
        let copy_result_1 = task1.await;
208
        let (copy_result_2, output) = task2.await;
209

            
210
        assert_eq!(copy_result_1.unwrap(), data.len() as u64);
211
        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
212
        assert_eq!(&output.into_inner()[..], data);
213
    }
214

            
215
    async fn test_copy_stream_paused(rt: &MockRuntime, data: &[u8]) {
216
        let n = data.len();
217
        let r1 = BufReader::new(Cursor::new(data.to_vec()).chain(PausedRead));
218
        let (w1, mut r2) = stream_pair();
219
        let mut task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
220
        let mut buf = vec![0_u8; n];
221
        r2.read_exact(&mut buf[..]).await.unwrap();
222
        assert_eq!(&buf[..], data);
223

            
224
        // Should not be able to ever end.
225
        let task1_status = poll_fn(|cx| Poll::Ready(Pin::new(&mut task1).poll(cx))).await;
226
        assert!(task1_status.is_pending());
227
    }
228

            
229
    async fn test_copy_stream_error(rt: &MockRuntime, data: &[u8]) {
230
        let out: Vec<u8> = Vec::new();
231
        let r1 = Cursor::new(data.to_vec()).chain(ErrorRW(io::ErrorKind::ResourceBusy));
232
        let (w1, r2) = stream_pair();
233
        let mut w2 = Cursor::new(out);
234
        let r1 = BufReader::new(r1);
235
        let r2 = BufReader::new(r2);
236
        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
237
        let task2 = rt
238
            .spawn_with_handle(async move {
239
                let copy_result = copy_buf(r2, &mut w2).await;
240
                (copy_result, w2)
241
            })
242
            .unwrap();
243

            
244
        let copy_result_1 = task1.await;
245
        let (copy_result_2, output) = task2.await;
246

            
247
        assert_eq!(
248
            copy_result_1.unwrap_err().kind(),
249
            io::ErrorKind::ResourceBusy
250
        );
251
        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
252
        assert_eq!(&output.into_inner()[..], data);
253
    }
254

            
255
    fn test_copy(data: &[u8]) {
256
        MockRuntime::test_with_various(async |rt| {
257
            test_copy_cursor(data).await;
258
            test_copy_stream(&rt, data).await;
259
            test_copy_stream_paused(&rt, data).await;
260
            test_copy_stream_error(&rt, data).await;
261
        });
262
    }
263

            
264
    #[test]
265
    fn copy_nothing() {
266
        test_copy(&[]);
267
    }
268

            
269
    #[test]
270
    fn copy_small() {
271
        test_copy(b"hEllo world");
272
    }
273

            
274
    #[test]
275
    fn copy_huge() {
276
        let huge: Vec<u8> = (0..=77).cycle().take(1_500_000).collect();
277
        test_copy(&huge[..]);
278
    }
279
}