1 use std::marker::Unpin; 2 use std::pin::Pin; 3 use std::task::{Context, Poll}; 4 use std::{cmp, io}; 5 6 use bytes::{Buf, Bytes}; 7 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 9 /// Combine a buffer with an IO, rewinding reads to use the buffer. 10 #[derive(Debug)] 11 pub(crate) struct Rewind<T> { 12 pre: Option<Bytes>, 13 inner: T, 14 } 15 16 impl<T> Rewind<T> { 17 #[cfg(any(all(feature = "http2", feature = "server"), test))] new(io: T) -> Self18 pub(crate) fn new(io: T) -> Self { 19 Rewind { 20 pre: None, 21 inner: io, 22 } 23 } 24 new_buffered(io: T, buf: Bytes) -> Self25 pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { 26 Rewind { 27 pre: Some(buf), 28 inner: io, 29 } 30 } 31 32 #[cfg(any(all(feature = "http1", feature = "http2", feature = "server"), test))] rewind(&mut self, bs: Bytes)33 pub(crate) fn rewind(&mut self, bs: Bytes) { 34 debug_assert!(self.pre.is_none()); 35 self.pre = Some(bs); 36 } 37 into_inner(self) -> (T, Bytes)38 pub(crate) fn into_inner(self) -> (T, Bytes) { 39 (self.inner, self.pre.unwrap_or_else(Bytes::new)) 40 } 41 42 // pub(crate) fn get_mut(&mut self) -> &mut T { 43 // &mut self.inner 44 // } 45 } 46 47 impl<T> AsyncRead for Rewind<T> 48 where 49 T: AsyncRead + Unpin, 50 { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>51 fn poll_read( 52 mut self: Pin<&mut Self>, 53 cx: &mut Context<'_>, 54 buf: &mut ReadBuf<'_>, 55 ) -> Poll<io::Result<()>> { 56 if let Some(mut prefix) = self.pre.take() { 57 // If there are no remaining bytes, let the bytes get dropped. 58 if !prefix.is_empty() { 59 let copy_len = cmp::min(prefix.len(), buf.remaining()); 60 // TODO: There should be a way to do following two lines cleaner... 61 buf.put_slice(&prefix[..copy_len]); 62 prefix.advance(copy_len); 63 // Put back what's left 64 if !prefix.is_empty() { 65 self.pre = Some(prefix); 66 } 67 68 return Poll::Ready(Ok(())); 69 } 70 } 71 Pin::new(&mut self.inner).poll_read(cx, buf) 72 } 73 } 74 75 impl<T> AsyncWrite for Rewind<T> 76 where 77 T: AsyncWrite + Unpin, 78 { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>79 fn poll_write( 80 mut self: Pin<&mut Self>, 81 cx: &mut Context<'_>, 82 buf: &[u8], 83 ) -> Poll<io::Result<usize>> { 84 Pin::new(&mut self.inner).poll_write(cx, buf) 85 } 86 poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>87 fn poll_write_vectored( 88 mut self: Pin<&mut Self>, 89 cx: &mut Context<'_>, 90 bufs: &[io::IoSlice<'_>], 91 ) -> Poll<io::Result<usize>> { 92 Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) 93 } 94 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>95 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 96 Pin::new(&mut self.inner).poll_flush(cx) 97 } 98 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>99 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 100 Pin::new(&mut self.inner).poll_shutdown(cx) 101 } 102 is_write_vectored(&self) -> bool103 fn is_write_vectored(&self) -> bool { 104 self.inner.is_write_vectored() 105 } 106 } 107 108 #[cfg(test)] 109 mod tests { 110 // FIXME: re-implement tests with `async/await`, this import should 111 // trigger a warning to remind us 112 use super::Rewind; 113 use bytes::Bytes; 114 use tokio::io::AsyncReadExt; 115 116 #[tokio::test] partial_rewind()117 async fn partial_rewind() { 118 let underlying = [104, 101, 108, 108, 111]; 119 120 let mock = tokio_test::io::Builder::new().read(&underlying).build(); 121 122 let mut stream = Rewind::new(mock); 123 124 // Read off some bytes, ensure we filled o1 125 let mut buf = [0; 2]; 126 stream.read_exact(&mut buf).await.expect("read1"); 127 128 // Rewind the stream so that it is as if we never read in the first place. 129 stream.rewind(Bytes::copy_from_slice(&buf[..])); 130 131 let mut buf = [0; 5]; 132 stream.read_exact(&mut buf).await.expect("read1"); 133 134 // At this point we should have read everything that was in the MockStream 135 assert_eq!(&buf, &underlying); 136 } 137 138 #[tokio::test] full_rewind()139 async fn full_rewind() { 140 let underlying = [104, 101, 108, 108, 111]; 141 142 let mock = tokio_test::io::Builder::new().read(&underlying).build(); 143 144 let mut stream = Rewind::new(mock); 145 146 let mut buf = [0; 5]; 147 stream.read_exact(&mut buf).await.expect("read1"); 148 149 // Rewind the stream so that it is as if we never read in the first place. 150 stream.rewind(Bytes::copy_from_slice(&buf[..])); 151 152 let mut buf = [0; 5]; 153 stream.read_exact(&mut buf).await.expect("read1"); 154 } 155 } 156