1 use pin_project_lite::pin_project;
2 use std::io::{IoSlice, Result};
3 use std::pin::Pin;
4 use std::task::{ready, Context, Poll};
5 
6 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7 
8 pin_project! {
9     /// An adapter that lets you inspect the data that's being read.
10     ///
11     /// This is useful for things like hashing data as it's read in.
12     pub struct InspectReader<R, F> {
13         #[pin]
14         reader: R,
15         f: F,
16     }
17 }
18 
19 impl<R, F> InspectReader<R, F> {
20     /// Create a new `InspectReader`, wrapping `reader` and calling `f` for the
21     /// new data supplied by each read call.
22     ///
23     /// The closure will only be called with an empty slice if the inner reader
24     /// returns without reading data into the buffer. This happens at EOF, or if
25     /// `poll_read` is called with a zero-size buffer.
new(reader: R, f: F) -> InspectReader<R, F> where R: AsyncRead, F: FnMut(&[u8]),26     pub fn new(reader: R, f: F) -> InspectReader<R, F>
27     where
28         R: AsyncRead,
29         F: FnMut(&[u8]),
30     {
31         InspectReader { reader, f }
32     }
33 
34     /// Consumes the `InspectReader`, returning the wrapped reader
into_inner(self) -> R35     pub fn into_inner(self) -> R {
36         self.reader
37     }
38 }
39 
40 impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>41     fn poll_read(
42         self: Pin<&mut Self>,
43         cx: &mut Context<'_>,
44         buf: &mut ReadBuf<'_>,
45     ) -> Poll<Result<()>> {
46         let me = self.project();
47         let filled_length = buf.filled().len();
48         ready!(me.reader.poll_read(cx, buf))?;
49         (me.f)(&buf.filled()[filled_length..]);
50         Poll::Ready(Ok(()))
51     }
52 }
53 
54 impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<std::result::Result<usize, std::io::Error>>55     fn poll_write(
56         self: Pin<&mut Self>,
57         cx: &mut Context<'_>,
58         buf: &[u8],
59     ) -> Poll<std::result::Result<usize, std::io::Error>> {
60         self.project().reader.poll_write(cx, buf)
61     }
62 
poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>63     fn poll_flush(
64         self: Pin<&mut Self>,
65         cx: &mut Context<'_>,
66     ) -> Poll<std::result::Result<(), std::io::Error>> {
67         self.project().reader.poll_flush(cx)
68     }
69 
poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>70     fn poll_shutdown(
71         self: Pin<&mut Self>,
72         cx: &mut Context<'_>,
73     ) -> Poll<std::result::Result<(), std::io::Error>> {
74         self.project().reader.poll_shutdown(cx)
75     }
76 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>77     fn poll_write_vectored(
78         self: Pin<&mut Self>,
79         cx: &mut Context<'_>,
80         bufs: &[IoSlice<'_>],
81     ) -> Poll<Result<usize>> {
82         self.project().reader.poll_write_vectored(cx, bufs)
83     }
84 
is_write_vectored(&self) -> bool85     fn is_write_vectored(&self) -> bool {
86         self.reader.is_write_vectored()
87     }
88 }
89 
90 pin_project! {
91     /// An adapter that lets you inspect the data that's being written.
92     ///
93     /// This is useful for things like hashing data as it's written out.
94     pub struct InspectWriter<W, F> {
95         #[pin]
96         writer: W,
97         f: F,
98     }
99 }
100 
101 impl<W, F> InspectWriter<W, F> {
102     /// Create a new `InspectWriter`, wrapping `write` and calling `f` for the
103     /// data successfully written by each write call.
104     ///
105     /// The closure `f` will never be called with an empty slice. A vectored
106     /// write can result in multiple calls to `f` - at most one call to `f` per
107     /// buffer supplied to `poll_write_vectored`.
new(writer: W, f: F) -> InspectWriter<W, F> where W: AsyncWrite, F: FnMut(&[u8]),108     pub fn new(writer: W, f: F) -> InspectWriter<W, F>
109     where
110         W: AsyncWrite,
111         F: FnMut(&[u8]),
112     {
113         InspectWriter { writer, f }
114     }
115 
116     /// Consumes the `InspectWriter`, returning the wrapped writer
into_inner(self) -> W117     pub fn into_inner(self) -> W {
118         self.writer
119     }
120 }
121 
122 impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>123     fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
124         let me = self.project();
125         let res = me.writer.poll_write(cx, buf);
126         if let Poll::Ready(Ok(count)) = res {
127             if count != 0 {
128                 (me.f)(&buf[..count]);
129             }
130         }
131         res
132     }
133 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>134     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
135         let me = self.project();
136         me.writer.poll_flush(cx)
137     }
138 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>139     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
140         let me = self.project();
141         me.writer.poll_shutdown(cx)
142     }
143 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>144     fn poll_write_vectored(
145         self: Pin<&mut Self>,
146         cx: &mut Context<'_>,
147         bufs: &[IoSlice<'_>],
148     ) -> Poll<Result<usize>> {
149         let me = self.project();
150         let res = me.writer.poll_write_vectored(cx, bufs);
151         if let Poll::Ready(Ok(mut count)) = res {
152             for buf in bufs {
153                 if count == 0 {
154                     break;
155                 }
156                 let size = count.min(buf.len());
157                 if size != 0 {
158                     (me.f)(&buf[..size]);
159                     count -= size;
160                 }
161             }
162         }
163         res
164     }
165 
is_write_vectored(&self) -> bool166     fn is_write_vectored(&self) -> bool {
167         self.writer.is_write_vectored()
168     }
169 }
170 
171 impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>172     fn poll_read(
173         self: Pin<&mut Self>,
174         cx: &mut Context<'_>,
175         buf: &mut ReadBuf<'_>,
176     ) -> Poll<std::io::Result<()>> {
177         self.project().writer.poll_read(cx, buf)
178     }
179 }
180