1 use bytes::Buf;
2 use futures_core::stream::Stream;
3 use futures_sink::Sink;
4 use std::io;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
8 
9 /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
10 ///
11 /// This type performs the inverse operation of [`ReaderStream`].
12 ///
13 /// This type also implements the [`AsyncBufRead`] trait, so you can use it
14 /// to read a `Stream` of byte chunks line-by-line. See the examples below.
15 ///
16 /// # Example
17 ///
18 /// ```
19 /// use bytes::Bytes;
20 /// use tokio::io::{AsyncReadExt, Result};
21 /// use tokio_util::io::StreamReader;
22 /// # #[tokio::main(flavor = "current_thread")]
23 /// # async fn main() -> std::io::Result<()> {
24 ///
25 /// // Create a stream from an iterator.
26 /// let stream = tokio_stream::iter(vec![
27 ///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
28 ///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
29 ///     Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
30 /// ]);
31 ///
32 /// // Convert it to an AsyncRead.
33 /// let mut read = StreamReader::new(stream);
34 ///
35 /// // Read five bytes from the stream.
36 /// let mut buf = [0; 5];
37 /// read.read_exact(&mut buf).await?;
38 /// assert_eq!(buf, [0, 1, 2, 3, 4]);
39 ///
40 /// // Read the rest of the current chunk.
41 /// assert_eq!(read.read(&mut buf).await?, 3);
42 /// assert_eq!(&buf[..3], [5, 6, 7]);
43 ///
44 /// // Read the next chunk.
45 /// assert_eq!(read.read(&mut buf).await?, 4);
46 /// assert_eq!(&buf[..4], [8, 9, 10, 11]);
47 ///
48 /// // We have now reached the end.
49 /// assert_eq!(read.read(&mut buf).await?, 0);
50 ///
51 /// # Ok(())
52 /// # }
53 /// ```
54 ///
55 /// If the stream produces errors which are not [`std::io::Error`],
56 /// the errors can be converted using [`StreamExt`] to map each
57 /// element.
58 ///
59 /// ```
60 /// use bytes::Bytes;
61 /// use tokio::io::AsyncReadExt;
62 /// use tokio_util::io::StreamReader;
63 /// use tokio_stream::StreamExt;
64 /// # #[tokio::main(flavor = "current_thread")]
65 /// # async fn main() -> std::io::Result<()> {
66 ///
67 /// // Create a stream from an iterator, including an error.
68 /// let stream = tokio_stream::iter(vec![
69 ///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
70 ///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
71 ///     Result::Err("Something bad happened!")
72 /// ]);
73 ///
74 /// // Use StreamExt to map the stream and error to a std::io::Error
75 /// let stream = stream.map(|result| result.map_err(|err| {
76 ///     std::io::Error::new(std::io::ErrorKind::Other, err)
77 /// }));
78 ///
79 /// // Convert it to an AsyncRead.
80 /// let mut read = StreamReader::new(stream);
81 ///
82 /// // Read five bytes from the stream.
83 /// let mut buf = [0; 5];
84 /// read.read_exact(&mut buf).await?;
85 /// assert_eq!(buf, [0, 1, 2, 3, 4]);
86 ///
87 /// // Read the rest of the current chunk.
88 /// assert_eq!(read.read(&mut buf).await?, 3);
89 /// assert_eq!(&buf[..3], [5, 6, 7]);
90 ///
91 /// // Reading the next chunk will produce an error
92 /// let error = read.read(&mut buf).await.unwrap_err();
93 /// assert_eq!(error.kind(), std::io::ErrorKind::Other);
94 /// assert_eq!(error.into_inner().unwrap().to_string(), "Something bad happened!");
95 ///
96 /// // We have now reached the end.
97 /// assert_eq!(read.read(&mut buf).await?, 0);
98 ///
99 /// # Ok(())
100 /// # }
101 /// ```
102 ///
103 /// Using the [`AsyncBufRead`] impl, you can read a `Stream` of byte chunks
104 /// line-by-line. Note that you will usually also need to convert the error
105 /// type when doing this. See the second example for an explanation of how
106 /// to do this.
107 ///
108 /// ```
109 /// use tokio::io::{Result, AsyncBufReadExt};
110 /// use tokio_util::io::StreamReader;
111 /// # #[tokio::main(flavor = "current_thread")]
112 /// # async fn main() -> std::io::Result<()> {
113 ///
114 /// // Create a stream of byte chunks.
115 /// let stream = tokio_stream::iter(vec![
116 ///     Result::Ok(b"The first line.\n".as_slice()),
117 ///     Result::Ok(b"The second line.".as_slice()),
118 ///     Result::Ok(b"\nThe third".as_slice()),
119 ///     Result::Ok(b" line.\nThe fourth line.\nThe fifth line.\n".as_slice()),
120 /// ]);
121 ///
122 /// // Convert it to an AsyncRead.
123 /// let mut read = StreamReader::new(stream);
124 ///
125 /// // Loop through the lines from the `StreamReader`.
126 /// let mut line = String::new();
127 /// let mut lines = Vec::new();
128 /// loop {
129 ///     line.clear();
130 ///     let len = read.read_line(&mut line).await?;
131 ///     if len == 0 { break; }
132 ///     lines.push(line.clone());
133 /// }
134 ///
135 /// // Verify that we got the lines we expected.
136 /// assert_eq!(
137 ///     lines,
138 ///     vec![
139 ///         "The first line.\n",
140 ///         "The second line.\n",
141 ///         "The third line.\n",
142 ///         "The fourth line.\n",
143 ///         "The fifth line.\n",
144 ///     ]
145 /// );
146 /// # Ok(())
147 /// # }
148 /// ```
149 ///
150 /// [`AsyncRead`]: tokio::io::AsyncRead
151 /// [`AsyncBufRead`]: tokio::io::AsyncBufRead
152 /// [`Stream`]: futures_core::Stream
153 /// [`ReaderStream`]: crate::io::ReaderStream
154 /// [`StreamExt`]: https://docs.rs/tokio-stream/latest/tokio_stream/trait.StreamExt.html
155 #[derive(Debug)]
156 pub struct StreamReader<S, B> {
157     // This field is pinned.
158     inner: S,
159     // This field is not pinned.
160     chunk: Option<B>,
161 }
162 
163 impl<S, B, E> StreamReader<S, B>
164 where
165     S: Stream<Item = Result<B, E>>,
166     B: Buf,
167     E: Into<std::io::Error>,
168 {
169     /// Convert a stream of byte chunks into an [`AsyncRead`].
170     ///
171     /// The item should be a [`Result`] with the ok variant being something that
172     /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
173     /// should be convertible into an [io error].
174     ///
175     /// [`Result`]: std::result::Result
176     /// [`Buf`]: bytes::Buf
177     /// [io error]: std::io::Error
new(stream: S) -> Self178     pub fn new(stream: S) -> Self {
179         Self {
180             inner: stream,
181             chunk: None,
182         }
183     }
184 
185     /// Do we have a chunk and is it non-empty?
has_chunk(&self) -> bool186     fn has_chunk(&self) -> bool {
187         if let Some(ref chunk) = self.chunk {
188             chunk.remaining() > 0
189         } else {
190             false
191         }
192     }
193 
194     /// Consumes this `StreamReader`, returning a Tuple consisting
195     /// of the underlying stream and an Option of the internal buffer,
196     /// which is Some in case the buffer contains elements.
into_inner_with_chunk(self) -> (S, Option<B>)197     pub fn into_inner_with_chunk(self) -> (S, Option<B>) {
198         if self.has_chunk() {
199             (self.inner, self.chunk)
200         } else {
201             (self.inner, None)
202         }
203     }
204 }
205 
206 impl<S, B> StreamReader<S, B> {
207     /// Gets a reference to the underlying stream.
208     ///
209     /// It is inadvisable to directly read from the underlying stream.
get_ref(&self) -> &S210     pub fn get_ref(&self) -> &S {
211         &self.inner
212     }
213 
214     /// Gets a mutable reference to the underlying stream.
215     ///
216     /// It is inadvisable to directly read from the underlying stream.
get_mut(&mut self) -> &mut S217     pub fn get_mut(&mut self) -> &mut S {
218         &mut self.inner
219     }
220 
221     /// Gets a pinned mutable reference to the underlying stream.
222     ///
223     /// It is inadvisable to directly read from the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>224     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
225         self.project().inner
226     }
227 
228     /// Consumes this `BufWriter`, returning the underlying stream.
229     ///
230     /// Note that any leftover data in the internal buffer is lost.
231     /// If you additionally want access to the internal buffer use
232     /// [`into_inner_with_chunk`].
233     ///
234     /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk
into_inner(self) -> S235     pub fn into_inner(self) -> S {
236         self.inner
237     }
238 }
239 
240 impl<S, B, E> AsyncRead for StreamReader<S, B>
241 where
242     S: Stream<Item = Result<B, E>>,
243     B: Buf,
244     E: Into<std::io::Error>,
245 {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>246     fn poll_read(
247         mut self: Pin<&mut Self>,
248         cx: &mut Context<'_>,
249         buf: &mut ReadBuf<'_>,
250     ) -> Poll<io::Result<()>> {
251         if buf.remaining() == 0 {
252             return Poll::Ready(Ok(()));
253         }
254 
255         let inner_buf = match self.as_mut().poll_fill_buf(cx) {
256             Poll::Ready(Ok(buf)) => buf,
257             Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
258             Poll::Pending => return Poll::Pending,
259         };
260         let len = std::cmp::min(inner_buf.len(), buf.remaining());
261         buf.put_slice(&inner_buf[..len]);
262 
263         self.consume(len);
264         Poll::Ready(Ok(()))
265     }
266 }
267 
268 impl<S, B, E> AsyncBufRead for StreamReader<S, B>
269 where
270     S: Stream<Item = Result<B, E>>,
271     B: Buf,
272     E: Into<std::io::Error>,
273 {
poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>274     fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
275         loop {
276             if self.as_mut().has_chunk() {
277                 // This unwrap is very sad, but it can't be avoided.
278                 let buf = self.project().chunk.as_ref().unwrap().chunk();
279                 return Poll::Ready(Ok(buf));
280             } else {
281                 match self.as_mut().project().inner.poll_next(cx) {
282                     Poll::Ready(Some(Ok(chunk))) => {
283                         // Go around the loop in case the chunk is empty.
284                         *self.as_mut().project().chunk = Some(chunk);
285                     }
286                     Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
287                     Poll::Ready(None) => return Poll::Ready(Ok(&[])),
288                     Poll::Pending => return Poll::Pending,
289                 }
290             }
291         }
292     }
consume(self: Pin<&mut Self>, amt: usize)293     fn consume(self: Pin<&mut Self>, amt: usize) {
294         if amt > 0 {
295             self.project()
296                 .chunk
297                 .as_mut()
298                 .expect("No chunk present")
299                 .advance(amt);
300         }
301     }
302 }
303 
304 // The code below is a manual expansion of the code that pin-project-lite would
305 // generate. This is done because pin-project-lite fails by hitting the recursion
306 // limit on this struct. (Every line of documentation is handled recursively by
307 // the macro.)
308 
309 impl<S: Unpin, B> Unpin for StreamReader<S, B> {}
310 
311 struct StreamReaderProject<'a, S, B> {
312     inner: Pin<&'a mut S>,
313     chunk: &'a mut Option<B>,
314 }
315 
316 impl<S, B> StreamReader<S, B> {
317     #[inline]
project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B>318     fn project(self: Pin<&mut Self>) -> StreamReaderProject<'_, S, B> {
319         // SAFETY: We define that only `inner` should be pinned when `Self` is
320         // and have an appropriate `impl Unpin` for this.
321         let me = unsafe { Pin::into_inner_unchecked(self) };
322         StreamReaderProject {
323             inner: unsafe { Pin::new_unchecked(&mut me.inner) },
324             chunk: &mut me.chunk,
325         }
326     }
327 }
328 
329 impl<S: Sink<T, Error = E>, B, E, T> Sink<T> for StreamReader<S, B> {
330     type Error = E;
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>331     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
332         self.project().inner.poll_ready(cx)
333     }
334 
start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error>335     fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
336         self.project().inner.start_send(item)
337     }
338 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>339     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
340         self.project().inner.poll_flush(cx)
341     }
342 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>343     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
344         self.project().inner.poll_close(cx)
345     }
346 }
347