1 #![warn(rust_2018_idioms)]
2 #![cfg(all(feature = "full", not(target_os = "wasi"), not(miri)))] // Wasi doesn't support bind
3 // No `socket` on miri.
4
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::sync::{mpsc, oneshot};
7 use tokio_test::assert_ok;
8
9 use std::io;
10 use std::net::{IpAddr, SocketAddr};
11
12 macro_rules! test_accept {
13 ($(($ident:ident, $target:expr),)*) => {
14 $(
15 #[tokio::test]
16 async fn $ident() {
17 let listener = assert_ok!(TcpListener::bind($target).await);
18 let addr = listener.local_addr().unwrap();
19
20 let (tx, rx) = oneshot::channel();
21
22 tokio::spawn(async move {
23 let (socket, _) = assert_ok!(listener.accept().await);
24 assert_ok!(tx.send(socket));
25 });
26
27 let cli = assert_ok!(TcpStream::connect(&addr).await);
28 let srv = assert_ok!(rx.await);
29
30 assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
31 }
32 )*
33 }
34 }
35
36 test_accept! {
37 (ip_str, "127.0.0.1:0"),
38 (host_str, "localhost:0"),
39 (socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
40 (str_port_tuple, ("127.0.0.1", 0)),
41 (ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
42 }
43
44 use std::pin::Pin;
45 use std::sync::{
46 atomic::{AtomicUsize, Ordering::SeqCst},
47 Arc,
48 };
49 use std::task::{Context, Poll};
50 use tokio_stream::{Stream, StreamExt};
51
52 struct TrackPolls<'a> {
53 npolls: Arc<AtomicUsize>,
54 listener: &'a mut TcpListener,
55 }
56
57 impl<'a> Stream for TrackPolls<'a> {
58 type Item = io::Result<(TcpStream, SocketAddr)>;
59
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>60 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
61 self.npolls.fetch_add(1, SeqCst);
62 self.listener.poll_accept(cx).map(Some)
63 }
64 }
65
66 #[tokio::test]
no_extra_poll()67 async fn no_extra_poll() {
68 let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
69 let addr = listener.local_addr().unwrap();
70
71 let (tx, rx) = oneshot::channel();
72 let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
73
74 tokio::spawn(async move {
75 let mut incoming = TrackPolls {
76 npolls: Arc::new(AtomicUsize::new(0)),
77 listener: &mut listener,
78 };
79 assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
80 while incoming.next().await.is_some() {
81 accepted_tx.send(()).unwrap();
82 }
83 });
84
85 let npolls = assert_ok!(rx.await);
86 tokio::task::yield_now().await;
87
88 // should have been polled exactly once: the initial poll
89 assert_eq!(npolls.load(SeqCst), 1);
90
91 let _ = assert_ok!(TcpStream::connect(&addr).await);
92 accepted_rx.recv().await.unwrap();
93
94 // should have been polled twice more: once to yield Some(), then once to yield Pending
95 assert_eq!(npolls.load(SeqCst), 1 + 2);
96 }
97
98 #[tokio::test]
accept_many()99 async fn accept_many() {
100 use std::future::{poll_fn, Future};
101 use std::sync::atomic::AtomicBool;
102
103 const N: usize = 50;
104
105 let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
106 let listener = Arc::new(listener);
107 let addr = listener.local_addr().unwrap();
108 let connected = Arc::new(AtomicBool::new(false));
109
110 let (pending_tx, mut pending_rx) = mpsc::unbounded_channel();
111 let (notified_tx, mut notified_rx) = mpsc::unbounded_channel();
112
113 for _ in 0..N {
114 let listener = listener.clone();
115 let connected = connected.clone();
116 let pending_tx = pending_tx.clone();
117 let notified_tx = notified_tx.clone();
118
119 tokio::spawn(async move {
120 let accept = listener.accept();
121 tokio::pin!(accept);
122
123 let mut polled = false;
124
125 poll_fn(|cx| {
126 if !polled {
127 polled = true;
128 assert!(Pin::new(&mut accept).poll(cx).is_pending());
129 pending_tx.send(()).unwrap();
130 Poll::Pending
131 } else if connected.load(SeqCst) {
132 notified_tx.send(()).unwrap();
133 Poll::Ready(())
134 } else {
135 Poll::Pending
136 }
137 })
138 .await;
139
140 pending_tx.send(()).unwrap();
141 });
142 }
143
144 // Wait for all tasks to have polled at least once
145 for _ in 0..N {
146 pending_rx.recv().await.unwrap();
147 }
148
149 // Establish a TCP connection
150 connected.store(true, SeqCst);
151 let _sock = TcpStream::connect(addr).await.unwrap();
152
153 // Wait for all notifications
154 for _ in 0..N {
155 notified_rx.recv().await.unwrap();
156 }
157 }
158