1 #![warn(rust_2018_idioms)]
2 #![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind
3                                                                    // No `socket` on miri.
4 
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::sync::{mpsc, oneshot};
7 use tokio_test::assert_ok;
8 
9 use std::io;
10 use std::net::{IpAddr, SocketAddr};
11 
12 macro_rules! test_accept {
13     ($(($ident:ident, $target:expr),)*) => {
14         $(
15             #[tokio::test]
16             async fn $ident() {
17                 let listener = assert_ok!(TcpListener::bind($target).await);
18                 let addr = listener.local_addr().unwrap();
19 
20                 let (tx, rx) = oneshot::channel();
21 
22                 tokio::spawn(async move {
23                     let (socket, _) = assert_ok!(listener.accept().await);
24                     assert_ok!(tx.send(socket));
25                 });
26 
27                 let cli = assert_ok!(TcpStream::connect(&addr).await);
28                 let srv = assert_ok!(rx.await);
29 
30                 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
31             }
32         )*
33     }
34 }
35 
36 test_accept! {
37     (ip_str, "127.0.0.1:0"),
38     (host_str, "localhost:0"),
39     (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
40     (str_port_tuple, ("127.0.0.1", 0)),
41     (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
42 }
43 
44 use std::pin::Pin;
45 use std::sync::{
46     atomic::{AtomicUsize, Ordering::SeqCst},
47     Arc,
48 };
49 use std::task::{Context, Poll};
50 use tokio_stream::{Stream, StreamExt};
51 
52 struct TrackPolls<'a> {
53     npolls: Arc<AtomicUsize>,
54     listener: &'a mut TcpListener,
55 }
56 
57 impl<'a> Stream for TrackPolls<'a> {
58     type Item = io::Result<(TcpStream, SocketAddr)>;
59 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>60     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61         self.npolls.fetch_add(1, SeqCst);
62         self.listener.poll_accept(cx).map(Some)
63     }
64 }
65 
66 #[tokio::test]
no_extra_poll()67 async fn no_extra_poll() {
68     let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
69     let addr = listener.local_addr().unwrap();
70 
71     let (tx, rx) = oneshot::channel();
72     let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
73 
74     tokio::spawn(async move {
75         let mut incoming = TrackPolls {
76             npolls: Arc::new(AtomicUsize::new(0)),
77             listener: &mut listener,
78         };
79         assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
80         while incoming.next().await.is_some() {
81             accepted_tx.send(()).unwrap();
82         }
83     });
84 
85     let npolls = assert_ok!(rx.await);
86     tokio::task::yield_now().await;
87 
88     // should have been polled exactly once: the initial poll
89     assert_eq!(npolls.load(SeqCst), 1);
90 
91     let _ = assert_ok!(TcpStream::connect(&addr).await);
92     accepted_rx.recv().await.unwrap();
93 
94     // should have been polled twice more: once to yield Some(), then once to yield Pending
95     assert_eq!(npolls.load(SeqCst), 1 + 2);
96 }
97 
98 #[tokio::test]
accept_many()99 async fn accept_many() {
100     use std::future::{poll_fn, Future};
101     use std::sync::atomic::AtomicBool;
102 
103     const N: usize = 50;
104 
105     let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
106     let listener = Arc::new(listener);
107     let addr = listener.local_addr().unwrap();
108     let connected = Arc::new(AtomicBool::new(false));
109 
110     let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
111     let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
112 
113     for _ in 0..N {
114         let listener = listener.clone();
115         let connected = connected.clone();
116         let pending_tx = pending_tx.clone();
117         let notified_tx = notified_tx.clone();
118 
119         tokio::spawn(async move {
120             let accept = listener.accept();
121             tokio::pin!(accept);
122 
123             let mut polled = false;
124 
125             poll_fn(|cx| {
126                 if !polled {
127                     polled = true;
128                     assert!(Pin::new(&mut accept).poll(cx).is_pending());
129                     pending_tx.send(()).unwrap();
130                     Poll::Pending
131                 } else if connected.load(SeqCst) {
132                     notified_tx.send(()).unwrap();
133                     Poll::Ready(())
134                 } else {
135                     Poll::Pending
136                 }
137             })
138             .await;
139 
140             pending_tx.send(()).unwrap();
141         });
142     }
143 
144     // Wait for all tasks to have polled at least once
145     for _ in 0..N {
146         pending_rx.recv().await.unwrap();
147     }
148 
149     // Establish a TCP connection
150     connected.store(true, SeqCst);
151     let _sock = TcpStream::connect(addr).await.unwrap();
152 
153     // Wait for all notifications
154     for _ in 0..N {
155         notified_rx.recv().await.unwrap();
156     }
157 }
158