1 use crate::io::util::DEFAULT_BUF_SIZE; 2 use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; 3 4 use pin_project_lite::pin_project; 5 use std::io::{self, IoSlice, SeekFrom}; 6 use std::pin::Pin; 7 use std::task::{ready, Context, Poll}; 8 use std::{cmp, fmt, mem}; 9 10 pin_project! { 11 /// The `BufReader` struct adds buffering to any reader. 12 /// 13 /// It can be excessively inefficient to work directly with a [`AsyncRead`] 14 /// instance. A `BufReader` performs large, infrequent reads on the underlying 15 /// [`AsyncRead`] and maintains an in-memory buffer of the results. 16 /// 17 /// `BufReader` can improve the speed of programs that make *small* and 18 /// *repeated* read calls to the same file or network socket. It does not 19 /// help when reading very large amounts at once, or reading just one or a few 20 /// times. It also provides no advantage when reading from a source that is 21 /// already in memory, like a `Vec<u8>`. 22 /// 23 /// When the `BufReader` is dropped, the contents of its buffer will be 24 /// discarded. Creating multiple instances of a `BufReader` on the same 25 /// stream can cause data loss. 26 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] 27 pub struct BufReader<R> { 28 #[pin] 29 pub(super) inner: R, 30 pub(super) buf: Box<[u8]>, 31 pub(super) pos: usize, 32 pub(super) cap: usize, 33 pub(super) seek_state: SeekState, 34 } 35 } 36 37 impl<R: AsyncRead> BufReader<R> { 38 /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB, 39 /// but may change in the future. new(inner: R) -> Self40 pub fn new(inner: R) -> Self { 41 Self::with_capacity(DEFAULT_BUF_SIZE, inner) 42 } 43 44 /// Creates a new `BufReader` with the specified buffer capacity. with_capacity(capacity: usize, inner: R) -> Self45 pub fn with_capacity(capacity: usize, inner: R) -> Self { 46 let buffer = vec![0; capacity]; 47 Self { 48 inner, 49 buf: buffer.into_boxed_slice(), 50 pos: 0, 51 cap: 0, 52 seek_state: SeekState::Init, 53 } 54 } 55 56 /// Gets a reference to the underlying reader. 57 /// 58 /// It is inadvisable to directly read from the underlying reader. get_ref(&self) -> &R59 pub fn get_ref(&self) -> &R { 60 &self.inner 61 } 62 63 /// Gets a mutable reference to the underlying reader. 64 /// 65 /// It is inadvisable to directly read from the underlying reader. get_mut(&mut self) -> &mut R66 pub fn get_mut(&mut self) -> &mut R { 67 &mut self.inner 68 } 69 70 /// Gets a pinned mutable reference to the underlying reader. 71 /// 72 /// It is inadvisable to directly read from the underlying reader. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R>73 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { 74 self.project().inner 75 } 76 77 /// Consumes this `BufReader`, returning the underlying reader. 78 /// 79 /// Note that any leftover data in the internal buffer is lost. into_inner(self) -> R80 pub fn into_inner(self) -> R { 81 self.inner 82 } 83 84 /// Returns a reference to the internally buffered data. 85 /// 86 /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty. buffer(&self) -> &[u8]87 pub fn buffer(&self) -> &[u8] { 88 &self.buf[self.pos..self.cap] 89 } 90 91 /// Invalidates all data in the internal buffer. 92 #[inline] discard_buffer(self: Pin<&mut Self>)93 fn discard_buffer(self: Pin<&mut Self>) { 94 let me = self.project(); 95 *me.pos = 0; 96 *me.cap = 0; 97 } 98 } 99 100 impl<R: AsyncRead> AsyncRead for BufReader<R> { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>101 fn poll_read( 102 mut self: Pin<&mut Self>, 103 cx: &mut Context<'_>, 104 buf: &mut ReadBuf<'_>, 105 ) -> Poll<io::Result<()>> { 106 // If we don't have any buffered data and we're doing a massive read 107 // (larger than our internal buffer), bypass our internal buffer 108 // entirely. 109 if self.pos == self.cap && buf.remaining() >= self.buf.len() { 110 let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf)); 111 self.discard_buffer(); 112 return Poll::Ready(res); 113 } 114 let rem = ready!(self.as_mut().poll_fill_buf(cx))?; 115 let amt = std::cmp::min(rem.len(), buf.remaining()); 116 buf.put_slice(&rem[..amt]); 117 self.consume(amt); 118 Poll::Ready(Ok(())) 119 } 120 } 121 122 impl<R: AsyncRead> AsyncBufRead for BufReader<R> { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>123 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 124 let me = self.project(); 125 126 // If we've reached the end of our internal buffer then we need to fetch 127 // some more data from the underlying reader. 128 // Branch using `>=` instead of the more correct `==` 129 // to tell the compiler that the pos..cap slice is always valid. 130 if *me.pos >= *me.cap { 131 debug_assert!(*me.pos == *me.cap); 132 let mut buf = ReadBuf::new(me.buf); 133 ready!(me.inner.poll_read(cx, &mut buf))?; 134 *me.cap = buf.filled().len(); 135 *me.pos = 0; 136 } 137 Poll::Ready(Ok(&me.buf[*me.pos..*me.cap])) 138 } 139 consume(self: Pin<&mut Self>, amt: usize)140 fn consume(self: Pin<&mut Self>, amt: usize) { 141 let me = self.project(); 142 *me.pos = cmp::min(*me.pos + amt, *me.cap); 143 } 144 } 145 146 #[derive(Debug, Clone, Copy)] 147 pub(super) enum SeekState { 148 /// `start_seek` has not been called. 149 Init, 150 /// `start_seek` has been called, but `poll_complete` has not yet been called. 151 Start(SeekFrom), 152 /// Waiting for completion of the first `poll_complete` in the `n.checked_sub(remainder).is_none()` branch. 153 PendingOverflowed(i64), 154 /// Waiting for completion of `poll_complete`. 155 Pending, 156 } 157 158 /// Seeks to an offset, in bytes, in the underlying reader. 159 /// 160 /// The position used for seeking with `SeekFrom::Current(_)` is the 161 /// position the underlying reader would be at if the `BufReader` had no 162 /// internal buffer. 163 /// 164 /// Seeking always discards the internal buffer, even if the seek position 165 /// would otherwise fall within it. This guarantees that calling 166 /// `.into_inner()` immediately after a seek yields the underlying reader 167 /// at the same position. 168 /// 169 /// See [`AsyncSeek`] for more details. 170 /// 171 /// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` 172 /// where `n` minus the internal buffer length overflows an `i64`, two 173 /// seeks will be performed instead of one. If the second seek returns 174 /// `Err`, the underlying reader will be left at the same position it would 175 /// have if you called `seek` with `SeekFrom::Current(0)`. 176 impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()>177 fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { 178 // We needs to call seek operation multiple times. 179 // And we should always call both start_seek and poll_complete, 180 // as start_seek alone cannot guarantee that the operation will be completed. 181 // poll_complete receives a Context and returns a Poll, so it cannot be called 182 // inside start_seek. 183 *self.project().seek_state = SeekState::Start(pos); 184 Ok(()) 185 } 186 poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>187 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { 188 let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) { 189 SeekState::Init => { 190 // 1.x AsyncSeek recommends calling poll_complete before start_seek. 191 // We don't have to guarantee that the value returned by 192 // poll_complete called without start_seek is correct, 193 // so we'll return 0. 194 return Poll::Ready(Ok(0)); 195 } 196 SeekState::Start(SeekFrom::Current(n)) => { 197 let remainder = (self.cap - self.pos) as i64; 198 // it should be safe to assume that remainder fits within an i64 as the alternative 199 // means we managed to allocate 8 exbibytes and that's absurd. 200 // But it's not out of the realm of possibility for some weird underlying reader to 201 // support seeking by i64::MIN so we need to handle underflow when subtracting 202 // remainder. 203 if let Some(offset) = n.checked_sub(remainder) { 204 self.as_mut() 205 .get_pin_mut() 206 .start_seek(SeekFrom::Current(offset))?; 207 } else { 208 // seek backwards by our remainder, and then by the offset 209 self.as_mut() 210 .get_pin_mut() 211 .start_seek(SeekFrom::Current(-remainder))?; 212 if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { 213 *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); 214 return Poll::Pending; 215 } 216 217 // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 218 self.as_mut().discard_buffer(); 219 220 self.as_mut() 221 .get_pin_mut() 222 .start_seek(SeekFrom::Current(n))?; 223 } 224 self.as_mut().get_pin_mut().poll_complete(cx)? 225 } 226 SeekState::PendingOverflowed(n) => { 227 if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { 228 *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); 229 return Poll::Pending; 230 } 231 232 // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 233 self.as_mut().discard_buffer(); 234 235 self.as_mut() 236 .get_pin_mut() 237 .start_seek(SeekFrom::Current(n))?; 238 self.as_mut().get_pin_mut().poll_complete(cx)? 239 } 240 SeekState::Start(pos) => { 241 // Seeking with Start/End doesn't care about our buffer length. 242 self.as_mut().get_pin_mut().start_seek(pos)?; 243 self.as_mut().get_pin_mut().poll_complete(cx)? 244 } 245 SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?, 246 }; 247 248 match res { 249 Poll::Ready(res) => { 250 self.discard_buffer(); 251 Poll::Ready(Ok(res)) 252 } 253 Poll::Pending => { 254 *self.as_mut().project().seek_state = SeekState::Pending; 255 Poll::Pending 256 } 257 } 258 } 259 } 260 261 impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>262 fn poll_write( 263 self: Pin<&mut Self>, 264 cx: &mut Context<'_>, 265 buf: &[u8], 266 ) -> Poll<io::Result<usize>> { 267 self.get_pin_mut().poll_write(cx, buf) 268 } 269 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>270 fn poll_write_vectored( 271 self: Pin<&mut Self>, 272 cx: &mut Context<'_>, 273 bufs: &[IoSlice<'_>], 274 ) -> Poll<io::Result<usize>> { 275 self.get_pin_mut().poll_write_vectored(cx, bufs) 276 } 277 is_write_vectored(&self) -> bool278 fn is_write_vectored(&self) -> bool { 279 self.get_ref().is_write_vectored() 280 } 281 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>282 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 283 self.get_pin_mut().poll_flush(cx) 284 } 285 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>286 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 287 self.get_pin_mut().poll_shutdown(cx) 288 } 289 } 290 291 impl<R: fmt::Debug> fmt::Debug for BufReader<R> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 293 f.debug_struct("BufReader") 294 .field("reader", &self.inner) 295 .field( 296 "buffer", 297 &format_args!("{}/{}", self.cap - self.pos, self.buf.len()), 298 ) 299 .finish() 300 } 301 } 302 303 #[cfg(test)] 304 mod tests { 305 use super::*; 306 307 #[test] assert_unpin()308 fn assert_unpin() { 309 crate::is_unpin::<BufReader<()>>(); 310 } 311 } 312