1 use std::marker::Unpin;
2 use std::pin::Pin;
3 use std::task::{Context, Poll};
4 use std::{cmp, io};
5 
6 use bytes::{Buf, Bytes};
7 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8 
9 /// Combine a buffer with an IO, rewinding reads to use the buffer.
10 #[derive(Debug)]
11 pub(crate) struct Rewind<T> {
12     pre: Option<Bytes>,
13     inner: T,
14 }
15 
16 impl<T> Rewind<T> {
17     #[cfg(any(all(feature = "http2", feature = "server"), test))]
new(io: T) -> Self18     pub(crate) fn new(io: T) -> Self {
19         Rewind {
20             pre: None,
21             inner: io,
22         }
23     }
24 
new_buffered(io: T, buf: Bytes) -> Self25     pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
26         Rewind {
27             pre: Some(buf),
28             inner: io,
29         }
30     }
31 
32     #[cfg(any(all(feature = "http1", feature = "http2", feature = "server"), test))]
rewind(&mut self, bs: Bytes)33     pub(crate) fn rewind(&mut self, bs: Bytes) {
34         debug_assert!(self.pre.is_none());
35         self.pre = Some(bs);
36     }
37 
into_inner(self) -> (T, Bytes)38     pub(crate) fn into_inner(self) -> (T, Bytes) {
39         (self.inner, self.pre.unwrap_or_else(Bytes::new))
40     }
41 
42     // pub(crate) fn get_mut(&mut self) -> &mut T {
43     //     &mut self.inner
44     // }
45 }
46 
47 impl<T> AsyncRead for Rewind<T>
48 where
49     T: AsyncRead + Unpin,
50 {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>51     fn poll_read(
52         mut self: Pin<&mut Self>,
53         cx: &mut Context<'_>,
54         buf: &mut ReadBuf<'_>,
55     ) -> Poll<io::Result<()>> {
56         if let Some(mut prefix) = self.pre.take() {
57             // If there are no remaining bytes, let the bytes get dropped.
58             if !prefix.is_empty() {
59                 let copy_len = cmp::min(prefix.len(), buf.remaining());
60                 // TODO: There should be a way to do following two lines cleaner...
61                 buf.put_slice(&prefix[..copy_len]);
62                 prefix.advance(copy_len);
63                 // Put back what's left
64                 if !prefix.is_empty() {
65                     self.pre = Some(prefix);
66                 }
67 
68                 return Poll::Ready(Ok(()));
69             }
70         }
71         Pin::new(&mut self.inner).poll_read(cx, buf)
72     }
73 }
74 
75 impl<T> AsyncWrite for Rewind<T>
76 where
77     T: AsyncWrite + Unpin,
78 {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>79     fn poll_write(
80         mut self: Pin<&mut Self>,
81         cx: &mut Context<'_>,
82         buf: &[u8],
83     ) -> Poll<io::Result<usize>> {
84         Pin::new(&mut self.inner).poll_write(cx, buf)
85     }
86 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>87     fn poll_write_vectored(
88         mut self: Pin<&mut Self>,
89         cx: &mut Context<'_>,
90         bufs: &[io::IoSlice<'_>],
91     ) -> Poll<io::Result<usize>> {
92         Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
93     }
94 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>95     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
96         Pin::new(&mut self.inner).poll_flush(cx)
97     }
98 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>99     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
100         Pin::new(&mut self.inner).poll_shutdown(cx)
101     }
102 
is_write_vectored(&self) -> bool103     fn is_write_vectored(&self) -> bool {
104         self.inner.is_write_vectored()
105     }
106 }
107 
108 #[cfg(test)]
109 mod tests {
110     // FIXME: re-implement tests with `async/await`, this import should
111     // trigger a warning to remind us
112     use super::Rewind;
113     use bytes::Bytes;
114     use tokio::io::AsyncReadExt;
115 
116     #[tokio::test]
partial_rewind()117     async fn partial_rewind() {
118         let underlying = [104, 101, 108, 108, 111];
119 
120         let mock = tokio_test::io::Builder::new().read(&underlying).build();
121 
122         let mut stream = Rewind::new(mock);
123 
124         // Read off some bytes, ensure we filled o1
125         let mut buf = [0; 2];
126         stream.read_exact(&mut buf).await.expect("read1");
127 
128         // Rewind the stream so that it is as if we never read in the first place.
129         stream.rewind(Bytes::copy_from_slice(&buf[..]));
130 
131         let mut buf = [0; 5];
132         stream.read_exact(&mut buf).await.expect("read1");
133 
134         // At this point we should have read everything that was in the MockStream
135         assert_eq!(&buf, &underlying);
136     }
137 
138     #[tokio::test]
full_rewind()139     async fn full_rewind() {
140         let underlying = [104, 101, 108, 108, 111];
141 
142         let mock = tokio_test::io::Builder::new().read(&underlying).build();
143 
144         let mut stream = Rewind::new(mock);
145 
146         let mut buf = [0; 5];
147         stream.read_exact(&mut buf).await.expect("read1");
148 
149         // Rewind the stream so that it is as if we never read in the first place.
150         stream.rewind(Bytes::copy_from_slice(&buf[..]));
151 
152         let mut buf = [0; 5];
153         stream.read_exact(&mut buf).await.expect("read1");
154     }
155 }
156