1 #![allow(unknown_lints, unexpected_cfgs)]
2 #![warn(rust_2018_idioms)]
3 #![cfg(feature = "full")]
4 
5 use futures::{
6     future::{pending, ready},
7     FutureExt,
8 };
9 
10 use tokio::runtime;
11 use tokio::sync::{mpsc, oneshot};
12 use tokio::task::{self, LocalSet};
13 use tokio::time;
14 
15 #[cfg(not(target_os = "wasi"))]
16 use std::cell::Cell;
17 use std::sync::atomic::AtomicBool;
18 #[cfg(not(target_os = "wasi"))]
19 use std::sync::atomic::AtomicUsize;
20 use std::sync::atomic::Ordering;
21 #[cfg(not(target_os = "wasi"))]
22 use std::sync::atomic::Ordering::SeqCst;
23 use std::time::Duration;
24 
25 #[tokio::test(flavor = "current_thread")]
local_current_thread_scheduler()26 async fn local_current_thread_scheduler() {
27     LocalSet::new()
28         .run_until(async {
29             task::spawn_local(async {}).await.unwrap();
30         })
31         .await;
32 }
33 
34 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
35 #[tokio::test(flavor = "multi_thread")]
local_threadpool()36 async fn local_threadpool() {
37     thread_local! {
38         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
39     }
40 
41     ON_RT_THREAD.with(|cell| cell.set(true));
42 
43     LocalSet::new()
44         .run_until(async {
45             assert!(ON_RT_THREAD.with(|cell| cell.get()));
46             task::spawn_local(async {
47                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
48             })
49             .await
50             .unwrap();
51         })
52         .await;
53 }
54 
55 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
56 #[tokio::test(flavor = "multi_thread")]
localset_future_threadpool()57 async fn localset_future_threadpool() {
58     thread_local! {
59         static ON_LOCAL_THREAD: Cell<bool> = const { Cell::new(false) };
60     }
61 
62     ON_LOCAL_THREAD.with(|cell| cell.set(true));
63 
64     let local = LocalSet::new();
65     local.spawn_local(async move {
66         assert!(ON_LOCAL_THREAD.with(|cell| cell.get()));
67     });
68     local.await;
69 }
70 
71 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
72 #[tokio::test(flavor = "multi_thread")]
localset_future_timers()73 async fn localset_future_timers() {
74     static RAN1: AtomicBool = AtomicBool::new(false);
75     static RAN2: AtomicBool = AtomicBool::new(false);
76 
77     let local = LocalSet::new();
78     local.spawn_local(async move {
79         time::sleep(Duration::from_millis(5)).await;
80         RAN1.store(true, Ordering::SeqCst);
81     });
82     local.spawn_local(async move {
83         time::sleep(Duration::from_millis(10)).await;
84         RAN2.store(true, Ordering::SeqCst);
85     });
86     local.await;
87     assert!(RAN1.load(Ordering::SeqCst));
88     assert!(RAN2.load(Ordering::SeqCst));
89 }
90 
91 #[tokio::test]
localset_future_drives_all_local_futs()92 async fn localset_future_drives_all_local_futs() {
93     static RAN1: AtomicBool = AtomicBool::new(false);
94     static RAN2: AtomicBool = AtomicBool::new(false);
95     static RAN3: AtomicBool = AtomicBool::new(false);
96 
97     let local = LocalSet::new();
98     local.spawn_local(async move {
99         task::spawn_local(async {
100             task::yield_now().await;
101             RAN3.store(true, Ordering::SeqCst);
102         });
103         task::yield_now().await;
104         RAN1.store(true, Ordering::SeqCst);
105     });
106     local.spawn_local(async move {
107         task::yield_now().await;
108         RAN2.store(true, Ordering::SeqCst);
109     });
110     local.await;
111     assert!(RAN1.load(Ordering::SeqCst));
112     assert!(RAN2.load(Ordering::SeqCst));
113     assert!(RAN3.load(Ordering::SeqCst));
114 }
115 
116 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
117 #[tokio::test(flavor = "multi_thread")]
local_threadpool_timer()118 async fn local_threadpool_timer() {
119     // This test ensures that runtime services like the timer are properly
120     // set for the local task set.
121     thread_local! {
122         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
123     }
124 
125     ON_RT_THREAD.with(|cell| cell.set(true));
126 
127     LocalSet::new()
128         .run_until(async {
129             assert!(ON_RT_THREAD.with(|cell| cell.get()));
130             let join = task::spawn_local(async move {
131                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
132                 time::sleep(Duration::from_millis(10)).await;
133                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
134             });
135             join.await.unwrap();
136         })
137         .await;
138 }
139 #[test]
enter_guard_spawn()140 fn enter_guard_spawn() {
141     let local = LocalSet::new();
142     let _guard = local.enter();
143     // Run the local task set.
144 
145     let join = task::spawn_local(async { true });
146     let rt = runtime::Builder::new_current_thread()
147         .enable_all()
148         .build()
149         .unwrap();
150     local.block_on(&rt, async move {
151         assert!(join.await.unwrap());
152     });
153 }
154 
155 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
156 #[test]
157 // This will panic, since the thread that calls `block_on` cannot use
158 // in-place blocking inside of `block_on`.
159 #[should_panic]
local_threadpool_blocking_in_place()160 fn local_threadpool_blocking_in_place() {
161     thread_local! {
162         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
163     }
164 
165     ON_RT_THREAD.with(|cell| cell.set(true));
166 
167     let rt = runtime::Builder::new_current_thread()
168         .enable_all()
169         .build()
170         .unwrap();
171     LocalSet::new().block_on(&rt, async {
172         assert!(ON_RT_THREAD.with(|cell| cell.get()));
173         let join = task::spawn_local(async move {
174             assert!(ON_RT_THREAD.with(|cell| cell.get()));
175             task::block_in_place(|| {});
176             assert!(ON_RT_THREAD.with(|cell| cell.get()));
177         });
178         join.await.unwrap();
179     });
180 }
181 
182 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
183 #[tokio::test(flavor = "multi_thread")]
local_threadpool_blocking_run()184 async fn local_threadpool_blocking_run() {
185     thread_local! {
186         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
187     }
188 
189     ON_RT_THREAD.with(|cell| cell.set(true));
190 
191     LocalSet::new()
192         .run_until(async {
193             assert!(ON_RT_THREAD.with(|cell| cell.get()));
194             let join = task::spawn_local(async move {
195                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
196                 task::spawn_blocking(|| {
197                     assert!(
198                         !ON_RT_THREAD.with(|cell| cell.get()),
199                         "blocking must not run on the local task set's thread"
200                     );
201                 })
202                 .await
203                 .unwrap();
204                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
205             });
206             join.await.unwrap();
207         })
208         .await;
209 }
210 
211 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
212 #[tokio::test(flavor = "multi_thread")]
all_spawns_are_local()213 async fn all_spawns_are_local() {
214     use futures::future;
215     thread_local! {
216         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
217     }
218 
219     ON_RT_THREAD.with(|cell| cell.set(true));
220 
221     LocalSet::new()
222         .run_until(async {
223             assert!(ON_RT_THREAD.with(|cell| cell.get()));
224             let handles = (0..128)
225                 .map(|_| {
226                     task::spawn_local(async {
227                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
228                     })
229                 })
230                 .collect::<Vec<_>>();
231             for joined in future::join_all(handles).await {
232                 joined.unwrap();
233             }
234         })
235         .await;
236 }
237 
238 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
239 #[tokio::test(flavor = "multi_thread")]
nested_spawn_is_local()240 async fn nested_spawn_is_local() {
241     thread_local! {
242         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
243     }
244 
245     ON_RT_THREAD.with(|cell| cell.set(true));
246 
247     LocalSet::new()
248         .run_until(async {
249             assert!(ON_RT_THREAD.with(|cell| cell.get()));
250             task::spawn_local(async {
251                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
252                 task::spawn_local(async {
253                     assert!(ON_RT_THREAD.with(|cell| cell.get()));
254                     task::spawn_local(async {
255                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
256                         task::spawn_local(async {
257                             assert!(ON_RT_THREAD.with(|cell| cell.get()));
258                         })
259                         .await
260                         .unwrap();
261                     })
262                     .await
263                     .unwrap();
264                 })
265                 .await
266                 .unwrap();
267             })
268             .await
269             .unwrap();
270         })
271         .await;
272 }
273 
274 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
275 #[test]
join_local_future_elsewhere()276 fn join_local_future_elsewhere() {
277     thread_local! {
278         static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
279     }
280 
281     ON_RT_THREAD.with(|cell| cell.set(true));
282 
283     let rt = runtime::Runtime::new().unwrap();
284     let local = LocalSet::new();
285     local.block_on(&rt, async move {
286         let (tx, rx) = oneshot::channel();
287         let join = task::spawn_local(async move {
288             assert!(
289                 ON_RT_THREAD.with(|cell| cell.get()),
290                 "local task must run on local thread, no matter where it is awaited"
291             );
292             rx.await.unwrap();
293 
294             "hello world"
295         });
296         let join2 = task::spawn(async move {
297             assert!(
298                 !ON_RT_THREAD.with(|cell| cell.get()),
299                 "spawned task should be on a worker"
300             );
301 
302             tx.send(()).expect("task shouldn't have ended yet");
303 
304             join.await.expect("task should complete successfully");
305         });
306         join2.await.unwrap()
307     });
308 }
309 
310 // Tests for <https://github.com/tokio-rs/tokio/issues/4973>
311 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
312 #[tokio::test(flavor = "multi_thread")]
localset_in_thread_local()313 async fn localset_in_thread_local() {
314     thread_local! {
315         static LOCAL_SET: LocalSet = LocalSet::new();
316     }
317 
318     // holds runtime thread until end of main fn.
319     let (_tx, rx) = oneshot::channel::<()>();
320     let handle = tokio::runtime::Handle::current();
321 
322     std::thread::spawn(move || {
323         LOCAL_SET.with(|local_set| {
324             handle.block_on(local_set.run_until(async move {
325                 let _ = rx.await;
326             }))
327         });
328     });
329 }
330 
331 #[test]
drop_cancels_tasks()332 fn drop_cancels_tasks() {
333     use std::rc::Rc;
334 
335     // This test reproduces issue #1842
336     let rt = rt();
337     let rc1 = Rc::new(());
338     let rc2 = rc1.clone();
339 
340     let (started_tx, started_rx) = oneshot::channel();
341 
342     let local = LocalSet::new();
343     local.spawn_local(async move {
344         // Move this in
345         let _rc2 = rc2;
346 
347         started_tx.send(()).unwrap();
348         futures::future::pending::<()>().await;
349     });
350 
351     local.block_on(&rt, async {
352         started_rx.await.unwrap();
353     });
354     drop(local);
355     drop(rt);
356 
357     assert_eq!(1, Rc::strong_count(&rc1));
358 }
359 
360 /// Runs a test function in a separate thread, and panics if the test does not
361 /// complete within the specified timeout, or if the test function panics.
362 ///
363 /// This is intended for running tests whose failure mode is a hang or infinite
364 /// loop that cannot be detected otherwise.
with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static)365 fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) {
366     use std::sync::mpsc::RecvTimeoutError;
367 
368     let (done_tx, done_rx) = std::sync::mpsc::channel();
369     let thread = std::thread::spawn(move || {
370         f();
371 
372         // Send a message on the channel so that the test thread can
373         // determine if we have entered an infinite loop:
374         done_tx.send(()).unwrap();
375     });
376 
377     // Since the failure mode of this test is an infinite loop, rather than
378     // something we can easily make assertions about, we'll run it in a
379     // thread. When the test thread finishes, it will send a message on a
380     // channel to this thread. We'll wait for that message with a fairly
381     // generous timeout, and if we don't receive it, we assume the test
382     // thread has hung.
383     //
384     // Note that it should definitely complete in under a minute, but just
385     // in case CI is slow, we'll give it a long timeout.
386     match done_rx.recv_timeout(timeout) {
387         Err(RecvTimeoutError::Timeout) => panic!(
388             "test did not complete within {timeout:?} seconds, \
389              we have (probably) entered an infinite loop!",
390         ),
391         // Did the test thread panic? We'll find out for sure when we `join`
392         // with it.
393         Err(RecvTimeoutError::Disconnected) => {}
394         // Test completed successfully!
395         Ok(()) => {}
396     }
397 
398     thread.join().expect("test thread should not panic!")
399 }
400 
401 #[cfg_attr(
402     target_os = "wasi",
403     ignore = "`unwrap()` in `with_timeout()` panics on Wasi"
404 )]
405 #[test]
drop_cancels_remote_tasks()406 fn drop_cancels_remote_tasks() {
407     // This test reproduces issue #1885.
408     with_timeout(Duration::from_secs(60), || {
409         let (tx, mut rx) = mpsc::channel::<()>(1024);
410 
411         let rt = rt();
412 
413         let local = LocalSet::new();
414         local.spawn_local(async move { while rx.recv().await.is_some() {} });
415         local.block_on(&rt, async {
416             time::sleep(Duration::from_millis(1)).await;
417         });
418 
419         drop(tx);
420 
421         // This enters an infinite loop if the remote notified tasks are not
422         // properly cancelled.
423         drop(local);
424     });
425 }
426 
427 #[cfg_attr(
428     target_os = "wasi",
429     ignore = "FIXME: `task::spawn_local().await.unwrap()` panics on Wasi"
430 )]
431 #[test]
local_tasks_wake_join_all()432 fn local_tasks_wake_join_all() {
433     // This test reproduces issue #2460.
434     with_timeout(Duration::from_secs(60), || {
435         use futures::future::join_all;
436         use tokio::task::LocalSet;
437 
438         let rt = rt();
439         let set = LocalSet::new();
440         let mut handles = Vec::new();
441 
442         for _ in 1..=128 {
443             handles.push(set.spawn_local(async move {
444                 tokio::task::spawn_local(async move {}).await.unwrap();
445             }));
446         }
447 
448         rt.block_on(set.run_until(join_all(handles)));
449     });
450 }
451 
452 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
453 #[test]
local_tasks_are_polled_after_tick()454 fn local_tasks_are_polled_after_tick() {
455     // This test depends on timing, so we run it up to five times.
456     for _ in 0..4 {
457         let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner);
458         if res.is_ok() {
459             // success
460             return;
461         }
462     }
463 
464     // Test failed 4 times. Try one more time without catching panics. If it
465     // fails again, the test fails.
466     local_tasks_are_polled_after_tick_inner();
467 }
468 
469 #[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
470 #[tokio::main(flavor = "current_thread")]
local_tasks_are_polled_after_tick_inner()471 async fn local_tasks_are_polled_after_tick_inner() {
472     // Reproduces issues #1899 and #1900
473 
474     static RX1: AtomicUsize = AtomicUsize::new(0);
475     static RX2: AtomicUsize = AtomicUsize::new(0);
476     const EXPECTED: usize = 500;
477 
478     RX1.store(0, SeqCst);
479     RX2.store(0, SeqCst);
480 
481     let (tx, mut rx) = mpsc::unbounded_channel();
482 
483     let local = LocalSet::new();
484 
485     local
486         .run_until(async {
487             let task2 = task::spawn(async move {
488                 // Wait a bit
489                 time::sleep(Duration::from_millis(10)).await;
490 
491                 let mut oneshots = Vec::with_capacity(EXPECTED);
492 
493                 // Send values
494                 for _ in 0..EXPECTED {
495                     let (oneshot_tx, oneshot_rx) = oneshot::channel();
496                     oneshots.push(oneshot_tx);
497                     tx.send(oneshot_rx).unwrap();
498                 }
499 
500                 time::sleep(Duration::from_millis(10)).await;
501 
502                 for tx in oneshots.drain(..) {
503                     tx.send(()).unwrap();
504                 }
505 
506                 loop {
507                     time::sleep(Duration::from_millis(20)).await;
508                     let rx1 = RX1.load(SeqCst);
509                     let rx2 = RX2.load(SeqCst);
510 
511                     if rx1 == EXPECTED && rx2 == EXPECTED {
512                         break;
513                     }
514                 }
515             });
516 
517             while let Some(oneshot) = rx.recv().await {
518                 RX1.fetch_add(1, SeqCst);
519 
520                 task::spawn_local(async move {
521                     oneshot.await.unwrap();
522                     RX2.fetch_add(1, SeqCst);
523                 });
524             }
525 
526             task2.await.unwrap();
527         })
528         .await;
529 }
530 
531 #[tokio::test]
acquire_mutex_in_drop()532 async fn acquire_mutex_in_drop() {
533     use futures::future::pending;
534 
535     let (tx1, rx1) = oneshot::channel();
536     let (tx2, rx2) = oneshot::channel();
537     let local = LocalSet::new();
538 
539     local.spawn_local(async move {
540         let _ = rx2.await;
541         unreachable!();
542     });
543 
544     local.spawn_local(async move {
545         let _ = rx1.await;
546         tx2.send(()).unwrap();
547         unreachable!();
548     });
549 
550     // Spawn a task that will never notify
551     local.spawn_local(async move {
552         pending::<()>().await;
553         tx1.send(()).unwrap();
554     });
555 
556     // Tick the loop
557     local
558         .run_until(async {
559             task::yield_now().await;
560         })
561         .await;
562 
563     // Drop the LocalSet
564     drop(local);
565 }
566 
567 #[tokio::test]
spawn_wakes_localset()568 async fn spawn_wakes_localset() {
569     let local = LocalSet::new();
570     futures::select! {
571         _ = local.run_until(pending::<()>()).fuse() => unreachable!(),
572         ret = async { local.spawn_local(ready(())).await.unwrap()}.fuse() => ret
573     }
574 }
575 
576 /// Checks that the task wakes up with `enter`.
577 /// Reproduces <https://github.com/tokio-rs/tokio/issues/5020>.
578 #[tokio::test]
sleep_with_local_enter_guard()579 async fn sleep_with_local_enter_guard() {
580     let local = LocalSet::new();
581     let _guard = local.enter();
582 
583     let (tx, rx) = oneshot::channel();
584 
585     local
586         .run_until(async move {
587             tokio::task::spawn_local(async move {
588                 time::sleep(Duration::ZERO).await;
589 
590                 tx.send(()).expect("failed to send");
591             });
592             assert_eq!(rx.await, Ok(()));
593         })
594         .await;
595 }
596 
597 #[test]
store_local_set_in_thread_local_with_runtime()598 fn store_local_set_in_thread_local_with_runtime() {
599     use tokio::runtime::Runtime;
600 
601     thread_local! {
602         static CURRENT: RtAndLocalSet = RtAndLocalSet::new();
603     }
604 
605     struct RtAndLocalSet {
606         rt: Runtime,
607         local: LocalSet,
608     }
609 
610     impl RtAndLocalSet {
611         fn new() -> RtAndLocalSet {
612             RtAndLocalSet {
613                 rt: tokio::runtime::Builder::new_current_thread()
614                     .enable_all()
615                     .build()
616                     .unwrap(),
617                 local: LocalSet::new(),
618             }
619         }
620 
621         async fn inner_method(&self) {
622             self.local
623                 .run_until(async move {
624                     tokio::task::spawn_local(async {});
625                 })
626                 .await
627         }
628 
629         fn method(&self) {
630             self.rt.block_on(self.inner_method());
631         }
632     }
633 
634     CURRENT.with(|f| {
635         f.method();
636     });
637 }
638 
639 #[cfg(tokio_unstable)]
640 mod unstable {
641     use tokio::runtime::UnhandledPanic;
642     use tokio::task::LocalSet;
643 
644     #[tokio::test]
645     #[should_panic(
646         expected = "a spawned task panicked and the LocalSet is configured to shutdown on unhandled panic"
647     )]
shutdown_on_panic()648     async fn shutdown_on_panic() {
649         LocalSet::new()
650             .unhandled_panic(UnhandledPanic::ShutdownRuntime)
651             .run_until(async {
652                 tokio::task::spawn_local(async {
653                     panic!("boom");
654                 });
655 
656                 futures::future::pending::<()>().await;
657             })
658             .await;
659     }
660 
661     // This test compares that, when the task driving `run_until` has already
662     // consumed budget, the `run_until` future has less budget than a "spawned"
663     // task.
664     //
665     // "Budget" is a fuzzy metric as the Tokio runtime is able to change values
666     // internally. This is why the test uses indirection to test this.
667     #[tokio::test]
run_until_does_not_get_own_budget()668     async fn run_until_does_not_get_own_budget() {
669         // Consume some budget
670         tokio::task::consume_budget().await;
671 
672         LocalSet::new()
673             .run_until(async {
674                 let spawned = tokio::spawn(async {
675                     let mut spawned_n = 0;
676 
677                     {
678                         let mut spawned = tokio_test::task::spawn(async {
679                             loop {
680                                 spawned_n += 1;
681                                 tokio::task::consume_budget().await;
682                             }
683                         });
684                         // Poll once
685                         assert!(!spawned.poll().is_ready());
686                     }
687 
688                     spawned_n
689                 });
690 
691                 let mut run_until_n = 0;
692                 {
693                     let mut run_until = tokio_test::task::spawn(async {
694                         loop {
695                             run_until_n += 1;
696                             tokio::task::consume_budget().await;
697                         }
698                     });
699                     // Poll once
700                     assert!(!run_until.poll().is_ready());
701                 }
702 
703                 let spawned_n = spawned.await.unwrap();
704                 assert_ne!(spawned_n, 0);
705                 assert_ne!(run_until_n, 0);
706                 assert!(spawned_n > run_until_n);
707             })
708             .await
709     }
710 }
711 
rt() -> runtime::Runtime712 fn rt() -> runtime::Runtime {
713     tokio::runtime::Builder::new_current_thread()
714         .enable_all()
715         .build()
716         .unwrap()
717 }
718