1 #![allow(unknown_lints, unexpected_cfgs)]
2 #![warn(rust_2018_idioms)]
3 #![cfg(all(feature = "rt", tokio_unstable))]
4 
5 use tokio::sync::oneshot;
6 use tokio::time::Duration;
7 use tokio_util::task::JoinMap;
8 
9 use futures::future::FutureExt;
10 
rt() -> tokio::runtime::Runtime11 fn rt() -> tokio::runtime::Runtime {
12     tokio::runtime::Builder::new_current_thread()
13         .build()
14         .unwrap()
15 }
16 
17 #[tokio::test(start_paused = true)]
test_with_sleep()18 async fn test_with_sleep() {
19     let mut map = JoinMap::new();
20 
21     for i in 0..10 {
22         map.spawn(i, async move { i });
23         assert_eq!(map.len(), 1 + i);
24     }
25     map.detach_all();
26     assert_eq!(map.len(), 0);
27 
28     assert!(matches!(map.join_next().await, None));
29 
30     for i in 0..10 {
31         map.spawn(i, async move {
32             tokio::time::sleep(Duration::from_secs(i as u64)).await;
33             i
34         });
35         assert_eq!(map.len(), 1 + i);
36     }
37 
38     let mut seen = [false; 10];
39     while let Some((k, res)) = map.join_next().await {
40         seen[k] = true;
41         assert_eq!(res.expect("task should have completed successfully"), k);
42     }
43 
44     for was_seen in &seen {
45         assert!(was_seen);
46     }
47     assert!(matches!(map.join_next().await, None));
48 
49     // Do it again.
50     for i in 0..10 {
51         map.spawn(i, async move {
52             tokio::time::sleep(Duration::from_secs(i as u64)).await;
53             i
54         });
55     }
56 
57     let mut seen = [false; 10];
58     while let Some((k, res)) = map.join_next().await {
59         seen[k] = true;
60         assert_eq!(res.expect("task should have completed successfully"), k);
61     }
62 
63     for was_seen in &seen {
64         assert!(was_seen);
65     }
66     assert!(matches!(map.join_next().await, None));
67 }
68 
69 #[tokio::test]
test_abort_on_drop()70 async fn test_abort_on_drop() {
71     let mut map = JoinMap::new();
72 
73     let mut recvs = Vec::new();
74 
75     for i in 0..16 {
76         let (send, recv) = oneshot::channel::<()>();
77         recvs.push(recv);
78 
79         map.spawn(i, async {
80             // This task will never complete on its own.
81             futures::future::pending::<()>().await;
82             drop(send);
83         });
84     }
85 
86     drop(map);
87 
88     for recv in recvs {
89         // The task is aborted soon and we will receive an error.
90         assert!(recv.await.is_err());
91     }
92 }
93 
94 #[tokio::test]
alternating()95 async fn alternating() {
96     let mut map = JoinMap::new();
97 
98     assert_eq!(map.len(), 0);
99     map.spawn(1, async {});
100     assert_eq!(map.len(), 1);
101     map.spawn(2, async {});
102     assert_eq!(map.len(), 2);
103 
104     for i in 0..16 {
105         let (_, res) = map.join_next().await.unwrap();
106         assert!(res.is_ok());
107         assert_eq!(map.len(), 1);
108         map.spawn(i, async {});
109         assert_eq!(map.len(), 2);
110     }
111 }
112 
113 #[tokio::test]
test_keys()114 async fn test_keys() {
115     use std::collections::HashSet;
116 
117     let mut map = JoinMap::new();
118 
119     assert_eq!(map.len(), 0);
120     map.spawn(1, async {});
121     assert_eq!(map.len(), 1);
122     map.spawn(2, async {});
123     assert_eq!(map.len(), 2);
124 
125     let keys = map.keys().collect::<HashSet<&u32>>();
126     assert!(keys.contains(&1));
127     assert!(keys.contains(&2));
128 
129     let _ = map.join_next().await.unwrap();
130     let _ = map.join_next().await.unwrap();
131 
132     assert_eq!(map.len(), 0);
133     let keys = map.keys().collect::<HashSet<&u32>>();
134     assert!(keys.is_empty());
135 }
136 
137 #[tokio::test(start_paused = true)]
abort_by_key()138 async fn abort_by_key() {
139     let mut map = JoinMap::new();
140     let mut num_canceled = 0;
141     let mut num_completed = 0;
142     for i in 0..16 {
143         map.spawn(i, async move {
144             tokio::time::sleep(Duration::from_secs(i as u64)).await;
145         });
146     }
147 
148     for i in 0..16 {
149         if i % 2 != 0 {
150             // abort odd-numbered tasks.
151             map.abort(&i);
152         }
153     }
154 
155     while let Some((key, res)) = map.join_next().await {
156         match res {
157             Ok(()) => {
158                 num_completed += 1;
159                 assert_eq!(key % 2, 0);
160                 assert!(!map.contains_key(&key));
161             }
162             Err(e) => {
163                 num_canceled += 1;
164                 assert!(e.is_cancelled());
165                 assert_ne!(key % 2, 0);
166                 assert!(!map.contains_key(&key));
167             }
168         }
169     }
170 
171     assert_eq!(num_canceled, 8);
172     assert_eq!(num_completed, 8);
173 }
174 
175 #[tokio::test(start_paused = true)]
abort_by_predicate()176 async fn abort_by_predicate() {
177     let mut map = JoinMap::new();
178     let mut num_canceled = 0;
179     let mut num_completed = 0;
180     for i in 0..16 {
181         map.spawn(i, async move {
182             tokio::time::sleep(Duration::from_secs(i as u64)).await;
183         });
184     }
185 
186     // abort odd-numbered tasks.
187     map.abort_matching(|key| key % 2 != 0);
188 
189     while let Some((key, res)) = map.join_next().await {
190         match res {
191             Ok(()) => {
192                 num_completed += 1;
193                 assert_eq!(key % 2, 0);
194                 assert!(!map.contains_key(&key));
195             }
196             Err(e) => {
197                 num_canceled += 1;
198                 assert!(e.is_cancelled());
199                 assert_ne!(key % 2, 0);
200                 assert!(!map.contains_key(&key));
201             }
202         }
203     }
204 
205     assert_eq!(num_canceled, 8);
206     assert_eq!(num_completed, 8);
207 }
208 
209 #[test]
runtime_gone()210 fn runtime_gone() {
211     let mut map = JoinMap::new();
212     {
213         let rt = rt();
214         map.spawn_on("key", async { 1 }, rt.handle());
215         drop(rt);
216     }
217 
218     let (key, res) = rt().block_on(map.join_next()).unwrap();
219     assert_eq!(key, "key");
220     assert!(res.unwrap_err().is_cancelled());
221 }
222 
223 // This ensures that `join_next` works correctly when the coop budget is
224 // exhausted.
225 #[tokio::test(flavor = "current_thread")]
join_map_coop()226 async fn join_map_coop() {
227     // Large enough to trigger coop.
228     const TASK_NUM: u32 = 1000;
229 
230     static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0);
231 
232     let mut map = JoinMap::new();
233 
234     for i in 0..TASK_NUM {
235         map.spawn(i, async move {
236             SEM.add_permits(1);
237             i
238         });
239     }
240 
241     // Wait for all tasks to complete.
242     //
243     // Since this is a `current_thread` runtime, there's no race condition
244     // between the last permit being added and the task completing.
245     let _ = SEM.acquire_many(TASK_NUM).await.unwrap();
246 
247     let mut count = 0;
248     let mut coop_count = 0;
249     loop {
250         match map.join_next().now_or_never() {
251             Some(Some((key, Ok(i)))) => assert_eq!(key, i),
252             Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err),
253             None => {
254                 coop_count += 1;
255                 tokio::task::yield_now().await;
256                 continue;
257             }
258             Some(None) => break,
259         }
260 
261         count += 1;
262     }
263     assert!(coop_count >= 1);
264     assert_eq!(count, TASK_NUM);
265 }
266 
267 #[tokio::test(start_paused = true)]
abort_all()268 async fn abort_all() {
269     let mut map: JoinMap<usize, ()> = JoinMap::new();
270 
271     for i in 0..5 {
272         map.spawn(i, futures::future::pending());
273     }
274     for i in 5..10 {
275         map.spawn(i, async {
276             tokio::time::sleep(Duration::from_secs(1)).await;
277         });
278     }
279 
280     // The join map will now have 5 pending tasks and 5 ready tasks.
281     tokio::time::sleep(Duration::from_secs(2)).await;
282 
283     map.abort_all();
284     assert_eq!(map.len(), 10);
285 
286     let mut count = 0;
287     let mut seen = [false; 10];
288     while let Some((k, res)) = map.join_next().await {
289         seen[k] = true;
290         if let Err(err) = res {
291             assert!(err.is_cancelled());
292         }
293         count += 1;
294     }
295     assert_eq!(count, 10);
296     assert_eq!(map.len(), 0);
297     for was_seen in &seen {
298         assert!(was_seen);
299     }
300 }
301