1 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2 
3 use bytes::{Buf, BufMut};
4 use std::io::{self, IoSlice};
5 use std::mem::MaybeUninit;
6 use std::pin::Pin;
7 use std::task::{ready, Context, Poll};
8 
9 /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
10 ///
11 /// [`BufMut`]: bytes::Buf
12 ///
13 /// # Example
14 ///
15 /// ```
16 /// use bytes::{Bytes, BytesMut};
17 /// use tokio_stream as stream;
18 /// use tokio::io::Result;
19 /// use tokio_util::io::{StreamReader, poll_read_buf};
20 /// use std::future::poll_fn;
21 /// use std::pin::Pin;
22 /// # #[tokio::main]
23 /// # async fn main() -> std::io::Result<()> {
24 ///
25 /// // Create a reader from an iterator. This particular reader will always be
26 /// // ready.
27 /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
28 ///
29 /// let mut buf = BytesMut::new();
30 /// let mut reads = 0;
31 ///
32 /// loop {
33 ///     reads += 1;
34 ///     let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
35 ///
36 ///     if n == 0 {
37 ///         break;
38 ///     }
39 /// }
40 ///
41 /// // one or more reads might be necessary.
42 /// assert!(reads >= 1);
43 /// assert_eq!(&buf[..], &[0, 1, 2, 3]);
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>>48 pub fn poll_read_buf<T: AsyncRead + ?Sized, B: BufMut>(
49     io: Pin<&mut T>,
50     cx: &mut Context<'_>,
51     buf: &mut B,
52 ) -> Poll<io::Result<usize>> {
53     if !buf.has_remaining_mut() {
54         return Poll::Ready(Ok(0));
55     }
56 
57     let n = {
58         let dst = buf.chunk_mut();
59 
60         // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
61         // transparent wrapper around `[MaybeUninit<u8>]`.
62         let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
63         let mut buf = ReadBuf::uninit(dst);
64         let ptr = buf.filled().as_ptr();
65         ready!(io.poll_read(cx, &mut buf)?);
66 
67         // Ensure the pointer does not change from under us
68         assert_eq!(ptr, buf.filled().as_ptr());
69         buf.filled().len()
70     };
71 
72     // Safety: This is guaranteed to be the number of initialized (and read)
73     // bytes due to the invariants provided by `ReadBuf::filled`.
74     unsafe {
75         buf.advance_mut(n);
76     }
77 
78     Poll::Ready(Ok(n))
79 }
80 
81 /// Try to write data from an implementer of the [`Buf`] trait to an
82 /// [`AsyncWrite`], advancing the buffer's internal cursor.
83 ///
84 /// This function will use [vectored writes] when the [`AsyncWrite`] supports
85 /// vectored writes.
86 ///
87 /// # Examples
88 ///
89 /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
90 /// [`Buf`]:
91 ///
92 /// ```no_run
93 /// use tokio_util::io::poll_write_buf;
94 /// use tokio::io;
95 /// use tokio::fs::File;
96 ///
97 /// use bytes::Buf;
98 /// use std::future::poll_fn;
99 /// use std::io::Cursor;
100 /// use std::pin::Pin;
101 ///
102 /// #[tokio::main]
103 /// async fn main() -> io::Result<()> {
104 ///     let mut file = File::create("foo.txt").await?;
105 ///     let mut buf = Cursor::new(b"data to write");
106 ///
107 ///     // Loop until the entire contents of the buffer are written to
108 ///     // the file.
109 ///     while buf.has_remaining() {
110 ///         poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
111 ///     }
112 ///
113 ///     Ok(())
114 /// }
115 /// ```
116 ///
117 /// [`Buf`]: bytes::Buf
118 /// [`AsyncWrite`]: tokio::io::AsyncWrite
119 /// [`File`]: tokio::fs::File
120 /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
121 #[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>( io: Pin<&mut T>, cx: &mut Context<'_>, buf: &mut B, ) -> Poll<io::Result<usize>>122 pub fn poll_write_buf<T: AsyncWrite + ?Sized, B: Buf>(
123     io: Pin<&mut T>,
124     cx: &mut Context<'_>,
125     buf: &mut B,
126 ) -> Poll<io::Result<usize>> {
127     const MAX_BUFS: usize = 64;
128 
129     if !buf.has_remaining() {
130         return Poll::Ready(Ok(0));
131     }
132 
133     let n = if io.is_write_vectored() {
134         let mut slices = [IoSlice::new(&[]); MAX_BUFS];
135         let cnt = buf.chunks_vectored(&mut slices);
136         ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
137     } else {
138         ready!(io.poll_write(cx, buf.chunk()))?
139     };
140 
141     buf.advance(n);
142 
143     Poll::Ready(Ok(n))
144 }
145