1 use bytes::Buf; 2 use futures_core::stream::Stream; 3 use futures_sink::Sink; 4 use std::io; 5 use std::pin::Pin; 6 use std::task::{Context, Poll}; 7 use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; 8 9 /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. 10 /// 11 /// This type performs the inverse operation of [`ReaderStream`]. 12 /// 13 /// This type also implements the [`AsyncBufRead`] trait, so you can use it 14 /// to read a `Stream` of byte chunks line-by-line. See the examples below. 15 /// 16 /// # Example 17 /// 18 /// ``` 19 /// use bytes::Bytes; 20 /// use tokio::io::{AsyncReadExt, Result}; 21 /// use tokio_util::io::StreamReader; 22 /// # #[tokio::main(flavor = "current_thread")] 23 /// # async fn main() -> std::io::Result<()> { 24 /// 25 /// // Create a stream from an iterator. 26 /// let stream = tokio_stream::iter(vec![ 27 /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), 28 /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), 29 /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), 30 /// ]); 31 /// 32 /// // Convert it to an AsyncRead. 33 /// let mut read = StreamReader::new(stream); 34 /// 35 /// // Read five bytes from the stream. 36 /// let mut buf = [0; 5]; 37 /// read.read_exact(&mut buf).await?; 38 /// assert_eq!(buf, [0, 1, 2, 3, 4]); 39 /// 40 /// // Read the rest of the current chunk. 41 /// assert_eq!(read.read(&mut buf).await?, 3); 42 /// assert_eq!(&buf[..3], [5, 6, 7]); 43 /// 44 /// // Read the next chunk. 45 /// assert_eq!(read.read(&mut buf).await?, 4); 46 /// assert_eq!(&buf[..4], [8, 9, 10, 11]); 47 /// 48 /// // We have now reached the end. 49 /// assert_eq!(read.read(&mut buf).await?, 0); 50 /// 51 /// # Ok(()) 52 /// # } 53 /// ``` 54 /// 55 /// If the stream produces errors which are not [`std::io::Error`], 56 /// the errors can be converted using [`StreamExt`] to map each 57 /// element. 58 /// 59 /// ``` 60 /// use bytes::Bytes; 61 /// use tokio::io::AsyncReadExt; 62 /// use tokio_util::io::StreamReader; 63 /// use tokio_stream::StreamExt; 64 /// # #[tokio::main(flavor = "current_thread")] 65 /// # async fn main() -> std::io::Result<()> { 66 /// 67 /// // Create a stream from an iterator, including an error. 68 /// let stream = tokio_stream::iter(vec![ 69 /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), 70 /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), 71 /// Result::Err("Something bad happened!") 72 /// ]); 73 /// 74 /// // Use StreamExt to map the stream and error to a std::io::Error 75 /// let stream = stream.map(|result| result.map_err(|err| { 76 /// std::io::Error::new(std::io::ErrorKind::Other, err) 77 /// })); 78 /// 79 /// // Convert it to an AsyncRead. 80 /// let mut read = StreamReader::new(stream); 81 /// 82 /// // Read five bytes from the stream. 83 /// let mut buf = [0; 5]; 84 /// read.read_exact(&mut buf).await?; 85 /// assert_eq!(buf, [0, 1, 2, 3, 4]); 86 /// 87 /// // Read the rest of the current chunk. 88 /// assert_eq!(read.read(&mut buf).await?, 3); 89 /// assert_eq!(&buf[..3], [5, 6, 7]); 90 /// 91 /// // Reading the next chunk will produce an error 92 /// let error = read.read(&mut buf).await.unwrap_err(); 93 /// assert_eq!(error.kind(), std::io::ErrorKind::Other); 94 /// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!"); 95 /// 96 /// // We have now reached the end. 97 /// assert_eq!(read.read(&mut buf).await?, 0); 98 /// 99 /// # Ok(()) 100 /// # } 101 /// ``` 102 /// 103 /// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks 104 /// line-by-line. Note that you will usually also need to convert the error 105 /// type when doing this. See the second example for an explanation of how 106 /// to do this. 107 /// 108 /// ``` 109 /// use tokio::io::{Result, AsyncBufReadExt}; 110 /// use tokio_util::io::StreamReader; 111 /// # #[tokio::main(flavor = "current_thread")] 112 /// # async fn main() -> std::io::Result<()> { 113 /// 114 /// // Create a stream of byte chunks. 115 /// let stream = tokio_stream::iter(vec![ 116 /// Result::Ok(b"The first line.\n".as_slice()), 117 /// Result::Ok(b"The second line.".as_slice()), 118 /// Result::Ok(b"\nThe third".as_slice()), 119 /// Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()), 120 /// ]); 121 /// 122 /// // Convert it to an AsyncRead. 123 /// let mut read = StreamReader::new(stream); 124 /// 125 /// // Loop through the lines from the `StreamReader`. 126 /// let mut line = String::new(); 127 /// let mut lines = Vec::new(); 128 /// loop { 129 /// line.clear(); 130 /// let len = read.read_line(&mut line).await?; 131 /// if len == 0 { break; } 132 /// lines.push(line.clone()); 133 /// } 134 /// 135 /// // Verify that we got the lines we expected. 136 /// assert_eq!( 137 /// lines, 138 /// vec![ 139 /// "The first line.\n", 140 /// "The second line.\n", 141 /// "The third line.\n", 142 /// "The fourth line.\n", 143 /// "The fifth line.\n", 144 /// ] 145 /// ); 146 /// # Ok(()) 147 /// # } 148 /// ``` 149 /// 150 /// [`AsyncRead`]: tokio::io::AsyncRead 151 /// [`AsyncBufRead`]: tokio::io::AsyncBufRead 152 /// [`Stream`]: futures_core::Stream 153 /// [`ReaderStream`]: crate::io::ReaderStream 154 /// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html 155 #[derive(Debug)] 156 pub struct StreamReader<S, B> { 157 // This field is pinned. 158 inner: S, 159 // This field is not pinned. 160 chunk: Option<B>, 161 } 162 163 impl<S, B, E> StreamReader<S, B> 164 where 165 S: Stream<Item = Result<B, E>>, 166 B: Buf, 167 E: Into<std::io::Error>, 168 { 169 /// Convert a stream of byte chunks into an [`AsyncRead`]. 170 /// 171 /// The item should be a [`Result`] with the ok variant being something that 172 /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error 173 /// should be convertible into an [io error]. 174 /// 175 /// [`Result`]: std::result::Result 176 /// [`Buf`]: bytes::Buf 177 /// [io error]: std::io::Error new(stream: S) -> Self178 pub fn new(stream: S) -> Self { 179 Self { 180 inner: stream, 181 chunk: None, 182 } 183 } 184 185 /// Do we have a chunk and is it non-empty? has_chunk(&self) -> bool186 fn has_chunk(&self) -> bool { 187 if let Some(ref chunk) = self.chunk { 188 chunk.remaining() > 0 189 } else { 190 false 191 } 192 } 193 194 /// Consumes this `StreamReader`, returning a Tuple consisting 195 /// of the underlying stream and an Option of the internal buffer, 196 /// which is Some in case the buffer contains elements. into_inner_with_chunk(self) -> (S, Option<B>)197 pub fn into_inner_with_chunk(self) -> (S, Option<B>) { 198 if self.has_chunk() { 199 (self.inner, self.chunk) 200 } else { 201 (self.inner, None) 202 } 203 } 204 } 205 206 impl<S, B> StreamReader<S, B> { 207 /// Gets a reference to the underlying stream. 208 /// 209 /// It is inadvisable to directly read from the underlying stream. get_ref(&self) -> &S210 pub fn get_ref(&self) -> &S { 211 &self.inner 212 } 213 214 /// Gets a mutable reference to the underlying stream. 215 /// 216 /// It is inadvisable to directly read from the underlying stream. get_mut(&mut self) -> &mut S217 pub fn get_mut(&mut self) -> &mut S { 218 &mut self.inner 219 } 220 221 /// Gets a pinned mutable reference to the underlying stream. 222 /// 223 /// It is inadvisable to directly read from the underlying stream. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>224 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { 225 self.project().inner 226 } 227 228 /// Consumes this `BufWriter`, returning the underlying stream. 229 /// 230 /// Note that any leftover data in the internal buffer is lost. 231 /// If you additionally want access to the internal buffer use 232 /// [`into_inner_with_chunk`]. 233 /// 234 /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk into_inner(self) -> S235 pub fn into_inner(self) -> S { 236 self.inner 237 } 238 } 239 240 impl<S, B, E> AsyncRead for StreamReader<S, B> 241 where 242 S: Stream<Item = Result<B, E>>, 243 B: Buf, 244 E: Into<std::io::Error>, 245 { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>246 fn poll_read( 247 mut self: Pin<&mut Self>, 248 cx: &mut Context<'_>, 249 buf: &mut ReadBuf<'_>, 250 ) -> Poll<io::Result<()>> { 251 if buf.remaining() == 0 { 252 return Poll::Ready(Ok(())); 253 } 254 255 let inner_buf = match self.as_mut().poll_fill_buf(cx) { 256 Poll::Ready(Ok(buf)) => buf, 257 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 258 Poll::Pending => return Poll::Pending, 259 }; 260 let len = std::cmp::min(inner_buf.len(), buf.remaining()); 261 buf.put_slice(&inner_buf[..len]); 262 263 self.consume(len); 264 Poll::Ready(Ok(())) 265 } 266 } 267 268 impl<S, B, E> AsyncBufRead for StreamReader<S, B> 269 where 270 S: Stream<Item = Result<B, E>>, 271 B: Buf, 272 E: Into<std::io::Error>, 273 { poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>274 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 275 loop { 276 if self.as_mut().has_chunk() { 277 // This unwrap is very sad, but it can't be avoided. 278 let buf = self.project().chunk.as_ref().unwrap().chunk(); 279 return Poll::Ready(Ok(buf)); 280 } else { 281 match self.as_mut().project().inner.poll_next(cx) { 282 Poll::Ready(Some(Ok(chunk))) => { 283 // Go around the loop in case the chunk is empty. 284 *self.as_mut().project().chunk = Some(chunk); 285 } 286 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), 287 Poll::Ready(None) => return Poll::Ready(Ok(&[])), 288 Poll::Pending => return Poll::Pending, 289 } 290 } 291 } 292 } consume(self: Pin<&mut Self>, amt: usize)293 fn consume(self: Pin<&mut Self>, amt: usize) { 294 if amt > 0 { 295 self.project() 296 .chunk 297 .as_mut() 298 .expect("No chunk present") 299 .advance(amt); 300 } 301 } 302 } 303 304 // The code below is a manual expansion of the code that pin-project-lite would 305 // generate. This is done because pin-project-lite fails by hitting the recursion 306 // limit on this struct. (Every line of documentation is handled recursively by 307 // the macro.) 308 309 impl<S: Unpin, B> Unpin for StreamReader<S, B> {} 310 311 struct StreamReaderProject<'a, S, B> { 312 inner: Pin<&'a mut S>, 313 chunk: &'a mut Option<B>, 314 } 315 316 impl<S, B> StreamReader<S, B> { 317 #[inline] project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B>318 fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> { 319 // SAFETY: We define that only `inner` should be pinned when `Self` is 320 // and have an appropriate `impl Unpin` for this. 321 let me = unsafe { Pin::into_inner_unchecked(self) }; 322 StreamReaderProject { 323 inner: unsafe { Pin::new_unchecked(&mut me.inner) }, 324 chunk: &mut me.chunk, 325 } 326 } 327 } 328 329 impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> { 330 type Error = E; poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>331 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 332 self.project().inner.poll_ready(cx) 333 } 334 start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>335 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { 336 self.project().inner.start_send(item) 337 } 338 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>339 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 340 self.project().inner.poll_flush(cx) 341 } 342 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>343 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 344 self.project().inner.poll_close(cx) 345 } 346 } 347