1
//! Functionality to copy from an `AsyncBufRead` to an `AsyncWrite`.
2

            
3
use std::{
4
    future::Future,
5
    pin::Pin,
6
    task::{Context, Poll, ready},
7
};
8

            
9
use crate::{
10
    arc_io_result::{ArcIoResult, ArcIoResultExt},
11
    fuse_buf_reader::FuseBufReader,
12
};
13
use futures::{AsyncBufRead, AsyncWrite};
14
use pin_project::pin_project;
15

            
16
/// Return a future to copy all bytes interactively from `reader` to `writer`.
17
///
18
/// Unlike [`futures::io::copy`], this future makes sure that
19
/// if `reader` pauses (returns `Pending`),
20
/// all as-yet-received bytes are still flushed to `writer`.
21
///
22
/// The future continues copying data until either an error occurs
23
/// (in which case it yields an error),
24
/// or the reader returns an EOF
25
/// (in which case it flushes any pending data,
26
/// and returns the number of bytes copied).
27
///
28
/// # Limitations
29
///
30
/// See the crate-level documentation for
31
/// [discussion of this function's limitations](crate#Limitations).
32
72
pub fn copy_buf<R, W>(reader: R, writer: W) -> CopyBuf<R, W>
33
72
where
34
72
    R: AsyncBufRead,
35
72
    W: AsyncWrite,
36
{
37
72
    CopyBuf {
38
72
        reader: FuseBufReader::new(reader),
39
72
        writer,
40
72
        copied: 0,
41
72
    }
42
72
}
43

            
44
/// A future returned by [`copy_buf`].
45
#[derive(Debug)]
46
#[pin_project]
47
#[must_use = "futures do nothing unless you `.await` or poll them"]
48
pub struct CopyBuf<R, W> {
49
    /// The reader that we're taking data from.
50
    ///
51
    /// This is `FuseBufReader` to make our logic simpler.
52
    #[pin]
53
    reader: FuseBufReader<R>,
54

            
55
    /// The writer that we're pushing
56
    #[pin]
57
    writer: W,
58

            
59
    /// The number of bytes written to the writer so far.
60
    copied: u64,
61
}
62

            
63
impl<R, W> CopyBuf<R, W> {
64
    /// Consume this CopyBuf future, and return the underlying reader and writer.
65
    pub fn into_inner(self) -> (R, W) {
66
        (self.reader.into_inner(), self.writer)
67
    }
68
}
69

            
70
impl<R, W> Future for CopyBuf<R, W>
71
where
72
    R: AsyncBufRead,
73
    W: AsyncWrite,
74
{
75
    type Output = std::io::Result<u64>;
76

            
77
28640
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78
28640
        let this = self.project();
79
28640
        let () = ready!(poll_copy_r_to_w(
80
28640
            cx,
81
28640
            this.reader,
82
28640
            this.writer,
83
28640
            this.copied,
84
28640
            false
85
28640
        ))
86
60
        .io_result()?;
87
48
        Poll::Ready(Ok(*this.copied))
88
28640
    }
89
}
90

            
91
/// Core implementation function:
92
/// Try to make progress copying bytes from `reader` to `writer`,
93
/// and add the number of bytes written to `*total_copied`.
94
///
95
/// Returns `Ready` when an error has occurred,
96
/// or when the reader has reached EOF and the writer has been flushed.
97
/// Otherwise, returns `Pending`, and registers itself with `cx`.
98
///
99
/// (This is a separate function so we can use it to implement CopyBuf and CopyBufBidirectional.)
100
48368
pub(crate) fn poll_copy_r_to_w<R, W>(
101
48368
    cx: &mut Context<'_>,
102
48368
    mut reader: Pin<&mut FuseBufReader<R>>,
103
48368
    mut writer: Pin<&mut W>,
104
48368
    total_copied: &mut u64,
105
48368
    flush_on_err: bool,
106
48368
) -> Poll<ArcIoResult<()>>
107
48368
where
108
48368
    R: AsyncBufRead,
109
48368
    W: AsyncWrite,
110
{
111
    // TODO: Instead of using poll_fill_buf() unconditionally,
112
    // it might be a neat idea to use the buffer by reference and just keep writing
113
    // if the buffer is already "full enough".  The futures::io AsyncBufRead API
114
    // doesn't really make that possible, though.  If specialization is ever stabilized,
115
    // we could have a special implementation for BufReader, I guess.
116

            
117
    // TODO: We assume that 'flush' is pretty fast when it has nothing to do.
118
    // If that's wrong, we may need to remember whether we've written data but not flushed it.
119

            
120
    loop {
121
287660
        match reader.as_mut().poll_fill_buf(cx) {
122
            Poll::Pending => {
123
                // If there's nothing to read now, we may need to make sure that the writer
124
                // is flushed.
125
21668
                let () = ready!(writer.as_mut().poll_flush(cx))?;
126
21668
                return Poll::Pending;
127
            }
128
12
            Poll::Ready(Err(e)) => {
129
                //  On error, flush, and propagate the error.
130
12
                if flush_on_err {
131
                    let _ignore_flush_error = ready!(writer.as_mut().poll_flush(cx));
132
12
                }
133
12
                return Poll::Ready(Err(e));
134
            }
135
265980
            Poll::Ready(Ok(&[])) => {
136
                // On EOF, we have already written all the data; make sure we flush it,
137
                // and then return the amount that we copied.
138
152
                let () = ready!(writer.as_mut().poll_flush(cx))?;
139
152
                return Poll::Ready(Ok(()));
140
            }
141
265828
            Poll::Ready(Ok(data)) => {
142
                // If there is pending data, we copy as much as we can.
143
                // We return "pending" if we can't write any.
144
265828
                let n_written: usize = ready!(writer.as_mut().poll_write(cx, data))?;
145
                // Remove the data from the reader.
146
239292
                reader.as_mut().consume(n_written);
147
239292
                *total_copied += n_written as u64;
148
            }
149
        }
150
    }
151
48368
}
152

            
153
#[cfg(test)]
154
mod test {
155
    // @@ begin test lint list maintained by maint/add_warning @@
156
    #![allow(clippy::bool_assert_comparison)]
157
    #![allow(clippy::clone_on_copy)]
158
    #![allow(clippy::dbg_macro)]
159
    #![allow(clippy::mixed_attributes_style)]
160
    #![allow(clippy::print_stderr)]
161
    #![allow(clippy::print_stdout)]
162
    #![allow(clippy::single_char_pattern)]
163
    #![allow(clippy::unwrap_used)]
164
    #![allow(clippy::unchecked_time_subtraction)]
165
    #![allow(clippy::useless_vec)]
166
    #![allow(clippy::needless_pass_by_value)]
167
    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
168

            
169
    use super::*;
170
    use crate::test::{ErrorRW, PausedRead};
171

            
172
    use futures::{
173
        AsyncReadExt as _,
174
        future::poll_fn,
175
        io::{BufReader, Cursor},
176
    };
177
    use std::io;
178
    use tor_rtcompat::SpawnExt as _;
179
    use tor_rtmock::{MockRuntime, io::stream_pair};
180

            
181
    async fn test_copy_cursor(data: &[u8]) {
182
        let mut out: Vec<u8> = Vec::new();
183
        let r = Cursor::new(data);
184
        let mut w = Cursor::new(&mut out);
185

            
186
        let n_copied = copy_buf(&mut BufReader::new(r), &mut w).await.unwrap();
187
        assert_eq!(n_copied, data.len() as u64);
188
        assert_eq!(&out[..], data);
189
    }
190

            
191
    async fn test_copy_stream(rt: &MockRuntime, data: &[u8]) {
192
        let out: Vec<u8> = Vec::new();
193
        let r1 = Cursor::new(data.to_vec());
194
        let (w1, r2) = stream_pair();
195
        let mut w2 = Cursor::new(out);
196
        let r1 = BufReader::new(r1);
197
        let r2 = BufReader::new(r2);
198
        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
199
        let task2 = rt
200
            .spawn_with_handle(async move {
201
                let copy_result = copy_buf(r2, &mut w2).await;
202
                (copy_result, w2)
203
            })
204
            .unwrap();
205

            
206
        let copy_result_1 = task1.await;
207
        let (copy_result_2, output) = task2.await;
208

            
209
        assert_eq!(copy_result_1.unwrap(), data.len() as u64);
210
        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
211
        assert_eq!(&output.into_inner()[..], data);
212
    }
213

            
214
    async fn test_copy_stream_paused(rt: &MockRuntime, data: &[u8]) {
215
        let n = data.len();
216
        let r1 = BufReader::new(Cursor::new(data.to_vec()).chain(PausedRead));
217
        let (w1, mut r2) = stream_pair();
218
        let mut task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
219
        let mut buf = vec![0_u8; n];
220
        r2.read_exact(&mut buf[..]).await.unwrap();
221
        assert_eq!(&buf[..], data);
222

            
223
        // Should not be able to ever end.
224
        let task1_status = poll_fn(|cx| Poll::Ready(Pin::new(&mut task1).poll(cx))).await;
225
        assert!(task1_status.is_pending());
226
    }
227

            
228
    async fn test_copy_stream_error(rt: &MockRuntime, data: &[u8]) {
229
        let out: Vec<u8> = Vec::new();
230
        let r1 = Cursor::new(data.to_vec()).chain(ErrorRW(io::ErrorKind::ResourceBusy));
231
        let (w1, r2) = stream_pair();
232
        let mut w2 = Cursor::new(out);
233
        let r1 = BufReader::new(r1);
234
        let r2 = BufReader::new(r2);
235
        let task1 = rt.spawn_with_handle(copy_buf(r1, w1)).unwrap();
236
        let task2 = rt
237
            .spawn_with_handle(async move {
238
                let copy_result = copy_buf(r2, &mut w2).await;
239
                (copy_result, w2)
240
            })
241
            .unwrap();
242

            
243
        let copy_result_1 = task1.await;
244
        let (copy_result_2, output) = task2.await;
245

            
246
        assert_eq!(
247
            copy_result_1.unwrap_err().kind(),
248
            io::ErrorKind::ResourceBusy
249
        );
250
        assert_eq!(copy_result_2.unwrap(), data.len() as u64);
251
        assert_eq!(&output.into_inner()[..], data);
252
    }
253

            
254
    fn test_copy(data: &[u8]) {
255
        MockRuntime::test_with_various(async |rt| {
256
            test_copy_cursor(data).await;
257
            test_copy_stream(&rt, data).await;
258
            test_copy_stream_paused(&rt, data).await;
259
            test_copy_stream_error(&rt, data).await;
260
        });
261
    }
262

            
263
    #[test]
264
    fn copy_nothing() {
265
        test_copy(&[]);
266
    }
267

            
268
    #[test]
269
    fn copy_small() {
270
        test_copy(b"hEllo world");
271
    }
272

            
273
    #[test]
274
    fn copy_huge() {
275
        let huge: Vec<u8> = (0..=77).cycle().take(1_500_000).collect();
276
        test_copy(&huge[..]);
277
    }
278
}