1 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
2 
3 use std::future::Future;
4 use std::io;
5 use std::pin::Pin;
6 use std::task::{ready, Context, Poll};
7 
8 #[derive(Debug)]
9 pub(super) struct CopyBuffer {
10     read_done: bool,
11     need_flush: bool,
12     pos: usize,
13     cap: usize,
14     amt: u64,
15     buf: Box<[u8]>,
16 }
17 
18 impl CopyBuffer {
new(buf_size: usize) -> Self19     pub(super) fn new(buf_size: usize) -> Self {
20         Self {
21             read_done: false,
22             need_flush: false,
23             pos: 0,
24             cap: 0,
25             amt: 0,
26             buf: vec![0; buf_size].into_boxed_slice(),
27         }
28     }
29 
poll_fill_buf<R>( &mut self, cx: &mut Context<'_>, reader: Pin<&mut R>, ) -> Poll<io::Result<()>> where R: AsyncRead + ?Sized,30     fn poll_fill_buf<R>(
31         &mut self,
32         cx: &mut Context<'_>,
33         reader: Pin<&mut R>,
34     ) -> Poll<io::Result<()>>
35     where
36         R: AsyncRead + ?Sized,
37     {
38         let me = &mut *self;
39         let mut buf = ReadBuf::new(&mut me.buf);
40         buf.set_filled(me.cap);
41 
42         let res = reader.poll_read(cx, &mut buf);
43         if let Poll::Ready(Ok(())) = res {
44             let filled_len = buf.filled().len();
45             me.read_done = me.cap == filled_len;
46             me.cap = filled_len;
47         }
48         res
49     }
50 
poll_write_buf<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<usize>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,51     fn poll_write_buf<R, W>(
52         &mut self,
53         cx: &mut Context<'_>,
54         mut reader: Pin<&mut R>,
55         mut writer: Pin<&mut W>,
56     ) -> Poll<io::Result<usize>>
57     where
58         R: AsyncRead + ?Sized,
59         W: AsyncWrite + ?Sized,
60     {
61         let me = &mut *self;
62         match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
63             Poll::Pending => {
64                 // Top up the buffer towards full if we can read a bit more
65                 // data - this should improve the chances of a large write
66                 if !me.read_done && me.cap < me.buf.len() {
67                     ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
68                 }
69                 Poll::Pending
70             }
71             res => res,
72         }
73     }
74 
poll_copy<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<u64>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,75     pub(super) fn poll_copy<R, W>(
76         &mut self,
77         cx: &mut Context<'_>,
78         mut reader: Pin<&mut R>,
79         mut writer: Pin<&mut W>,
80     ) -> Poll<io::Result<u64>>
81     where
82         R: AsyncRead + ?Sized,
83         W: AsyncWrite + ?Sized,
84     {
85         ready!(crate::trace::trace_leaf(cx));
86         #[cfg(any(
87             feature = "fs",
88             feature = "io-std",
89             feature = "net",
90             feature = "process",
91             feature = "rt",
92             feature = "signal",
93             feature = "sync",
94             feature = "time",
95         ))]
96         // Keep track of task budget
97         let coop = ready!(crate::runtime::coop::poll_proceed(cx));
98         loop {
99             // If there is some space left in our buffer, then we try to read some
100             // data to continue, thus maximizing the chances of a large write.
101             if self.cap < self.buf.len() && !self.read_done {
102                 match self.poll_fill_buf(cx, reader.as_mut()) {
103                     Poll::Ready(Ok(())) => {
104                         #[cfg(any(
105                             feature = "fs",
106                             feature = "io-std",
107                             feature = "net",
108                             feature = "process",
109                             feature = "rt",
110                             feature = "signal",
111                             feature = "sync",
112                             feature = "time",
113                         ))]
114                         coop.made_progress();
115                     }
116                     Poll::Ready(Err(err)) => {
117                         #[cfg(any(
118                             feature = "fs",
119                             feature = "io-std",
120                             feature = "net",
121                             feature = "process",
122                             feature = "rt",
123                             feature = "signal",
124                             feature = "sync",
125                             feature = "time",
126                         ))]
127                         coop.made_progress();
128                         return Poll::Ready(Err(err));
129                     }
130                     Poll::Pending => {
131                         // Ignore pending reads when our buffer is not empty, because
132                         // we can try to write data immediately.
133                         if self.pos == self.cap {
134                             // Try flushing when the reader has no progress to avoid deadlock
135                             // when the reader depends on buffered writer.
136                             if self.need_flush {
137                                 ready!(writer.as_mut().poll_flush(cx))?;
138                                 #[cfg(any(
139                                     feature = "fs",
140                                     feature = "io-std",
141                                     feature = "net",
142                                     feature = "process",
143                                     feature = "rt",
144                                     feature = "signal",
145                                     feature = "sync",
146                                     feature = "time",
147                                 ))]
148                                 coop.made_progress();
149                                 self.need_flush = false;
150                             }
151 
152                             return Poll::Pending;
153                         }
154                     }
155                 }
156             }
157 
158             // If our buffer has some data, let's write it out!
159             while self.pos < self.cap {
160                 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
161                 #[cfg(any(
162                     feature = "fs",
163                     feature = "io-std",
164                     feature = "net",
165                     feature = "process",
166                     feature = "rt",
167                     feature = "signal",
168                     feature = "sync",
169                     feature = "time",
170                 ))]
171                 coop.made_progress();
172                 if i == 0 {
173                     return Poll::Ready(Err(io::Error::new(
174                         io::ErrorKind::WriteZero,
175                         "write zero byte into writer",
176                     )));
177                 } else {
178                     self.pos += i;
179                     self.amt += i as u64;
180                     self.need_flush = true;
181                 }
182             }
183 
184             // If pos larger than cap, this loop will never stop.
185             // In particular, user's wrong poll_write implementation returning
186             // incorrect written length may lead to thread blocking.
187             debug_assert!(
188                 self.pos <= self.cap,
189                 "writer returned length larger than input slice"
190             );
191 
192             // All data has been written, the buffer can be considered empty again
193             self.pos = 0;
194             self.cap = 0;
195 
196             // If we've written all the data and we've seen EOF, flush out the
197             // data and finish the transfer.
198             if self.read_done {
199                 ready!(writer.as_mut().poll_flush(cx))?;
200                 #[cfg(any(
201                     feature = "fs",
202                     feature = "io-std",
203                     feature = "net",
204                     feature = "process",
205                     feature = "rt",
206                     feature = "signal",
207                     feature = "sync",
208                     feature = "time",
209                 ))]
210                 coop.made_progress();
211                 return Poll::Ready(Ok(self.amt));
212             }
213         }
214     }
215 }
216 
217 /// A future that asynchronously copies the entire contents of a reader into a
218 /// writer.
219 #[derive(Debug)]
220 #[must_use = "futures do nothing unless you `.await` or poll them"]
221 struct Copy<'a, R: ?Sized, W: ?Sized> {
222     reader: &'a mut R,
223     writer: &'a mut W,
224     buf: CopyBuffer,
225 }
226 
227 cfg_io_util! {
228     /// Asynchronously copies the entire contents of a reader into a writer.
229     ///
230     /// This function returns a future that will continuously read data from
231     /// `reader` and then write it into `writer` in a streaming fashion until
232     /// `reader` returns EOF or fails.
233     ///
234     /// On success, the total number of bytes that were copied from `reader` to
235     /// `writer` is returned.
236     ///
237     /// This is an asynchronous version of [`std::io::copy`][std].
238     ///
239     /// A heap-allocated copy buffer with 8 KB is created to take data from the
240     /// reader to the writer, check [`copy_buf`] if you want an alternative for
241     /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the
242     /// buffer capacity.
243     ///
244     /// [std]: std::io::copy
245     /// [`copy_buf`]: crate::io::copy_buf
246     /// [`AsyncBufRead`]: crate::io::AsyncBufRead
247     /// [`BufReader`]: crate::io::BufReader
248     ///
249     /// # Errors
250     ///
251     /// The returned future will return an error immediately if any call to
252     /// `poll_read` or `poll_write` returns an error.
253     ///
254     /// # Examples
255     ///
256     /// ```
257     /// use tokio::io;
258     ///
259     /// # async fn dox() -> std::io::Result<()> {
260     /// let mut reader: &[u8] = b"hello";
261     /// let mut writer: Vec<u8> = vec![];
262     ///
263     /// io::copy(&mut reader, &mut writer).await?;
264     ///
265     /// assert_eq!(&b"hello"[..], &writer[..]);
266     /// # Ok(())
267     /// # }
268     /// ```
269     pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
270     where
271         R: AsyncRead + Unpin + ?Sized,
272         W: AsyncWrite + Unpin + ?Sized,
273     {
274         Copy {
275             reader,
276             writer,
277             buf: CopyBuffer::new(super::DEFAULT_BUF_SIZE)
278         }.await
279     }
280 }
281 
282 impl<R, W> Future for Copy<'_, R, W>
283 where
284     R: AsyncRead + Unpin + ?Sized,
285     W: AsyncWrite + Unpin + ?Sized,
286 {
287     type Output = io::Result<u64>;
288 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>289     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
290         let me = &mut *self;
291 
292         me.buf
293             .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
294     }
295 }
296