1 //! In-process memory IO types.
2 
3 use crate::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
4 use crate::loom::sync::Mutex;
5 
6 use bytes::{Buf, BytesMut};
7 use std::{
8     pin::Pin,
9     sync::Arc,
10     task::{self, ready, Poll, Waker},
11 };
12 
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Closing a `DuplexStream`
20 ///
21 /// If one end of the `DuplexStream` channel is dropped, any pending reads on
22 /// the other side will continue to read data until the buffer is drained, then
23 /// they will signal EOF by returning 0 bytes. Any writes to the other side,
24 /// including pending ones (that are waiting for free space in the buffer) will
25 /// return `Err(BrokenPipe)` immediately.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// # async fn ex() -> std::io::Result<()> {
31 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32 /// let (mut client, mut server) = tokio::io::duplex(64);
33 ///
34 /// client.write_all(b"ping").await?;
35 ///
36 /// let mut buf = [0u8; 4];
37 /// server.read_exact(&mut buf).await?;
38 /// assert_eq!(&buf, b"ping");
39 ///
40 /// server.write_all(b"pong").await?;
41 ///
42 /// client.read_exact(&mut buf).await?;
43 /// assert_eq!(&buf, b"pong");
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[derive(Debug)]
48 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49 pub struct DuplexStream {
50     read: Arc<Mutex<SimplexStream>>,
51     write: Arc<Mutex<SimplexStream>>,
52 }
53 
54 /// A unidirectional pipe to read and write bytes in memory.
55 ///
56 /// It can be constructed by [`simplex`] function which will create a pair of
57 /// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
58 /// create a handle for both reading and writing.
59 ///
60 /// # Example
61 ///
62 /// ```
63 /// # async fn ex() -> std::io::Result<()> {
64 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
65 /// let (mut receiver, mut sender) = tokio::io::simplex(64);
66 ///
67 /// sender.write_all(b"ping").await?;
68 ///
69 /// let mut buf = [0u8; 4];
70 /// receiver.read_exact(&mut buf).await?;
71 /// assert_eq!(&buf, b"ping");
72 /// # Ok(())
73 /// # }
74 /// ```
75 #[derive(Debug)]
76 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
77 pub struct SimplexStream {
78     /// The buffer storing the bytes written, also read from.
79     ///
80     /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
81     /// functionality already. Additionally, it can try to copy data in the
82     /// same buffer if there read index has advanced far enough.
83     buffer: BytesMut,
84     /// Determines if the write side has been closed.
85     is_closed: bool,
86     /// The maximum amount of bytes that can be written before returning
87     /// `Poll::Pending`.
88     max_buf_size: usize,
89     /// If the `read` side has been polled and is pending, this is the waker
90     /// for that parked task.
91     read_waker: Option<Waker>,
92     /// If the `write` side has filled the `max_buf_size` and returned
93     /// `Poll::Pending`, this is the waker for that parked task.
94     write_waker: Option<Waker>,
95 }
96 
97 // ===== impl DuplexStream =====
98 
99 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
100 ///
101 /// The `max_buf_size` argument is the maximum amount of bytes that can be
102 /// written to a side before the write returns `Poll::Pending`.
103 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)104 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
105     let one = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
106     let two = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
107 
108     (
109         DuplexStream {
110             read: one.clone(),
111             write: two.clone(),
112         },
113         DuplexStream {
114             read: two,
115             write: one,
116         },
117     )
118 }
119 
120 impl AsyncRead for DuplexStream {
121     // Previous rustc required this `self` to be `mut`, even though newer
122     // versions recognize it isn't needed to call `lock()`. So for
123     // compatibility, we include the `mut` and `allow` the lint.
124     //
125     // See https://github.com/rust-lang/rust/issues/73592
126     #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>127     fn poll_read(
128         mut self: Pin<&mut Self>,
129         cx: &mut task::Context<'_>,
130         buf: &mut ReadBuf<'_>,
131     ) -> Poll<std::io::Result<()>> {
132         Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
133     }
134 }
135 
136 impl AsyncWrite for DuplexStream {
137     #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>138     fn poll_write(
139         mut self: Pin<&mut Self>,
140         cx: &mut task::Context<'_>,
141         buf: &[u8],
142     ) -> Poll<std::io::Result<usize>> {
143         Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
144     }
145 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>146     fn poll_write_vectored(
147         self: Pin<&mut Self>,
148         cx: &mut task::Context<'_>,
149         bufs: &[std::io::IoSlice<'_>],
150     ) -> Poll<Result<usize, std::io::Error>> {
151         Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
152     }
153 
is_write_vectored(&self) -> bool154     fn is_write_vectored(&self) -> bool {
155         true
156     }
157 
158     #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>159     fn poll_flush(
160         mut self: Pin<&mut Self>,
161         cx: &mut task::Context<'_>,
162     ) -> Poll<std::io::Result<()>> {
163         Pin::new(&mut *self.write.lock()).poll_flush(cx)
164     }
165 
166     #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>167     fn poll_shutdown(
168         mut self: Pin<&mut Self>,
169         cx: &mut task::Context<'_>,
170     ) -> Poll<std::io::Result<()>> {
171         Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
172     }
173 }
174 
175 impl Drop for DuplexStream {
drop(&mut self)176     fn drop(&mut self) {
177         // notify the other side of the closure
178         self.write.lock().close_write();
179         self.read.lock().close_read();
180     }
181 }
182 
183 // ===== impl SimplexStream =====
184 
185 /// Creates unidirectional buffer that acts like in memory pipe.
186 ///
187 /// The `max_buf_size` argument is the maximum amount of bytes that can be
188 /// written to a buffer before the it returns `Poll::Pending`.
189 ///
190 /// # Unify reader and writer
191 ///
192 /// The reader and writer half can be unified into a single structure
193 /// of `SimplexStream` that supports both reading and writing or
194 /// the `SimplexStream` can be already created as unified structure
195 /// using [`SimplexStream::new_unsplit()`].
196 ///
197 /// ```
198 /// # async fn ex() -> std::io::Result<()> {
199 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
200 /// let (writer, reader) = tokio::io::simplex(64);
201 /// let mut simplex_stream = writer.unsplit(reader);
202 /// simplex_stream.write_all(b"hello").await?;
203 ///
204 /// let mut buf = [0u8; 5];
205 /// simplex_stream.read_exact(&mut buf).await?;
206 /// assert_eq!(&buf, b"hello");
207 /// # Ok(())
208 /// # }
209 /// ```
210 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>)211 pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
212     split(SimplexStream::new_unsplit(max_buf_size))
213 }
214 
215 impl SimplexStream {
216     /// Creates unidirectional buffer that acts like in memory pipe. To create split
217     /// version with separate reader and writer you can use [`simplex`] function.
218     ///
219     /// The `max_buf_size` argument is the maximum amount of bytes that can be
220     /// written to a buffer before the it returns `Poll::Pending`.
221     #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
new_unsplit(max_buf_size: usize) -> SimplexStream222     pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
223         SimplexStream {
224             buffer: BytesMut::new(),
225             is_closed: false,
226             max_buf_size,
227             read_waker: None,
228             write_waker: None,
229         }
230     }
231 
close_write(&mut self)232     fn close_write(&mut self) {
233         self.is_closed = true;
234         // needs to notify any readers that no more data will come
235         if let Some(waker) = self.read_waker.take() {
236             waker.wake();
237         }
238     }
239 
close_read(&mut self)240     fn close_read(&mut self) {
241         self.is_closed = true;
242         // needs to notify any writers that they have to abort
243         if let Some(waker) = self.write_waker.take() {
244             waker.wake();
245         }
246     }
247 
poll_read_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>248     fn poll_read_internal(
249         mut self: Pin<&mut Self>,
250         cx: &mut task::Context<'_>,
251         buf: &mut ReadBuf<'_>,
252     ) -> Poll<std::io::Result<()>> {
253         if self.buffer.has_remaining() {
254             let max = self.buffer.remaining().min(buf.remaining());
255             buf.put_slice(&self.buffer[..max]);
256             self.buffer.advance(max);
257             if max > 0 {
258                 // The passed `buf` might have been empty, don't wake up if
259                 // no bytes have been moved.
260                 if let Some(waker) = self.write_waker.take() {
261                     waker.wake();
262                 }
263             }
264             Poll::Ready(Ok(()))
265         } else if self.is_closed {
266             Poll::Ready(Ok(()))
267         } else {
268             self.read_waker = Some(cx.waker().clone());
269             Poll::Pending
270         }
271     }
272 
poll_write_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>273     fn poll_write_internal(
274         mut self: Pin<&mut Self>,
275         cx: &mut task::Context<'_>,
276         buf: &[u8],
277     ) -> Poll<std::io::Result<usize>> {
278         if self.is_closed {
279             return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
280         }
281         let avail = self.max_buf_size - self.buffer.len();
282         if avail == 0 {
283             self.write_waker = Some(cx.waker().clone());
284             return Poll::Pending;
285         }
286 
287         let len = buf.len().min(avail);
288         self.buffer.extend_from_slice(&buf[..len]);
289         if let Some(waker) = self.read_waker.take() {
290             waker.wake();
291         }
292         Poll::Ready(Ok(len))
293     }
294 
poll_write_vectored_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>295     fn poll_write_vectored_internal(
296         mut self: Pin<&mut Self>,
297         cx: &mut task::Context<'_>,
298         bufs: &[std::io::IoSlice<'_>],
299     ) -> Poll<Result<usize, std::io::Error>> {
300         if self.is_closed {
301             return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
302         }
303         let avail = self.max_buf_size - self.buffer.len();
304         if avail == 0 {
305             self.write_waker = Some(cx.waker().clone());
306             return Poll::Pending;
307         }
308 
309         let mut rem = avail;
310         for buf in bufs {
311             if rem == 0 {
312                 break;
313             }
314 
315             let len = buf.len().min(rem);
316             self.buffer.extend_from_slice(&buf[..len]);
317             rem -= len;
318         }
319 
320         if let Some(waker) = self.read_waker.take() {
321             waker.wake();
322         }
323         Poll::Ready(Ok(avail - rem))
324     }
325 }
326 
327 impl AsyncRead for SimplexStream {
328     cfg_coop! {
329         fn poll_read(
330             self: Pin<&mut Self>,
331             cx: &mut task::Context<'_>,
332             buf: &mut ReadBuf<'_>,
333         ) -> Poll<std::io::Result<()>> {
334             ready!(crate::trace::trace_leaf(cx));
335             let coop = ready!(crate::runtime::coop::poll_proceed(cx));
336 
337             let ret = self.poll_read_internal(cx, buf);
338             if ret.is_ready() {
339                 coop.made_progress();
340             }
341             ret
342         }
343     }
344 
345     cfg_not_coop! {
346         fn poll_read(
347             self: Pin<&mut Self>,
348             cx: &mut task::Context<'_>,
349             buf: &mut ReadBuf<'_>,
350         ) -> Poll<std::io::Result<()>> {
351             ready!(crate::trace::trace_leaf(cx));
352             self.poll_read_internal(cx, buf)
353         }
354     }
355 }
356 
357 impl AsyncWrite for SimplexStream {
358     cfg_coop! {
359         fn poll_write(
360             self: Pin<&mut Self>,
361             cx: &mut task::Context<'_>,
362             buf: &[u8],
363         ) -> Poll<std::io::Result<usize>> {
364             ready!(crate::trace::trace_leaf(cx));
365             let coop = ready!(crate::runtime::coop::poll_proceed(cx));
366 
367             let ret = self.poll_write_internal(cx, buf);
368             if ret.is_ready() {
369                 coop.made_progress();
370             }
371             ret
372         }
373     }
374 
375     cfg_not_coop! {
376         fn poll_write(
377             self: Pin<&mut Self>,
378             cx: &mut task::Context<'_>,
379             buf: &[u8],
380         ) -> Poll<std::io::Result<usize>> {
381             ready!(crate::trace::trace_leaf(cx));
382             self.poll_write_internal(cx, buf)
383         }
384     }
385 
386     cfg_coop! {
387         fn poll_write_vectored(
388             self: Pin<&mut Self>,
389             cx: &mut task::Context<'_>,
390             bufs: &[std::io::IoSlice<'_>],
391         ) -> Poll<Result<usize, std::io::Error>> {
392             ready!(crate::trace::trace_leaf(cx));
393             let coop = ready!(crate::runtime::coop::poll_proceed(cx));
394 
395             let ret = self.poll_write_vectored_internal(cx, bufs);
396             if ret.is_ready() {
397                 coop.made_progress();
398             }
399             ret
400         }
401     }
402 
403     cfg_not_coop! {
404         fn poll_write_vectored(
405             self: Pin<&mut Self>,
406             cx: &mut task::Context<'_>,
407             bufs: &[std::io::IoSlice<'_>],
408         ) -> Poll<Result<usize, std::io::Error>> {
409             ready!(crate::trace::trace_leaf(cx));
410             self.poll_write_vectored_internal(cx, bufs)
411         }
412     }
413 
is_write_vectored(&self) -> bool414     fn is_write_vectored(&self) -> bool {
415         true
416     }
417 
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>418     fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
419         Poll::Ready(Ok(()))
420     }
421 
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>422     fn poll_shutdown(
423         mut self: Pin<&mut Self>,
424         _: &mut task::Context<'_>,
425     ) -> Poll<std::io::Result<()>> {
426         self.close_write();
427         Poll::Ready(Ok(()))
428     }
429 }
430