1 use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; 2 3 use std::future::Future; 4 use std::io; 5 use std::pin::Pin; 6 use std::task::{ready, Context, Poll}; 7 8 #[derive(Debug)] 9 pub(super) struct CopyBuffer { 10 read_done: bool, 11 need_flush: bool, 12 pos: usize, 13 cap: usize, 14 amt: u64, 15 buf: Box<[u8]>, 16 } 17 18 impl CopyBuffer { new(buf_size: usize) -> Self19 pub(super) fn new(buf_size: usize) -> Self { 20 Self { 21 read_done: false, 22 need_flush: false, 23 pos: 0, 24 cap: 0, 25 amt: 0, 26 buf: vec![0; buf_size].into_boxed_slice(), 27 } 28 } 29 poll_fill_buf<R>( &mut self, cx: &mut Context<'_>, reader: Pin<&mut R>, ) -> Poll<io::Result<()>> where R: AsyncRead + ?Sized,30 fn poll_fill_buf<R>( 31 &mut self, 32 cx: &mut Context<'_>, 33 reader: Pin<&mut R>, 34 ) -> Poll<io::Result<()>> 35 where 36 R: AsyncRead + ?Sized, 37 { 38 let me = &mut *self; 39 let mut buf = ReadBuf::new(&mut me.buf); 40 buf.set_filled(me.cap); 41 42 let res = reader.poll_read(cx, &mut buf); 43 if let Poll::Ready(Ok(())) = res { 44 let filled_len = buf.filled().len(); 45 me.read_done = me.cap == filled_len; 46 me.cap = filled_len; 47 } 48 res 49 } 50 poll_write_buf<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<usize>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,51 fn poll_write_buf<R, W>( 52 &mut self, 53 cx: &mut Context<'_>, 54 mut reader: Pin<&mut R>, 55 mut writer: Pin<&mut W>, 56 ) -> Poll<io::Result<usize>> 57 where 58 R: AsyncRead + ?Sized, 59 W: AsyncWrite + ?Sized, 60 { 61 let me = &mut *self; 62 match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { 63 Poll::Pending => { 64 // Top up the buffer towards full if we can read a bit more 65 // data - this should improve the chances of a large write 66 if !me.read_done && me.cap < me.buf.len() { 67 ready!(me.poll_fill_buf(cx, reader.as_mut()))?; 68 } 69 Poll::Pending 70 } 71 res => res, 72 } 73 } 74 poll_copy<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<u64>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,75 pub(super) fn poll_copy<R, W>( 76 &mut self, 77 cx: &mut Context<'_>, 78 mut reader: Pin<&mut R>, 79 mut writer: Pin<&mut W>, 80 ) -> Poll<io::Result<u64>> 81 where 82 R: AsyncRead + ?Sized, 83 W: AsyncWrite + ?Sized, 84 { 85 ready!(crate::trace::trace_leaf(cx)); 86 #[cfg(any( 87 feature = "fs", 88 feature = "io-std", 89 feature = "net", 90 feature = "process", 91 feature = "rt", 92 feature = "signal", 93 feature = "sync", 94 feature = "time", 95 ))] 96 // Keep track of task budget 97 let coop = ready!(crate::runtime::coop::poll_proceed(cx)); 98 loop { 99 // If there is some space left in our buffer, then we try to read some 100 // data to continue, thus maximizing the chances of a large write. 101 if self.cap < self.buf.len() && !self.read_done { 102 match self.poll_fill_buf(cx, reader.as_mut()) { 103 Poll::Ready(Ok(())) => { 104 #[cfg(any( 105 feature = "fs", 106 feature = "io-std", 107 feature = "net", 108 feature = "process", 109 feature = "rt", 110 feature = "signal", 111 feature = "sync", 112 feature = "time", 113 ))] 114 coop.made_progress(); 115 } 116 Poll::Ready(Err(err)) => { 117 #[cfg(any( 118 feature = "fs", 119 feature = "io-std", 120 feature = "net", 121 feature = "process", 122 feature = "rt", 123 feature = "signal", 124 feature = "sync", 125 feature = "time", 126 ))] 127 coop.made_progress(); 128 return Poll::Ready(Err(err)); 129 } 130 Poll::Pending => { 131 // Ignore pending reads when our buffer is not empty, because 132 // we can try to write data immediately. 133 if self.pos == self.cap { 134 // Try flushing when the reader has no progress to avoid deadlock 135 // when the reader depends on buffered writer. 136 if self.need_flush { 137 ready!(writer.as_mut().poll_flush(cx))?; 138 #[cfg(any( 139 feature = "fs", 140 feature = "io-std", 141 feature = "net", 142 feature = "process", 143 feature = "rt", 144 feature = "signal", 145 feature = "sync", 146 feature = "time", 147 ))] 148 coop.made_progress(); 149 self.need_flush = false; 150 } 151 152 return Poll::Pending; 153 } 154 } 155 } 156 } 157 158 // If our buffer has some data, let's write it out! 159 while self.pos < self.cap { 160 let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; 161 #[cfg(any( 162 feature = "fs", 163 feature = "io-std", 164 feature = "net", 165 feature = "process", 166 feature = "rt", 167 feature = "signal", 168 feature = "sync", 169 feature = "time", 170 ))] 171 coop.made_progress(); 172 if i == 0 { 173 return Poll::Ready(Err(io::Error::new( 174 io::ErrorKind::WriteZero, 175 "write zero byte into writer", 176 ))); 177 } else { 178 self.pos += i; 179 self.amt += i as u64; 180 self.need_flush = true; 181 } 182 } 183 184 // If pos larger than cap, this loop will never stop. 185 // In particular, user's wrong poll_write implementation returning 186 // incorrect written length may lead to thread blocking. 187 debug_assert!( 188 self.pos <= self.cap, 189 "writer returned length larger than input slice" 190 ); 191 192 // All data has been written, the buffer can be considered empty again 193 self.pos = 0; 194 self.cap = 0; 195 196 // If we've written all the data and we've seen EOF, flush out the 197 // data and finish the transfer. 198 if self.read_done { 199 ready!(writer.as_mut().poll_flush(cx))?; 200 #[cfg(any( 201 feature = "fs", 202 feature = "io-std", 203 feature = "net", 204 feature = "process", 205 feature = "rt", 206 feature = "signal", 207 feature = "sync", 208 feature = "time", 209 ))] 210 coop.made_progress(); 211 return Poll::Ready(Ok(self.amt)); 212 } 213 } 214 } 215 } 216 217 /// A future that asynchronously copies the entire contents of a reader into a 218 /// writer. 219 #[derive(Debug)] 220 #[must_use = "futures do nothing unless you `.await` or poll them"] 221 struct Copy<'a, R: ?Sized, W: ?Sized> { 222 reader: &'a mut R, 223 writer: &'a mut W, 224 buf: CopyBuffer, 225 } 226 227 cfg_io_util! { 228 /// Asynchronously copies the entire contents of a reader into a writer. 229 /// 230 /// This function returns a future that will continuously read data from 231 /// `reader` and then write it into `writer` in a streaming fashion until 232 /// `reader` returns EOF or fails. 233 /// 234 /// On success, the total number of bytes that were copied from `reader` to 235 /// `writer` is returned. 236 /// 237 /// This is an asynchronous version of [`std::io::copy`][std]. 238 /// 239 /// A heap-allocated copy buffer with 8 KB is created to take data from the 240 /// reader to the writer, check [`copy_buf`] if you want an alternative for 241 /// [`AsyncBufRead`]. You can use `copy_buf` with [`BufReader`] to change the 242 /// buffer capacity. 243 /// 244 /// [std]: std::io::copy 245 /// [`copy_buf`]: crate::io::copy_buf 246 /// [`AsyncBufRead`]: crate::io::AsyncBufRead 247 /// [`BufReader`]: crate::io::BufReader 248 /// 249 /// # Errors 250 /// 251 /// The returned future will return an error immediately if any call to 252 /// `poll_read` or `poll_write` returns an error. 253 /// 254 /// # Examples 255 /// 256 /// ``` 257 /// use tokio::io; 258 /// 259 /// # async fn dox() -> std::io::Result<()> { 260 /// let mut reader: &[u8] = b"hello"; 261 /// let mut writer: Vec<u8> = vec![]; 262 /// 263 /// io::copy(&mut reader, &mut writer).await?; 264 /// 265 /// assert_eq!(&b"hello"[..], &writer[..]); 266 /// # Ok(()) 267 /// # } 268 /// ``` 269 pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> 270 where 271 R: AsyncRead + Unpin + ?Sized, 272 W: AsyncWrite + Unpin + ?Sized, 273 { 274 Copy { 275 reader, 276 writer, 277 buf: CopyBuffer::new(super::DEFAULT_BUF_SIZE) 278 }.await 279 } 280 } 281 282 impl<R, W> Future for Copy<'_, R, W> 283 where 284 R: AsyncRead + Unpin + ?Sized, 285 W: AsyncWrite + Unpin + ?Sized, 286 { 287 type Output = io::Result<u64>; 288 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>289 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { 290 let me = &mut *self; 291 292 me.buf 293 .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer)) 294 } 295 } 296