1 #![allow(unknown_lints, unexpected_cfgs)]
2 #![warn(rust_2018_idioms)]
3 #![cfg(feature = "full")]
4 
5 use futures::future::FutureExt;
6 use tokio::sync::oneshot;
7 use tokio::task::JoinSet;
8 use tokio::time::Duration;
9 
rt() -> tokio::runtime::Runtime10 fn rt() -> tokio::runtime::Runtime {
11     tokio::runtime::Builder::new_current_thread()
12         .build()
13         .unwrap()
14 }
15 
16 #[tokio::test(start_paused = true)]
test_with_sleep()17 async fn test_with_sleep() {
18     let mut set = JoinSet::new();
19 
20     for i in 0..10 {
21         set.spawn(async move { i });
22         assert_eq!(set.len(), 1 + i);
23     }
24     set.detach_all();
25     assert_eq!(set.len(), 0);
26 
27     assert!(set.join_next().await.is_none());
28 
29     for i in 0..10 {
30         set.spawn(async move {
31             tokio::time::sleep(Duration::from_secs(i as u64)).await;
32             i
33         });
34         assert_eq!(set.len(), 1 + i);
35     }
36 
37     let mut seen = [false; 10];
38     while let Some(res) = set.join_next().await.transpose().unwrap() {
39         seen[res] = true;
40     }
41 
42     for was_seen in &seen {
43         assert!(was_seen);
44     }
45     assert!(set.join_next().await.is_none());
46 
47     // Do it again.
48     for i in 0..10 {
49         set.spawn(async move {
50             tokio::time::sleep(Duration::from_secs(i as u64)).await;
51             i
52         });
53     }
54 
55     let mut seen = [false; 10];
56     while let Some(res) = set.join_next().await.transpose().unwrap() {
57         seen[res] = true;
58     }
59 
60     for was_seen in &seen {
61         assert!(was_seen);
62     }
63     assert!(set.join_next().await.is_none());
64 }
65 
66 #[tokio::test]
test_abort_on_drop()67 async fn test_abort_on_drop() {
68     let mut set = JoinSet::new();
69 
70     let mut recvs = Vec::new();
71 
72     for _ in 0..16 {
73         let (send, recv) = oneshot::channel::<()>();
74         recvs.push(recv);
75 
76         set.spawn(async {
77             // This task will never complete on its own.
78             futures::future::pending::<()>().await;
79             drop(send);
80         });
81     }
82 
83     drop(set);
84 
85     for recv in recvs {
86         // The task is aborted soon and we will receive an error.
87         assert!(recv.await.is_err());
88     }
89 }
90 
91 #[tokio::test]
alternating()92 async fn alternating() {
93     let mut set = JoinSet::new();
94 
95     assert_eq!(set.len(), 0);
96     set.spawn(async {});
97     assert_eq!(set.len(), 1);
98     set.spawn(async {});
99     assert_eq!(set.len(), 2);
100 
101     for _ in 0..16 {
102         let () = set.join_next().await.unwrap().unwrap();
103         assert_eq!(set.len(), 1);
104         set.spawn(async {});
105         assert_eq!(set.len(), 2);
106     }
107 }
108 
109 #[tokio::test(start_paused = true)]
abort_tasks()110 async fn abort_tasks() {
111     let mut set = JoinSet::new();
112     let mut num_canceled = 0;
113     let mut num_completed = 0;
114     for i in 0..16 {
115         let abort = set.spawn(async move {
116             tokio::time::sleep(Duration::from_secs(i as u64)).await;
117             i
118         });
119 
120         if i % 2 != 0 {
121             // abort odd-numbered tasks.
122             abort.abort();
123         }
124     }
125     loop {
126         match set.join_next().await {
127             Some(Ok(res)) => {
128                 num_completed += 1;
129                 assert_eq!(res % 2, 0);
130             }
131             Some(Err(e)) => {
132                 assert!(e.is_cancelled());
133                 num_canceled += 1;
134             }
135             None => break,
136         }
137     }
138 
139     assert_eq!(num_canceled, 8);
140     assert_eq!(num_completed, 8);
141 }
142 
143 #[test]
runtime_gone()144 fn runtime_gone() {
145     let mut set = JoinSet::new();
146     {
147         let rt = rt();
148         set.spawn_on(async { 1 }, rt.handle());
149         drop(rt);
150     }
151 
152     assert!(rt()
153         .block_on(set.join_next())
154         .unwrap()
155         .unwrap_err()
156         .is_cancelled());
157 }
158 
159 #[tokio::test]
join_all()160 async fn join_all() {
161     let mut set: JoinSet<i32> = JoinSet::new();
162 
163     for _ in 0..5 {
164         set.spawn(async { 1 });
165     }
166     let res: Vec<i32> = set.join_all().await;
167 
168     assert_eq!(res.len(), 5);
169     for itm in res.into_iter() {
170         assert_eq!(itm, 1)
171     }
172 }
173 
174 #[cfg(panic = "unwind")]
175 #[tokio::test(start_paused = true)]
task_panics()176 async fn task_panics() {
177     let mut set: JoinSet<()> = JoinSet::new();
178 
179     let (tx, mut rx) = oneshot::channel();
180     assert_eq!(set.len(), 0);
181 
182     set.spawn(async move {
183         tokio::time::sleep(Duration::from_secs(2)).await;
184         tx.send(()).unwrap();
185     });
186     assert_eq!(set.len(), 1);
187 
188     set.spawn(async {
189         tokio::time::sleep(Duration::from_secs(1)).await;
190         panic!();
191     });
192     assert_eq!(set.len(), 2);
193 
194     let panic = tokio::spawn(set.join_all()).await.unwrap_err();
195     assert!(rx.try_recv().is_err());
196     assert!(panic.is_panic());
197 }
198 
199 #[tokio::test(start_paused = true)]
abort_all()200 async fn abort_all() {
201     let mut set: JoinSet<()> = JoinSet::new();
202 
203     for _ in 0..5 {
204         set.spawn(futures::future::pending());
205     }
206     for _ in 0..5 {
207         set.spawn(async {
208             tokio::time::sleep(Duration::from_secs(1)).await;
209         });
210     }
211 
212     // The join set will now have 5 pending tasks and 5 ready tasks.
213     tokio::time::sleep(Duration::from_secs(2)).await;
214 
215     set.abort_all();
216     assert_eq!(set.len(), 10);
217 
218     let mut count = 0;
219     while let Some(res) = set.join_next().await {
220         if let Err(err) = res {
221             assert!(err.is_cancelled());
222         }
223         count += 1;
224     }
225     assert_eq!(count, 10);
226     assert_eq!(set.len(), 0);
227 }
228 
229 // This ensures that `join_next` works correctly when the coop budget is
230 // exhausted.
231 #[tokio::test(flavor = "current_thread")]
join_set_coop()232 async fn join_set_coop() {
233     // Large enough to trigger coop.
234     const TASK_NUM: u32 = 1000;
235 
236     static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0);
237 
238     let mut set = JoinSet::new();
239 
240     for _ in 0..TASK_NUM {
241         set.spawn(async {
242             SEM.add_permits(1);
243         });
244     }
245 
246     // Wait for all tasks to complete.
247     //
248     // Since this is a `current_thread` runtime, there's no race condition
249     // between the last permit being added and the task completing.
250     let _ = SEM.acquire_many(TASK_NUM).await.unwrap();
251 
252     let mut count = 0;
253     let mut coop_count = 0;
254     loop {
255         match set.join_next().now_or_never() {
256             Some(Some(Ok(()))) => {}
257             Some(Some(Err(err))) => panic!("failed: {err}"),
258             None => {
259                 coop_count += 1;
260                 tokio::task::yield_now().await;
261                 continue;
262             }
263             Some(None) => break,
264         }
265 
266         count += 1;
267     }
268     assert!(coop_count >= 1);
269     assert_eq!(count, TASK_NUM);
270 }
271 
272 #[tokio::test(flavor = "current_thread")]
try_join_next()273 async fn try_join_next() {
274     const TASK_NUM: u32 = 1000;
275 
276     let (send, recv) = tokio::sync::watch::channel(());
277 
278     let mut set = JoinSet::new();
279 
280     for _ in 0..TASK_NUM {
281         let mut recv = recv.clone();
282         set.spawn(async move { recv.changed().await.unwrap() });
283     }
284     drop(recv);
285 
286     assert!(set.try_join_next().is_none());
287 
288     send.send_replace(());
289     send.closed().await;
290 
291     let mut count = 0;
292     loop {
293         match set.try_join_next() {
294             Some(Ok(())) => {
295                 count += 1;
296             }
297             Some(Err(err)) => panic!("failed: {err}"),
298             None => {
299                 break;
300             }
301         }
302     }
303 
304     assert_eq!(count, TASK_NUM);
305 }
306 
307 #[cfg(tokio_unstable)]
308 #[tokio::test(flavor = "current_thread")]
try_join_next_with_id()309 async fn try_join_next_with_id() {
310     const TASK_NUM: u32 = 1000;
311 
312     let (send, recv) = tokio::sync::watch::channel(());
313 
314     let mut set = JoinSet::new();
315     let mut spawned = std::collections::HashSet::with_capacity(TASK_NUM as usize);
316 
317     for _ in 0..TASK_NUM {
318         let mut recv = recv.clone();
319         let handle = set.spawn(async move { recv.changed().await.unwrap() });
320 
321         spawned.insert(handle.id());
322     }
323     drop(recv);
324 
325     assert!(set.try_join_next_with_id().is_none());
326 
327     send.send_replace(());
328     send.closed().await;
329 
330     let mut count = 0;
331     let mut joined = std::collections::HashSet::with_capacity(TASK_NUM as usize);
332     loop {
333         match set.try_join_next_with_id() {
334             Some(Ok((id, ()))) => {
335                 count += 1;
336                 joined.insert(id);
337             }
338             Some(Err(err)) => panic!("failed: {}", err),
339             None => {
340                 break;
341             }
342         }
343     }
344 
345     assert_eq!(count, TASK_NUM);
346     assert_eq!(joined, spawned);
347 }
348