1 use crate::io::util::DEFAULT_BUF_SIZE;
2 use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
3 
4 use pin_project_lite::pin_project;
5 use std::fmt;
6 use std::io::{self, IoSlice, SeekFrom, Write};
7 use std::pin::Pin;
8 use std::task::{ready, Context, Poll};
9 
10 pin_project! {
11     /// Wraps a writer and buffers its output.
12     ///
13     /// It can be excessively inefficient to work directly with something that
14     /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and
15     /// writes it to an underlying writer in large, infrequent batches.
16     ///
17     /// `BufWriter` can improve the speed of programs that make *small* and
18     /// *repeated* write calls to the same file or network socket. It does not
19     /// help when writing very large amounts at once, or writing just one or a few
20     /// times. It also provides no advantage when writing to a destination that is
21     /// in memory, like a `Vec<u8>`.
22     ///
23     /// When the `BufWriter` is dropped, the contents of its buffer will be
24     /// discarded. Creating multiple instances of a `BufWriter` on the same
25     /// stream can cause data loss. If you need to write out the contents of its
26     /// buffer, you must manually call flush before the writer is dropped.
27     ///
28     /// [`AsyncWrite`]: AsyncWrite
29     /// [`flush`]: super::AsyncWriteExt::flush
30     ///
31     #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
32     pub struct BufWriter<W> {
33         #[pin]
34         pub(super) inner: W,
35         pub(super) buf: Vec<u8>,
36         pub(super) written: usize,
37         pub(super) seek_state: SeekState,
38     }
39 }
40 
41 impl<W: AsyncWrite> BufWriter<W> {
42     /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
43     /// but may change in the future.
new(inner: W) -> Self44     pub fn new(inner: W) -> Self {
45         Self::with_capacity(DEFAULT_BUF_SIZE, inner)
46     }
47 
48     /// Creates a new `BufWriter` with the specified buffer capacity.
with_capacity(cap: usize, inner: W) -> Self49     pub fn with_capacity(cap: usize, inner: W) -> Self {
50         Self {
51             inner,
52             buf: Vec::with_capacity(cap),
53             written: 0,
54             seek_state: SeekState::Init,
55         }
56     }
57 
flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>58     fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
59         let mut me = self.project();
60 
61         let len = me.buf.len();
62         let mut ret = Ok(());
63         while *me.written < len {
64             match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) {
65                 Ok(0) => {
66                     ret = Err(io::Error::new(
67                         io::ErrorKind::WriteZero,
68                         "failed to write the buffered data",
69                     ));
70                     break;
71                 }
72                 Ok(n) => *me.written += n,
73                 Err(e) => {
74                     ret = Err(e);
75                     break;
76                 }
77             }
78         }
79         if *me.written > 0 {
80             me.buf.drain(..*me.written);
81         }
82         *me.written = 0;
83         Poll::Ready(ret)
84     }
85 
86     /// Gets a reference to the underlying writer.
get_ref(&self) -> &W87     pub fn get_ref(&self) -> &W {
88         &self.inner
89     }
90 
91     /// Gets a mutable reference to the underlying writer.
92     ///
93     /// It is inadvisable to directly write to the underlying writer.
get_mut(&mut self) -> &mut W94     pub fn get_mut(&mut self) -> &mut W {
95         &mut self.inner
96     }
97 
98     /// Gets a pinned mutable reference to the underlying writer.
99     ///
100     /// It is inadvisable to directly write to the underlying writer.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W>101     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
102         self.project().inner
103     }
104 
105     /// Consumes this `BufWriter`, returning the underlying writer.
106     ///
107     /// Note that any leftover data in the internal buffer is lost.
into_inner(self) -> W108     pub fn into_inner(self) -> W {
109         self.inner
110     }
111 
112     /// Returns a reference to the internally buffered data.
buffer(&self) -> &[u8]113     pub fn buffer(&self) -> &[u8] {
114         &self.buf
115     }
116 }
117 
118 impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>119     fn poll_write(
120         mut self: Pin<&mut Self>,
121         cx: &mut Context<'_>,
122         buf: &[u8],
123     ) -> Poll<io::Result<usize>> {
124         if self.buf.len() + buf.len() > self.buf.capacity() {
125             ready!(self.as_mut().flush_buf(cx))?;
126         }
127 
128         let me = self.project();
129         if buf.len() >= me.buf.capacity() {
130             me.inner.poll_write(cx, buf)
131         } else {
132             Poll::Ready(me.buf.write(buf))
133         }
134     }
135 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, mut bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>136     fn poll_write_vectored(
137         mut self: Pin<&mut Self>,
138         cx: &mut Context<'_>,
139         mut bufs: &[IoSlice<'_>],
140     ) -> Poll<io::Result<usize>> {
141         if self.inner.is_write_vectored() {
142             let total_len = bufs
143                 .iter()
144                 .fold(0usize, |acc, b| acc.saturating_add(b.len()));
145             if total_len > self.buf.capacity() - self.buf.len() {
146                 ready!(self.as_mut().flush_buf(cx))?;
147             }
148             let me = self.as_mut().project();
149             if total_len >= me.buf.capacity() {
150                 // It's more efficient to pass the slices directly to the
151                 // underlying writer than to buffer them.
152                 // The case when the total_len calculation saturates at
153                 // usize::MAX is also handled here.
154                 me.inner.poll_write_vectored(cx, bufs)
155             } else {
156                 bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
157                 Poll::Ready(Ok(total_len))
158             }
159         } else {
160             // Remove empty buffers at the beginning of bufs.
161             while bufs.first().map(|buf| buf.len()) == Some(0) {
162                 bufs = &bufs[1..];
163             }
164             if bufs.is_empty() {
165                 return Poll::Ready(Ok(0));
166             }
167             // Flush if the first buffer doesn't fit.
168             let first_len = bufs[0].len();
169             if first_len > self.buf.capacity() - self.buf.len() {
170                 ready!(self.as_mut().flush_buf(cx))?;
171                 debug_assert!(self.buf.is_empty());
172             }
173             let me = self.as_mut().project();
174             if first_len >= me.buf.capacity() {
175                 // The slice is at least as large as the buffering capacity,
176                 // so it's better to write it directly, bypassing the buffer.
177                 debug_assert!(me.buf.is_empty());
178                 return me.inner.poll_write(cx, &bufs[0]);
179             } else {
180                 me.buf.extend_from_slice(&bufs[0]);
181                 bufs = &bufs[1..];
182             }
183             let mut total_written = first_len;
184             debug_assert!(total_written != 0);
185             // Append the buffers that fit in the internal buffer.
186             for buf in bufs {
187                 if buf.len() > me.buf.capacity() - me.buf.len() {
188                     break;
189                 } else {
190                     me.buf.extend_from_slice(buf);
191                     total_written += buf.len();
192                 }
193             }
194             Poll::Ready(Ok(total_written))
195         }
196     }
197 
is_write_vectored(&self) -> bool198     fn is_write_vectored(&self) -> bool {
199         true
200     }
201 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>202     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203         ready!(self.as_mut().flush_buf(cx))?;
204         self.get_pin_mut().poll_flush(cx)
205     }
206 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>207     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208         ready!(self.as_mut().flush_buf(cx))?;
209         self.get_pin_mut().poll_shutdown(cx)
210     }
211 }
212 
213 #[derive(Debug, Clone, Copy)]
214 pub(super) enum SeekState {
215     /// `start_seek` has not been called.
216     Init,
217     /// `start_seek` has been called, but `poll_complete` has not yet been called.
218     Start(SeekFrom),
219     /// Waiting for completion of `poll_complete`.
220     Pending,
221 }
222 
223 /// Seek to the offset, in bytes, in the underlying writer.
224 ///
225 /// Seeking always writes out the internal buffer before seeking.
226 impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> {
start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()>227     fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
228         // We need to flush the internal buffer before seeking.
229         // It receives a `Context` and returns a `Poll`, so it cannot be called
230         // inside `start_seek`.
231         *self.project().seek_state = SeekState::Start(pos);
232         Ok(())
233     }
234 
poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>235     fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
236         let pos = match self.seek_state {
237             SeekState::Init => {
238                 return self.project().inner.poll_complete(cx);
239             }
240             SeekState::Start(pos) => Some(pos),
241             SeekState::Pending => None,
242         };
243 
244         // Flush the internal buffer before seeking.
245         ready!(self.as_mut().flush_buf(cx))?;
246 
247         let mut me = self.project();
248         if let Some(pos) = pos {
249             // Ensure previous seeks have finished before starting a new one
250             ready!(me.inner.as_mut().poll_complete(cx))?;
251             if let Err(e) = me.inner.as_mut().start_seek(pos) {
252                 *me.seek_state = SeekState::Init;
253                 return Poll::Ready(Err(e));
254             }
255         }
256         match me.inner.poll_complete(cx) {
257             Poll::Ready(res) => {
258                 *me.seek_state = SeekState::Init;
259                 Poll::Ready(res)
260             }
261             Poll::Pending => {
262                 *me.seek_state = SeekState::Pending;
263                 Poll::Pending
264             }
265         }
266     }
267 }
268 
269 impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>270     fn poll_read(
271         self: Pin<&mut Self>,
272         cx: &mut Context<'_>,
273         buf: &mut ReadBuf<'_>,
274     ) -> Poll<io::Result<()>> {
275         self.get_pin_mut().poll_read(cx, buf)
276     }
277 }
278 
279 impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>280     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
281         self.get_pin_mut().poll_fill_buf(cx)
282     }
283 
consume(self: Pin<&mut Self>, amt: usize)284     fn consume(self: Pin<&mut Self>, amt: usize) {
285         self.get_pin_mut().consume(amt);
286     }
287 }
288 
289 impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result290     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291         f.debug_struct("BufWriter")
292             .field("writer", &self.inner)
293             .field(
294                 "buffer",
295                 &format_args!("{}/{}", self.buf.len(), self.buf.capacity()),
296             )
297             .field("written", &self.written)
298             .finish()
299     }
300 }
301 
302 #[cfg(test)]
303 mod tests {
304     use super::*;
305 
306     #[test]
assert_unpin()307     fn assert_unpin() {
308         crate::is_unpin::<BufWriter<()>>();
309     }
310 }
311