1 //! In-process memory IO types.
2
3 use crate::io::{split, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
4 use crate::loom::sync::Mutex;
5
6 use bytes::{Buf, BytesMut};
7 use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, ready, Poll, Waker},
11 };
12
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Closing a `DuplexStream`
20 ///
21 /// If one end of the `DuplexStream` channel is dropped, any pending reads on
22 /// the other side will continue to read data until the buffer is drained, then
23 /// they will signal EOF by returning 0 bytes. Any writes to the other side,
24 /// including pending ones (that are waiting for free space in the buffer) will
25 /// return `Err(BrokenPipe)` immediately.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// # async fn ex() -> std::io::Result<()> {
31 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32 /// let (mut client, mut server) = tokio::io::duplex(64);
33 ///
34 /// client.write_all(b"ping").await?;
35 ///
36 /// let mut buf = [0u8; 4];
37 /// server.read_exact(&mut buf).await?;
38 /// assert_eq!(&buf, b"ping");
39 ///
40 /// server.write_all(b"pong").await?;
41 ///
42 /// client.read_exact(&mut buf).await?;
43 /// assert_eq!(&buf, b"pong");
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[derive(Debug)]
48 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49 pub struct DuplexStream {
50 read: Arc<Mutex<SimplexStream>>,
51 write: Arc<Mutex<SimplexStream>>,
52 }
53
54 /// A unidirectional pipe to read and write bytes in memory.
55 ///
56 /// It can be constructed by [`simplex`] function which will create a pair of
57 /// reader and writer or by calling [`SimplexStream::new_unsplit`] that will
58 /// create a handle for both reading and writing.
59 ///
60 /// # Example
61 ///
62 /// ```
63 /// # async fn ex() -> std::io::Result<()> {
64 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
65 /// let (mut receiver, mut sender) = tokio::io::simplex(64);
66 ///
67 /// sender.write_all(b"ping").await?;
68 ///
69 /// let mut buf = [0u8; 4];
70 /// receiver.read_exact(&mut buf).await?;
71 /// assert_eq!(&buf, b"ping");
72 /// # Ok(())
73 /// # }
74 /// ```
75 #[derive(Debug)]
76 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
77 pub struct SimplexStream {
78 /// The buffer storing the bytes written, also read from.
79 ///
80 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
81 /// functionality already. Additionally, it can try to copy data in the
82 /// same buffer if there read index has advanced far enough.
83 buffer: BytesMut,
84 /// Determines if the write side has been closed.
85 is_closed: bool,
86 /// The maximum amount of bytes that can be written before returning
87 /// `Poll::Pending`.
88 max_buf_size: usize,
89 /// If the `read` side has been polled and is pending, this is the waker
90 /// for that parked task.
91 read_waker: Option<Waker>,
92 /// If the `write` side has filled the `max_buf_size` and returned
93 /// `Poll::Pending`, this is the waker for that parked task.
94 write_waker: Option<Waker>,
95 }
96
97 // ===== impl DuplexStream =====
98
99 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
100 ///
101 /// The `max_buf_size` argument is the maximum amount of bytes that can be
102 /// written to a side before the write returns `Poll::Pending`.
103 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)104 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
105 let one = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
106 let two = Arc::new(Mutex::new(SimplexStream::new_unsplit(max_buf_size)));
107
108 (
109 DuplexStream {
110 read: one.clone(),
111 write: two.clone(),
112 },
113 DuplexStream {
114 read: two,
115 write: one,
116 },
117 )
118 }
119
120 impl AsyncRead for DuplexStream {
121 // Previous rustc required this `self` to be `mut`, even though newer
122 // versions recognize it isn't needed to call `lock()`. So for
123 // compatibility, we include the `mut` and `allow` the lint.
124 //
125 // See https://github.com/rust-lang/rust/issues/73592
126 #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>127 fn poll_read(
128 mut self: Pin<&mut Self>,
129 cx: &mut task::Context<'_>,
130 buf: &mut ReadBuf<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
133 }
134 }
135
136 impl AsyncWrite for DuplexStream {
137 #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>138 fn poll_write(
139 mut self: Pin<&mut Self>,
140 cx: &mut task::Context<'_>,
141 buf: &[u8],
142 ) -> Poll<std::io::Result<usize>> {
143 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
144 }
145
poll_write_vectored( self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>146 fn poll_write_vectored(
147 self: Pin<&mut Self>,
148 cx: &mut task::Context<'_>,
149 bufs: &[std::io::IoSlice<'_>],
150 ) -> Poll<Result<usize, std::io::Error>> {
151 Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs)
152 }
153
is_write_vectored(&self) -> bool154 fn is_write_vectored(&self) -> bool {
155 true
156 }
157
158 #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>159 fn poll_flush(
160 mut self: Pin<&mut Self>,
161 cx: &mut task::Context<'_>,
162 ) -> Poll<std::io::Result<()>> {
163 Pin::new(&mut *self.write.lock()).poll_flush(cx)
164 }
165
166 #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>167 fn poll_shutdown(
168 mut self: Pin<&mut Self>,
169 cx: &mut task::Context<'_>,
170 ) -> Poll<std::io::Result<()>> {
171 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
172 }
173 }
174
175 impl Drop for DuplexStream {
drop(&mut self)176 fn drop(&mut self) {
177 // notify the other side of the closure
178 self.write.lock().close_write();
179 self.read.lock().close_read();
180 }
181 }
182
183 // ===== impl SimplexStream =====
184
185 /// Creates unidirectional buffer that acts like in memory pipe.
186 ///
187 /// The `max_buf_size` argument is the maximum amount of bytes that can be
188 /// written to a buffer before the it returns `Poll::Pending`.
189 ///
190 /// # Unify reader and writer
191 ///
192 /// The reader and writer half can be unified into a single structure
193 /// of `SimplexStream` that supports both reading and writing or
194 /// the `SimplexStream` can be already created as unified structure
195 /// using [`SimplexStream::new_unsplit()`].
196 ///
197 /// ```
198 /// # async fn ex() -> std::io::Result<()> {
199 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
200 /// let (writer, reader) = tokio::io::simplex(64);
201 /// let mut simplex_stream = writer.unsplit(reader);
202 /// simplex_stream.write_all(b"hello").await?;
203 ///
204 /// let mut buf = [0u8; 5];
205 /// simplex_stream.read_exact(&mut buf).await?;
206 /// assert_eq!(&buf, b"hello");
207 /// # Ok(())
208 /// # }
209 /// ```
210 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>)211 pub fn simplex(max_buf_size: usize) -> (ReadHalf<SimplexStream>, WriteHalf<SimplexStream>) {
212 split(SimplexStream::new_unsplit(max_buf_size))
213 }
214
215 impl SimplexStream {
216 /// Creates unidirectional buffer that acts like in memory pipe. To create split
217 /// version with separate reader and writer you can use [`simplex`] function.
218 ///
219 /// The `max_buf_size` argument is the maximum amount of bytes that can be
220 /// written to a buffer before the it returns `Poll::Pending`.
221 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
new_unsplit(max_buf_size: usize) -> SimplexStream222 pub fn new_unsplit(max_buf_size: usize) -> SimplexStream {
223 SimplexStream {
224 buffer: BytesMut::new(),
225 is_closed: false,
226 max_buf_size,
227 read_waker: None,
228 write_waker: None,
229 }
230 }
231
close_write(&mut self)232 fn close_write(&mut self) {
233 self.is_closed = true;
234 // needs to notify any readers that no more data will come
235 if let Some(waker) = self.read_waker.take() {
236 waker.wake();
237 }
238 }
239
close_read(&mut self)240 fn close_read(&mut self) {
241 self.is_closed = true;
242 // needs to notify any writers that they have to abort
243 if let Some(waker) = self.write_waker.take() {
244 waker.wake();
245 }
246 }
247
poll_read_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>248 fn poll_read_internal(
249 mut self: Pin<&mut Self>,
250 cx: &mut task::Context<'_>,
251 buf: &mut ReadBuf<'_>,
252 ) -> Poll<std::io::Result<()>> {
253 if self.buffer.has_remaining() {
254 let max = self.buffer.remaining().min(buf.remaining());
255 buf.put_slice(&self.buffer[..max]);
256 self.buffer.advance(max);
257 if max > 0 {
258 // The passed `buf` might have been empty, don't wake up if
259 // no bytes have been moved.
260 if let Some(waker) = self.write_waker.take() {
261 waker.wake();
262 }
263 }
264 Poll::Ready(Ok(()))
265 } else if self.is_closed {
266 Poll::Ready(Ok(()))
267 } else {
268 self.read_waker = Some(cx.waker().clone());
269 Poll::Pending
270 }
271 }
272
poll_write_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>273 fn poll_write_internal(
274 mut self: Pin<&mut Self>,
275 cx: &mut task::Context<'_>,
276 buf: &[u8],
277 ) -> Poll<std::io::Result<usize>> {
278 if self.is_closed {
279 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
280 }
281 let avail = self.max_buf_size - self.buffer.len();
282 if avail == 0 {
283 self.write_waker = Some(cx.waker().clone());
284 return Poll::Pending;
285 }
286
287 let len = buf.len().min(avail);
288 self.buffer.extend_from_slice(&buf[..len]);
289 if let Some(waker) = self.read_waker.take() {
290 waker.wake();
291 }
292 Poll::Ready(Ok(len))
293 }
294
poll_write_vectored_internal( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll<Result<usize, std::io::Error>>295 fn poll_write_vectored_internal(
296 mut self: Pin<&mut Self>,
297 cx: &mut task::Context<'_>,
298 bufs: &[std::io::IoSlice<'_>],
299 ) -> Poll<Result<usize, std::io::Error>> {
300 if self.is_closed {
301 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
302 }
303 let avail = self.max_buf_size - self.buffer.len();
304 if avail == 0 {
305 self.write_waker = Some(cx.waker().clone());
306 return Poll::Pending;
307 }
308
309 let mut rem = avail;
310 for buf in bufs {
311 if rem == 0 {
312 break;
313 }
314
315 let len = buf.len().min(rem);
316 self.buffer.extend_from_slice(&buf[..len]);
317 rem -= len;
318 }
319
320 if let Some(waker) = self.read_waker.take() {
321 waker.wake();
322 }
323 Poll::Ready(Ok(avail - rem))
324 }
325 }
326
327 impl AsyncRead for SimplexStream {
328 cfg_coop! {
329 fn poll_read(
330 self: Pin<&mut Self>,
331 cx: &mut task::Context<'_>,
332 buf: &mut ReadBuf<'_>,
333 ) -> Poll<std::io::Result<()>> {
334 ready!(crate::trace::trace_leaf(cx));
335 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
336
337 let ret = self.poll_read_internal(cx, buf);
338 if ret.is_ready() {
339 coop.made_progress();
340 }
341 ret
342 }
343 }
344
345 cfg_not_coop! {
346 fn poll_read(
347 self: Pin<&mut Self>,
348 cx: &mut task::Context<'_>,
349 buf: &mut ReadBuf<'_>,
350 ) -> Poll<std::io::Result<()>> {
351 ready!(crate::trace::trace_leaf(cx));
352 self.poll_read_internal(cx, buf)
353 }
354 }
355 }
356
357 impl AsyncWrite for SimplexStream {
358 cfg_coop! {
359 fn poll_write(
360 self: Pin<&mut Self>,
361 cx: &mut task::Context<'_>,
362 buf: &[u8],
363 ) -> Poll<std::io::Result<usize>> {
364 ready!(crate::trace::trace_leaf(cx));
365 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
366
367 let ret = self.poll_write_internal(cx, buf);
368 if ret.is_ready() {
369 coop.made_progress();
370 }
371 ret
372 }
373 }
374
375 cfg_not_coop! {
376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut task::Context<'_>,
379 buf: &[u8],
380 ) -> Poll<std::io::Result<usize>> {
381 ready!(crate::trace::trace_leaf(cx));
382 self.poll_write_internal(cx, buf)
383 }
384 }
385
386 cfg_coop! {
387 fn poll_write_vectored(
388 self: Pin<&mut Self>,
389 cx: &mut task::Context<'_>,
390 bufs: &[std::io::IoSlice<'_>],
391 ) -> Poll<Result<usize, std::io::Error>> {
392 ready!(crate::trace::trace_leaf(cx));
393 let coop = ready!(crate::runtime::coop::poll_proceed(cx));
394
395 let ret = self.poll_write_vectored_internal(cx, bufs);
396 if ret.is_ready() {
397 coop.made_progress();
398 }
399 ret
400 }
401 }
402
403 cfg_not_coop! {
404 fn poll_write_vectored(
405 self: Pin<&mut Self>,
406 cx: &mut task::Context<'_>,
407 bufs: &[std::io::IoSlice<'_>],
408 ) -> Poll<Result<usize, std::io::Error>> {
409 ready!(crate::trace::trace_leaf(cx));
410 self.poll_write_vectored_internal(cx, bufs)
411 }
412 }
413
is_write_vectored(&self) -> bool414 fn is_write_vectored(&self) -> bool {
415 true
416 }
417
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>418 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
419 Poll::Ready(Ok(()))
420 }
421
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>422 fn poll_shutdown(
423 mut self: Pin<&mut Self>,
424 _: &mut task::Context<'_>,
425 ) -> Poll<std::io::Result<()>> {
426 self.close_write();
427 Poll::Ready(Ok(()))
428 }
429 }
430