1 use crate::io::util::DEFAULT_BUF_SIZE; 2 use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; 3 4 use pin_project_lite::pin_project; 5 use std::fmt; 6 use std::io::{self, IoSlice, SeekFrom, Write}; 7 use std::pin::Pin; 8 use std::task::{ready, Context, Poll}; 9 10 pin_project! { 11 /// Wraps a writer and buffers its output. 12 /// 13 /// It can be excessively inefficient to work directly with something that 14 /// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and 15 /// writes it to an underlying writer in large, infrequent batches. 16 /// 17 /// `BufWriter` can improve the speed of programs that make *small* and 18 /// *repeated* write calls to the same file or network socket. It does not 19 /// help when writing very large amounts at once, or writing just one or a few 20 /// times. It also provides no advantage when writing to a destination that is 21 /// in memory, like a `Vec<u8>`. 22 /// 23 /// When the `BufWriter` is dropped, the contents of its buffer will be 24 /// discarded. Creating multiple instances of a `BufWriter` on the same 25 /// stream can cause data loss. If you need to write out the contents of its 26 /// buffer, you must manually call flush before the writer is dropped. 27 /// 28 /// [`AsyncWrite`]: AsyncWrite 29 /// [`flush`]: super::AsyncWriteExt::flush 30 /// 31 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] 32 pub struct BufWriter<W> { 33 #[pin] 34 pub(super) inner: W, 35 pub(super) buf: Vec<u8>, 36 pub(super) written: usize, 37 pub(super) seek_state: SeekState, 38 } 39 } 40 41 impl<W: AsyncWrite> BufWriter<W> { 42 /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, 43 /// but may change in the future. new(inner: W) -> Self44 pub fn new(inner: W) -> Self { 45 Self::with_capacity(DEFAULT_BUF_SIZE, inner) 46 } 47 48 /// Creates a new `BufWriter` with the specified buffer capacity. with_capacity(cap: usize, inner: W) -> Self49 pub fn with_capacity(cap: usize, inner: W) -> Self { 50 Self { 51 inner, 52 buf: Vec::with_capacity(cap), 53 written: 0, 54 seek_state: SeekState::Init, 55 } 56 } 57 flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>58 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 59 let mut me = self.project(); 60 61 let len = me.buf.len(); 62 let mut ret = Ok(()); 63 while *me.written < len { 64 match ready!(me.inner.as_mut().poll_write(cx, &me.buf[*me.written..])) { 65 Ok(0) => { 66 ret = Err(io::Error::new( 67 io::ErrorKind::WriteZero, 68 "failed to write the buffered data", 69 )); 70 break; 71 } 72 Ok(n) => *me.written += n, 73 Err(e) => { 74 ret = Err(e); 75 break; 76 } 77 } 78 } 79 if *me.written > 0 { 80 me.buf.drain(..*me.written); 81 } 82 *me.written = 0; 83 Poll::Ready(ret) 84 } 85 86 /// Gets a reference to the underlying writer. get_ref(&self) -> &W87 pub fn get_ref(&self) -> &W { 88 &self.inner 89 } 90 91 /// Gets a mutable reference to the underlying writer. 92 /// 93 /// It is inadvisable to directly write to the underlying writer. get_mut(&mut self) -> &mut W94 pub fn get_mut(&mut self) -> &mut W { 95 &mut self.inner 96 } 97 98 /// Gets a pinned mutable reference to the underlying writer. 99 /// 100 /// It is inadvisable to directly write to the underlying writer. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W>101 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { 102 self.project().inner 103 } 104 105 /// Consumes this `BufWriter`, returning the underlying writer. 106 /// 107 /// Note that any leftover data in the internal buffer is lost. into_inner(self) -> W108 pub fn into_inner(self) -> W { 109 self.inner 110 } 111 112 /// Returns a reference to the internally buffered data. buffer(&self) -> &[u8]113 pub fn buffer(&self) -> &[u8] { 114 &self.buf 115 } 116 } 117 118 impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>119 fn poll_write( 120 mut self: Pin<&mut Self>, 121 cx: &mut Context<'_>, 122 buf: &[u8], 123 ) -> Poll<io::Result<usize>> { 124 if self.buf.len() + buf.len() > self.buf.capacity() { 125 ready!(self.as_mut().flush_buf(cx))?; 126 } 127 128 let me = self.project(); 129 if buf.len() >= me.buf.capacity() { 130 me.inner.poll_write(cx, buf) 131 } else { 132 Poll::Ready(me.buf.write(buf)) 133 } 134 } 135 poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, mut bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>136 fn poll_write_vectored( 137 mut self: Pin<&mut Self>, 138 cx: &mut Context<'_>, 139 mut bufs: &[IoSlice<'_>], 140 ) -> Poll<io::Result<usize>> { 141 if self.inner.is_write_vectored() { 142 let total_len = bufs 143 .iter() 144 .fold(0usize, |acc, b| acc.saturating_add(b.len())); 145 if total_len > self.buf.capacity() - self.buf.len() { 146 ready!(self.as_mut().flush_buf(cx))?; 147 } 148 let me = self.as_mut().project(); 149 if total_len >= me.buf.capacity() { 150 // It's more efficient to pass the slices directly to the 151 // underlying writer than to buffer them. 152 // The case when the total_len calculation saturates at 153 // usize::MAX is also handled here. 154 me.inner.poll_write_vectored(cx, bufs) 155 } else { 156 bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); 157 Poll::Ready(Ok(total_len)) 158 } 159 } else { 160 // Remove empty buffers at the beginning of bufs. 161 while bufs.first().map(|buf| buf.len()) == Some(0) { 162 bufs = &bufs[1..]; 163 } 164 if bufs.is_empty() { 165 return Poll::Ready(Ok(0)); 166 } 167 // Flush if the first buffer doesn't fit. 168 let first_len = bufs[0].len(); 169 if first_len > self.buf.capacity() - self.buf.len() { 170 ready!(self.as_mut().flush_buf(cx))?; 171 debug_assert!(self.buf.is_empty()); 172 } 173 let me = self.as_mut().project(); 174 if first_len >= me.buf.capacity() { 175 // The slice is at least as large as the buffering capacity, 176 // so it's better to write it directly, bypassing the buffer. 177 debug_assert!(me.buf.is_empty()); 178 return me.inner.poll_write(cx, &bufs[0]); 179 } else { 180 me.buf.extend_from_slice(&bufs[0]); 181 bufs = &bufs[1..]; 182 } 183 let mut total_written = first_len; 184 debug_assert!(total_written != 0); 185 // Append the buffers that fit in the internal buffer. 186 for buf in bufs { 187 if buf.len() > me.buf.capacity() - me.buf.len() { 188 break; 189 } else { 190 me.buf.extend_from_slice(buf); 191 total_written += buf.len(); 192 } 193 } 194 Poll::Ready(Ok(total_written)) 195 } 196 } 197 is_write_vectored(&self) -> bool198 fn is_write_vectored(&self) -> bool { 199 true 200 } 201 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>202 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 203 ready!(self.as_mut().flush_buf(cx))?; 204 self.get_pin_mut().poll_flush(cx) 205 } 206 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>207 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 208 ready!(self.as_mut().flush_buf(cx))?; 209 self.get_pin_mut().poll_shutdown(cx) 210 } 211 } 212 213 #[derive(Debug, Clone, Copy)] 214 pub(super) enum SeekState { 215 /// `start_seek` has not been called. 216 Init, 217 /// `start_seek` has been called, but `poll_complete` has not yet been called. 218 Start(SeekFrom), 219 /// Waiting for completion of `poll_complete`. 220 Pending, 221 } 222 223 /// Seek to the offset, in bytes, in the underlying writer. 224 /// 225 /// Seeking always writes out the internal buffer before seeking. 226 impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> { start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()>227 fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { 228 // We need to flush the internal buffer before seeking. 229 // It receives a `Context` and returns a `Poll`, so it cannot be called 230 // inside `start_seek`. 231 *self.project().seek_state = SeekState::Start(pos); 232 Ok(()) 233 } 234 poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>235 fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { 236 let pos = match self.seek_state { 237 SeekState::Init => { 238 return self.project().inner.poll_complete(cx); 239 } 240 SeekState::Start(pos) => Some(pos), 241 SeekState::Pending => None, 242 }; 243 244 // Flush the internal buffer before seeking. 245 ready!(self.as_mut().flush_buf(cx))?; 246 247 let mut me = self.project(); 248 if let Some(pos) = pos { 249 // Ensure previous seeks have finished before starting a new one 250 ready!(me.inner.as_mut().poll_complete(cx))?; 251 if let Err(e) = me.inner.as_mut().start_seek(pos) { 252 *me.seek_state = SeekState::Init; 253 return Poll::Ready(Err(e)); 254 } 255 } 256 match me.inner.poll_complete(cx) { 257 Poll::Ready(res) => { 258 *me.seek_state = SeekState::Init; 259 Poll::Ready(res) 260 } 261 Poll::Pending => { 262 *me.seek_state = SeekState::Pending; 263 Poll::Pending 264 } 265 } 266 } 267 } 268 269 impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>270 fn poll_read( 271 self: Pin<&mut Self>, 272 cx: &mut Context<'_>, 273 buf: &mut ReadBuf<'_>, 274 ) -> Poll<io::Result<()>> { 275 self.get_pin_mut().poll_read(cx, buf) 276 } 277 } 278 279 impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>280 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 281 self.get_pin_mut().poll_fill_buf(cx) 282 } 283 consume(self: Pin<&mut Self>, amt: usize)284 fn consume(self: Pin<&mut Self>, amt: usize) { 285 self.get_pin_mut().consume(amt); 286 } 287 } 288 289 impl<W: fmt::Debug> fmt::Debug for BufWriter<W> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 291 f.debug_struct("BufWriter") 292 .field("writer", &self.inner) 293 .field( 294 "buffer", 295 &format_args!("{}/{}", self.buf.len(), self.buf.capacity()), 296 ) 297 .field("written", &self.written) 298 .finish() 299 } 300 } 301 302 #[cfg(test)] 303 mod tests { 304 use super::*; 305 306 #[test] assert_unpin()307 fn assert_unpin() { 308 crate::is_unpin::<BufWriter<()>>(); 309 } 310 } 311