1 use crate::io::util::DEFAULT_BUF_SIZE;
2 use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3 
4 use pin_project_lite::pin_project;
5 use std::io::{self, IoSlice, SeekFrom};
6 use std::pin::Pin;
7 use std::task::{ready, Context, Poll};
8 use std::{cmp, fmt, mem};
9 
10 pin_project! {
11     /// The `BufReader` struct adds buffering to any reader.
12     ///
13     /// It can be excessively inefficient to work directly with a [`AsyncRead`]
14     /// instance. A `BufReader` performs large, infrequent reads on the underlying
15     /// [`AsyncRead`] and maintains an in-memory buffer of the results.
16     ///
17     /// `BufReader` can improve the speed of programs that make *small* and
18     /// *repeated* read calls to the same file or network socket. It does not
19     /// help when reading very large amounts at once, or reading just one or a few
20     /// times. It also provides no advantage when reading from a source that is
21     /// already in memory, like a `Vec<u8>`.
22     ///
23     /// When the `BufReader` is dropped, the contents of its buffer will be
24     /// discarded. Creating multiple instances of a `BufReader` on the same
25     /// stream can cause data loss.
26     #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
27     pub struct BufReader<R> {
28         #[pin]
29         pub(super) inner: R,
30         pub(super) buf: Box<[u8]>,
31         pub(super) pos: usize,
32         pub(super) cap: usize,
33         pub(super) seek_state: SeekState,
34     }
35 }
36 
37 impl<R: AsyncRead> BufReader<R> {
38     /// Creates a new `BufReader` with a default buffer capacity. The default is currently 8 KB,
39     /// but may change in the future.
new(inner: R) -> Self40     pub fn new(inner: R) -> Self {
41         Self::with_capacity(DEFAULT_BUF_SIZE, inner)
42     }
43 
44     /// Creates a new `BufReader` with the specified buffer capacity.
with_capacity(capacity: usize, inner: R) -> Self45     pub fn with_capacity(capacity: usize, inner: R) -> Self {
46         let buffer = vec![0; capacity];
47         Self {
48             inner,
49             buf: buffer.into_boxed_slice(),
50             pos: 0,
51             cap: 0,
52             seek_state: SeekState::Init,
53         }
54     }
55 
56     /// Gets a reference to the underlying reader.
57     ///
58     /// It is inadvisable to directly read from the underlying reader.
get_ref(&self) -> &R59     pub fn get_ref(&self) -> &R {
60         &self.inner
61     }
62 
63     /// Gets a mutable reference to the underlying reader.
64     ///
65     /// It is inadvisable to directly read from the underlying reader.
get_mut(&mut self) -> &mut R66     pub fn get_mut(&mut self) -> &mut R {
67         &mut self.inner
68     }
69 
70     /// Gets a pinned mutable reference to the underlying reader.
71     ///
72     /// It is inadvisable to directly read from the underlying reader.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R>73     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
74         self.project().inner
75     }
76 
77     /// Consumes this `BufReader`, returning the underlying reader.
78     ///
79     /// Note that any leftover data in the internal buffer is lost.
into_inner(self) -> R80     pub fn into_inner(self) -> R {
81         self.inner
82     }
83 
84     /// Returns a reference to the internally buffered data.
85     ///
86     /// Unlike `fill_buf`, this will not attempt to fill the buffer if it is empty.
buffer(&self) -> &[u8]87     pub fn buffer(&self) -> &[u8] {
88         &self.buf[self.pos..self.cap]
89     }
90 
91     /// Invalidates all data in the internal buffer.
92     #[inline]
discard_buffer(self: Pin<&mut Self>)93     fn discard_buffer(self: Pin<&mut Self>) {
94         let me = self.project();
95         *me.pos = 0;
96         *me.cap = 0;
97     }
98 }
99 
100 impl<R: AsyncRead> AsyncRead for BufReader<R> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>101     fn poll_read(
102         mut self: Pin<&mut Self>,
103         cx: &mut Context<'_>,
104         buf: &mut ReadBuf<'_>,
105     ) -> Poll<io::Result<()>> {
106         // If we don't have any buffered data and we're doing a massive read
107         // (larger than our internal buffer), bypass our internal buffer
108         // entirely.
109         if self.pos == self.cap && buf.remaining() >= self.buf.len() {
110             let res = ready!(self.as_mut().get_pin_mut().poll_read(cx, buf));
111             self.discard_buffer();
112             return Poll::Ready(res);
113         }
114         let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
115         let amt = std::cmp::min(rem.len(), buf.remaining());
116         buf.put_slice(&rem[..amt]);
117         self.consume(amt);
118         Poll::Ready(Ok(()))
119     }
120 }
121 
122 impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>123     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
124         let me = self.project();
125 
126         // If we've reached the end of our internal buffer then we need to fetch
127         // some more data from the underlying reader.
128         // Branch using `>=` instead of the more correct `==`
129         // to tell the compiler that the pos..cap slice is always valid.
130         if *me.pos >= *me.cap {
131             debug_assert!(*me.pos == *me.cap);
132             let mut buf = ReadBuf::new(me.buf);
133             ready!(me.inner.poll_read(cx, &mut buf))?;
134             *me.cap = buf.filled().len();
135             *me.pos = 0;
136         }
137         Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
138     }
139 
consume(self: Pin<&mut Self>, amt: usize)140     fn consume(self: Pin<&mut Self>, amt: usize) {
141         let me = self.project();
142         *me.pos = cmp::min(*me.pos + amt, *me.cap);
143     }
144 }
145 
146 #[derive(Debug, Clone, Copy)]
147 pub(super) enum SeekState {
148     /// `start_seek` has not been called.
149     Init,
150     /// `start_seek` has been called, but `poll_complete` has not yet been called.
151     Start(SeekFrom),
152     /// Waiting for completion of the first `poll_complete` in the `n.checked_sub(remainder).is_none()` branch.
153     PendingOverflowed(i64),
154     /// Waiting for completion of `poll_complete`.
155     Pending,
156 }
157 
158 /// Seeks to an offset, in bytes, in the underlying reader.
159 ///
160 /// The position used for seeking with `SeekFrom::Current(_)` is the
161 /// position the underlying reader would be at if the `BufReader` had no
162 /// internal buffer.
163 ///
164 /// Seeking always discards the internal buffer, even if the seek position
165 /// would otherwise fall within it. This guarantees that calling
166 /// `.into_inner()` immediately after a seek yields the underlying reader
167 /// at the same position.
168 ///
169 /// See [`AsyncSeek`] for more details.
170 ///
171 /// Note: In the edge case where you're seeking with `SeekFrom::Current(n)`
172 /// where `n` minus the internal buffer length overflows an `i64`, two
173 /// seeks will be performed instead of one. If the second seek returns
174 /// `Err`, the underlying reader will be left at the same position it would
175 /// have if you called `seek` with `SeekFrom::Current(0)`.
176 impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> {
start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()>177     fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
178         // We needs to call seek operation multiple times.
179         // And we should always call both start_seek and poll_complete,
180         // as start_seek alone cannot guarantee that the operation will be completed.
181         // poll_complete receives a Context and returns a Poll, so it cannot be called
182         // inside start_seek.
183         *self.project().seek_state = SeekState::Start(pos);
184         Ok(())
185     }
186 
poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>187     fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
188         let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) {
189             SeekState::Init => {
190                 // 1.x AsyncSeek recommends calling poll_complete before start_seek.
191                 // We don't have to guarantee that the value returned by
192                 // poll_complete called without start_seek is correct,
193                 // so we'll return 0.
194                 return Poll::Ready(Ok(0));
195             }
196             SeekState::Start(SeekFrom::Current(n)) => {
197                 let remainder = (self.cap - self.pos) as i64;
198                 // it should be safe to assume that remainder fits within an i64 as the alternative
199                 // means we managed to allocate 8 exbibytes and that's absurd.
200                 // But it's not out of the realm of possibility for some weird underlying reader to
201                 // support seeking by i64::MIN so we need to handle underflow when subtracting
202                 // remainder.
203                 if let Some(offset) = n.checked_sub(remainder) {
204                     self.as_mut()
205                         .get_pin_mut()
206                         .start_seek(SeekFrom::Current(offset))?;
207                 } else {
208                     // seek backwards by our remainder, and then by the offset
209                     self.as_mut()
210                         .get_pin_mut()
211                         .start_seek(SeekFrom::Current(-remainder))?;
212                     if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
213                         *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
214                         return Poll::Pending;
215                     }
216 
217                     // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
218                     self.as_mut().discard_buffer();
219 
220                     self.as_mut()
221                         .get_pin_mut()
222                         .start_seek(SeekFrom::Current(n))?;
223                 }
224                 self.as_mut().get_pin_mut().poll_complete(cx)?
225             }
226             SeekState::PendingOverflowed(n) => {
227                 if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() {
228                     *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n);
229                     return Poll::Pending;
230                 }
231 
232                 // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676
233                 self.as_mut().discard_buffer();
234 
235                 self.as_mut()
236                     .get_pin_mut()
237                     .start_seek(SeekFrom::Current(n))?;
238                 self.as_mut().get_pin_mut().poll_complete(cx)?
239             }
240             SeekState::Start(pos) => {
241                 // Seeking with Start/End doesn't care about our buffer length.
242                 self.as_mut().get_pin_mut().start_seek(pos)?;
243                 self.as_mut().get_pin_mut().poll_complete(cx)?
244             }
245             SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?,
246         };
247 
248         match res {
249             Poll::Ready(res) => {
250                 self.discard_buffer();
251                 Poll::Ready(Ok(res))
252             }
253             Poll::Pending => {
254                 *self.as_mut().project().seek_state = SeekState::Pending;
255                 Poll::Pending
256             }
257         }
258     }
259 }
260 
261 impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>262     fn poll_write(
263         self: Pin<&mut Self>,
264         cx: &mut Context<'_>,
265         buf: &[u8],
266     ) -> Poll<io::Result<usize>> {
267         self.get_pin_mut().poll_write(cx, buf)
268     }
269 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>270     fn poll_write_vectored(
271         self: Pin<&mut Self>,
272         cx: &mut Context<'_>,
273         bufs: &[IoSlice<'_>],
274     ) -> Poll<io::Result<usize>> {
275         self.get_pin_mut().poll_write_vectored(cx, bufs)
276     }
277 
is_write_vectored(&self) -> bool278     fn is_write_vectored(&self) -> bool {
279         self.get_ref().is_write_vectored()
280     }
281 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>282     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
283         self.get_pin_mut().poll_flush(cx)
284     }
285 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>286     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287         self.get_pin_mut().poll_shutdown(cx)
288     }
289 }
290 
291 impl<R: fmt::Debug> fmt::Debug for BufReader<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result292     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293         f.debug_struct("BufReader")
294             .field("reader", &self.inner)
295             .field(
296                 "buffer",
297                 &format_args!("{}/{}", self.cap - self.pos, self.buf.len()),
298             )
299             .finish()
300     }
301 }
302 
303 #[cfg(test)]
304 mod tests {
305     use super::*;
306 
307     #[test]
assert_unpin()308     fn assert_unpin() {
309         crate::is_unpin::<BufReader<()>>();
310     }
311 }
312