1 #![warn(rust_2018_idioms)]
2 #![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind
3                                                                    // No `socket` on miri.
4 
5 use std::io::{Error, ErrorKind, Result};
6 use std::io::{Read, Write};
7 use std::{net, thread};
8 
9 use tokio::io::{AsyncReadExt, AsyncWriteExt};
10 use tokio::net::{TcpListener, TcpStream};
11 use tokio::try_join;
12 
13 #[tokio::test]
split() -> Result<()>14 async fn split() -> Result<()> {
15     const MSG: &[u8] = b"split";
16 
17     let listener = TcpListener::bind("127.0.0.1:0").await?;
18     let addr = listener.local_addr()?;
19 
20     let (stream1, (mut stream2, _)) = try_join! {
21         TcpStream::connect(&addr),
22         listener.accept(),
23     }?;
24     let (mut read_half, mut write_half) = stream1.into_split();
25 
26     let ((), (), ()) = try_join! {
27         async {
28             let len = stream2.write(MSG).await?;
29             assert_eq!(len, MSG.len());
30 
31             let mut read_buf = vec![0u8; 32];
32             let read_len = stream2.read(&mut read_buf).await?;
33             assert_eq!(&read_buf[..read_len], MSG);
34             Result::Ok(())
35         },
36         async {
37             let len = write_half.write(MSG).await?;
38             assert_eq!(len, MSG.len());
39             Ok(())
40         },
41         async {
42             let mut read_buf = [0u8; 32];
43             let peek_len1 = read_half.peek(&mut read_buf[..]).await?;
44             let peek_len2 = read_half.peek(&mut read_buf[..]).await?;
45             assert_eq!(peek_len1, peek_len2);
46 
47             let read_len = read_half.read(&mut read_buf[..]).await?;
48             assert_eq!(peek_len1, read_len);
49             assert_eq!(&read_buf[..read_len], MSG);
50             Ok(())
51         },
52     }?;
53 
54     Ok(())
55 }
56 
57 #[tokio::test]
reunite() -> Result<()>58 async fn reunite() -> Result<()> {
59     let listener = net::TcpListener::bind("127.0.0.1:0")?;
60     let addr = listener.local_addr()?;
61 
62     let handle = thread::spawn(move || {
63         drop(listener.accept().unwrap());
64         drop(listener.accept().unwrap());
65     });
66 
67     let stream1 = TcpStream::connect(&addr).await?;
68     let (read1, write1) = stream1.into_split();
69 
70     let stream2 = TcpStream::connect(&addr).await?;
71     let (_, write2) = stream2.into_split();
72 
73     let read1 = match read1.reunite(write2) {
74         Ok(_) => panic!("Reunite should not succeed"),
75         Err(err) => err.0,
76     };
77 
78     read1.reunite(write1).expect("Reunite should succeed");
79 
80     handle.join().unwrap();
81     Ok(())
82 }
83 
84 /// Test that dropping the write half actually closes the stream.
85 #[tokio::test]
drop_write() -> Result<()>86 async fn drop_write() -> Result<()> {
87     const MSG: &[u8] = b"split";
88 
89     let listener = net::TcpListener::bind("127.0.0.1:0")?;
90     let addr = listener.local_addr()?;
91 
92     let handle = thread::spawn(move || {
93         let (mut stream, _) = listener.accept().unwrap();
94         stream.write_all(MSG).unwrap();
95 
96         let mut read_buf = [0u8; 32];
97         let res = match stream.read(&mut read_buf) {
98             Ok(0) => Ok(()),
99             Ok(len) => Err(Error::new(
100                 ErrorKind::Other,
101                 format!("Unexpected read: {len} bytes."),
102             )),
103             Err(err) => Err(err),
104         };
105 
106         drop(stream);
107 
108         res
109     });
110 
111     let stream = TcpStream::connect(&addr).await?;
112     let (mut read_half, write_half) = stream.into_split();
113 
114     let mut read_buf = [0u8; 32];
115     let read_len = read_half.read(&mut read_buf[..]).await?;
116     assert_eq!(&read_buf[..read_len], MSG);
117 
118     // drop it while the read is in progress
119     std::thread::spawn(move || {
120         thread::sleep(std::time::Duration::from_millis(10));
121         drop(write_half);
122     });
123 
124     match read_half.read(&mut read_buf[..]).await {
125         Ok(0) => {}
126         Ok(len) => panic!("Unexpected read: {len} bytes."),
127         Err(err) => panic!("Unexpected error: {err}."),
128     }
129 
130     handle.join().unwrap().unwrap();
131     Ok(())
132 }
133