1 use pin_project_lite::pin_project; 2 use std::io::{IoSlice, Result}; 3 use std::pin::Pin; 4 use std::task::{ready, Context, Poll}; 5 6 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 7 8 pin_project! { 9 /// An adapter that lets you inspect the data that's being read. 10 /// 11 /// This is useful for things like hashing data as it's read in. 12 pub struct InspectReader<R, F> { 13 #[pin] 14 reader: R, 15 f: F, 16 } 17 } 18 19 impl<R, F> InspectReader<R, F> { 20 /// Create a new `InspectReader`, wrapping `reader` and calling `f` for the 21 /// new data supplied by each read call. 22 /// 23 /// The closure will only be called with an empty slice if the inner reader 24 /// returns without reading data into the buffer. This happens at EOF, or if 25 /// `poll_read` is called with a zero-size buffer. new(reader: R, f: F) -> InspectReader<R, F> where R: AsyncRead, F: FnMut(&[u8]),26 pub fn new(reader: R, f: F) -> InspectReader<R, F> 27 where 28 R: AsyncRead, 29 F: FnMut(&[u8]), 30 { 31 InspectReader { reader, f } 32 } 33 34 /// Consumes the `InspectReader`, returning the wrapped reader into_inner(self) -> R35 pub fn into_inner(self) -> R { 36 self.reader 37 } 38 } 39 40 impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>41 fn poll_read( 42 self: Pin<&mut Self>, 43 cx: &mut Context<'_>, 44 buf: &mut ReadBuf<'_>, 45 ) -> Poll<Result<()>> { 46 let me = self.project(); 47 let filled_length = buf.filled().len(); 48 ready!(me.reader.poll_read(cx, buf))?; 49 (me.f)(&buf.filled()[filled_length..]); 50 Poll::Ready(Ok(())) 51 } 52 } 53 54 impl<R: AsyncWrite, F> AsyncWrite for InspectReader<R, F> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<std::result::Result<usize, std::io::Error>>55 fn poll_write( 56 self: Pin<&mut Self>, 57 cx: &mut Context<'_>, 58 buf: &[u8], 59 ) -> Poll<std::result::Result<usize, std::io::Error>> { 60 self.project().reader.poll_write(cx, buf) 61 } 62 poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>63 fn poll_flush( 64 self: Pin<&mut Self>, 65 cx: &mut Context<'_>, 66 ) -> Poll<std::result::Result<(), std::io::Error>> { 67 self.project().reader.poll_flush(cx) 68 } 69 poll_shutdown( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<std::result::Result<(), std::io::Error>>70 fn poll_shutdown( 71 self: Pin<&mut Self>, 72 cx: &mut Context<'_>, 73 ) -> Poll<std::result::Result<(), std::io::Error>> { 74 self.project().reader.poll_shutdown(cx) 75 } 76 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>77 fn poll_write_vectored( 78 self: Pin<&mut Self>, 79 cx: &mut Context<'_>, 80 bufs: &[IoSlice<'_>], 81 ) -> Poll<Result<usize>> { 82 self.project().reader.poll_write_vectored(cx, bufs) 83 } 84 is_write_vectored(&self) -> bool85 fn is_write_vectored(&self) -> bool { 86 self.reader.is_write_vectored() 87 } 88 } 89 90 pin_project! { 91 /// An adapter that lets you inspect the data that's being written. 92 /// 93 /// This is useful for things like hashing data as it's written out. 94 pub struct InspectWriter<W, F> { 95 #[pin] 96 writer: W, 97 f: F, 98 } 99 } 100 101 impl<W, F> InspectWriter<W, F> { 102 /// Create a new `InspectWriter`, wrapping `write` and calling `f` for the 103 /// data successfully written by each write call. 104 /// 105 /// The closure `f` will never be called with an empty slice. A vectored 106 /// write can result in multiple calls to `f` - at most one call to `f` per 107 /// buffer supplied to `poll_write_vectored`. new(writer: W, f: F) -> InspectWriter<W, F> where W: AsyncWrite, F: FnMut(&[u8]),108 pub fn new(writer: W, f: F) -> InspectWriter<W, F> 109 where 110 W: AsyncWrite, 111 F: FnMut(&[u8]), 112 { 113 InspectWriter { writer, f } 114 } 115 116 /// Consumes the `InspectWriter`, returning the wrapped writer into_inner(self) -> W117 pub fn into_inner(self) -> W { 118 self.writer 119 } 120 } 121 122 impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> { poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>123 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { 124 let me = self.project(); 125 let res = me.writer.poll_write(cx, buf); 126 if let Poll::Ready(Ok(count)) = res { 127 if count != 0 { 128 (me.f)(&buf[..count]); 129 } 130 } 131 res 132 } 133 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>134 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { 135 let me = self.project(); 136 me.writer.poll_flush(cx) 137 } 138 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>>139 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { 140 let me = self.project(); 141 me.writer.poll_shutdown(cx) 142 } 143 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize>>144 fn poll_write_vectored( 145 self: Pin<&mut Self>, 146 cx: &mut Context<'_>, 147 bufs: &[IoSlice<'_>], 148 ) -> Poll<Result<usize>> { 149 let me = self.project(); 150 let res = me.writer.poll_write_vectored(cx, bufs); 151 if let Poll::Ready(Ok(mut count)) = res { 152 for buf in bufs { 153 if count == 0 { 154 break; 155 } 156 let size = count.min(buf.len()); 157 if size != 0 { 158 (me.f)(&buf[..size]); 159 count -= size; 160 } 161 } 162 } 163 res 164 } 165 is_write_vectored(&self) -> bool166 fn is_write_vectored(&self) -> bool { 167 self.writer.is_write_vectored() 168 } 169 } 170 171 impl<W: AsyncRead, F> AsyncRead for InspectWriter<W, F> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>172 fn poll_read( 173 self: Pin<&mut Self>, 174 cx: &mut Context<'_>, 175 buf: &mut ReadBuf<'_>, 176 ) -> Poll<std::io::Result<()>> { 177 self.project().writer.poll_read(cx, buf) 178 } 179 } 180