1 #![allow(unknown_lints, unexpected_cfgs)]
2 #![warn(rust_2018_idioms)]
3 #![cfg(all(feature = "full", not(target_os = "wasi")))]
4 #![cfg(tokio_unstable)]
5 
6 use tokio::io::{AsyncReadExt, AsyncWriteExt};
7 use tokio::net::{TcpListener, TcpStream};
8 use tokio::runtime;
9 use tokio::sync::oneshot;
10 use tokio_test::{assert_err, assert_ok};
11 
12 use std::future::{poll_fn, Future};
13 use std::pin::Pin;
14 use std::sync::atomic::AtomicUsize;
15 use std::sync::atomic::Ordering::Relaxed;
16 use std::sync::{mpsc, Arc, Mutex};
17 use std::task::{Context, Poll, Waker};
18 
19 macro_rules! cfg_metrics {
20     ($($t:tt)*) => {
21         #[cfg(all(tokio_unstable, target_has_atomic = "64"))]
22         {
23             $( $t )*
24         }
25     }
26 }
27 
28 #[test]
single_thread()29 fn single_thread() {
30     // No panic when starting a runtime w/ a single thread
31     let _ = runtime::Builder::new_multi_thread_alt()
32         .enable_all()
33         .worker_threads(1)
34         .build()
35         .unwrap();
36 }
37 
38 #[test]
39 #[ignore] // https://github.com/tokio-rs/tokio/issues/5995
many_oneshot_futures()40 fn many_oneshot_futures() {
41     // used for notifying the main thread
42     const NUM: usize = 1_000;
43 
44     for _ in 0..5 {
45         let (tx, rx) = mpsc::channel();
46 
47         let rt = rt();
48         let cnt = Arc::new(AtomicUsize::new(0));
49 
50         for _ in 0..NUM {
51             let cnt = cnt.clone();
52             let tx = tx.clone();
53 
54             rt.spawn(async move {
55                 let num = cnt.fetch_add(1, Relaxed) + 1;
56 
57                 if num == NUM {
58                     tx.send(()).unwrap();
59                 }
60             });
61         }
62 
63         rx.recv().unwrap();
64 
65         // Wait for the pool to shutdown
66         drop(rt);
67     }
68 }
69 
70 #[test]
spawn_two()71 fn spawn_two() {
72     let rt = rt();
73 
74     let out = rt.block_on(async {
75         let (tx, rx) = oneshot::channel();
76 
77         tokio::spawn(async move {
78             tokio::spawn(async move {
79                 tx.send("ZOMG").unwrap();
80             });
81         });
82 
83         assert_ok!(rx.await)
84     });
85 
86     assert_eq!(out, "ZOMG");
87 
88     cfg_metrics! {
89         let metrics = rt.metrics();
90         drop(rt);
91         assert_eq!(1, metrics.remote_schedule_count());
92 
93         let mut local = 0;
94         for i in 0..metrics.num_workers() {
95             local += metrics.worker_local_schedule_count(i);
96         }
97 
98         assert_eq!(1, local);
99     }
100 }
101 
102 #[test]
many_multishot_futures()103 fn many_multishot_futures() {
104     const CHAIN: usize = 200;
105     const CYCLES: usize = 5;
106     const TRACKS: usize = 50;
107 
108     for _ in 0..50 {
109         let rt = rt();
110         let mut start_txs = Vec::with_capacity(TRACKS);
111         let mut final_rxs = Vec::with_capacity(TRACKS);
112 
113         for _ in 0..TRACKS {
114             let (start_tx, mut chain_rx) = tokio::sync::mpsc::channel(10);
115 
116             for _ in 0..CHAIN {
117                 let (next_tx, next_rx) = tokio::sync::mpsc::channel(10);
118 
119                 // Forward all the messages
120                 rt.spawn(async move {
121                     while let Some(v) = chain_rx.recv().await {
122                         next_tx.send(v).await.unwrap();
123                     }
124                 });
125 
126                 chain_rx = next_rx;
127             }
128 
129             // This final task cycles if needed
130             let (final_tx, final_rx) = tokio::sync::mpsc::channel(10);
131             let cycle_tx = start_tx.clone();
132             let mut rem = CYCLES;
133 
134             rt.spawn(async move {
135                 for _ in 0..CYCLES {
136                     let msg = chain_rx.recv().await.unwrap();
137 
138                     rem -= 1;
139 
140                     if rem == 0 {
141                         final_tx.send(msg).await.unwrap();
142                     } else {
143                         cycle_tx.send(msg).await.unwrap();
144                     }
145                 }
146             });
147 
148             start_txs.push(start_tx);
149             final_rxs.push(final_rx);
150         }
151 
152         {
153             rt.block_on(async move {
154                 for start_tx in start_txs {
155                     start_tx.send("ping").await.unwrap();
156                 }
157 
158                 for mut final_rx in final_rxs {
159                     final_rx.recv().await.unwrap();
160                 }
161             });
162         }
163     }
164 }
165 
166 #[test]
lifo_slot_budget()167 fn lifo_slot_budget() {
168     async fn my_fn() {
169         spawn_another();
170     }
171 
172     fn spawn_another() {
173         tokio::spawn(my_fn());
174     }
175 
176     let rt = runtime::Builder::new_multi_thread_alt()
177         .enable_all()
178         .worker_threads(1)
179         .build()
180         .unwrap();
181 
182     let (send, recv) = oneshot::channel();
183 
184     rt.spawn(async move {
185         tokio::spawn(my_fn());
186         let _ = send.send(());
187     });
188 
189     let _ = rt.block_on(recv);
190 }
191 
192 #[test]
spawn_shutdown()193 fn spawn_shutdown() {
194     let rt = rt();
195     let (tx, rx) = mpsc::channel();
196 
197     rt.block_on(async {
198         tokio::spawn(client_server(tx.clone()));
199     });
200 
201     // Use spawner
202     rt.spawn(client_server(tx));
203 
204     assert_ok!(rx.recv());
205     assert_ok!(rx.recv());
206 
207     drop(rt);
208     assert_err!(rx.try_recv());
209 }
210 
client_server(tx: mpsc::Sender<()>)211 async fn client_server(tx: mpsc::Sender<()>) {
212     let server = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
213 
214     // Get the assigned address
215     let addr = assert_ok!(server.local_addr());
216 
217     // Spawn the server
218     tokio::spawn(async move {
219         // Accept a socket
220         let (mut socket, _) = server.accept().await.unwrap();
221 
222         // Write some data
223         socket.write_all(b"hello").await.unwrap();
224     });
225 
226     let mut client = TcpStream::connect(&addr).await.unwrap();
227 
228     let mut buf = vec![];
229     client.read_to_end(&mut buf).await.unwrap();
230 
231     assert_eq!(buf, b"hello");
232     tx.send(()).unwrap();
233 }
234 
235 #[test]
drop_threadpool_drops_futures()236 fn drop_threadpool_drops_futures() {
237     for _ in 0..1_000 {
238         let num_inc = Arc::new(AtomicUsize::new(0));
239         let num_dec = Arc::new(AtomicUsize::new(0));
240         let num_drop = Arc::new(AtomicUsize::new(0));
241 
242         struct Never(Arc<AtomicUsize>);
243 
244         impl Future for Never {
245             type Output = ();
246 
247             fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> {
248                 Poll::Pending
249             }
250         }
251 
252         impl Drop for Never {
253             fn drop(&mut self) {
254                 self.0.fetch_add(1, Relaxed);
255             }
256         }
257 
258         let a = num_inc.clone();
259         let b = num_dec.clone();
260 
261         let rt = runtime::Builder::new_multi_thread_alt()
262             .enable_all()
263             .on_thread_start(move || {
264                 a.fetch_add(1, Relaxed);
265             })
266             .on_thread_stop(move || {
267                 b.fetch_add(1, Relaxed);
268             })
269             .build()
270             .unwrap();
271 
272         rt.spawn(Never(num_drop.clone()));
273 
274         // Wait for the pool to shutdown
275         drop(rt);
276 
277         // Assert that only a single thread was spawned.
278         let a = num_inc.load(Relaxed);
279         assert!(a >= 1);
280 
281         // Assert that all threads shutdown
282         let b = num_dec.load(Relaxed);
283         assert_eq!(a, b);
284 
285         // Assert that the future was dropped
286         let c = num_drop.load(Relaxed);
287         assert_eq!(c, 1);
288     }
289 }
290 
291 #[test]
start_stop_callbacks_called()292 fn start_stop_callbacks_called() {
293     use std::sync::atomic::{AtomicUsize, Ordering};
294 
295     let after_start = Arc::new(AtomicUsize::new(0));
296     let before_stop = Arc::new(AtomicUsize::new(0));
297 
298     let after_inner = after_start.clone();
299     let before_inner = before_stop.clone();
300     let rt = tokio::runtime::Builder::new_multi_thread_alt()
301         .enable_all()
302         .on_thread_start(move || {
303             after_inner.clone().fetch_add(1, Ordering::Relaxed);
304         })
305         .on_thread_stop(move || {
306             before_inner.clone().fetch_add(1, Ordering::Relaxed);
307         })
308         .build()
309         .unwrap();
310 
311     let (tx, rx) = oneshot::channel();
312 
313     rt.spawn(async move {
314         assert_ok!(tx.send(()));
315     });
316 
317     assert_ok!(rt.block_on(rx));
318 
319     drop(rt);
320 
321     assert!(after_start.load(Ordering::Relaxed) > 0);
322     assert!(before_stop.load(Ordering::Relaxed) > 0);
323 }
324 
325 #[test]
blocking_task()326 fn blocking_task() {
327     // used for notifying the main thread
328     const NUM: usize = 1_000;
329 
330     for _ in 0..10 {
331         let (tx, rx) = mpsc::channel();
332 
333         let rt = rt();
334         let cnt = Arc::new(AtomicUsize::new(0));
335 
336         // there are four workers in the pool
337         // so, if we run 4 blocking tasks, we know that handoff must have happened
338         let block = Arc::new(std::sync::Barrier::new(5));
339         for _ in 0..4 {
340             let block = block.clone();
341             rt.spawn(async move {
342                 tokio::task::block_in_place(move || {
343                     block.wait();
344                     block.wait();
345                 })
346             });
347         }
348         block.wait();
349 
350         for _ in 0..NUM {
351             let cnt = cnt.clone();
352             let tx = tx.clone();
353 
354             rt.spawn(async move {
355                 let num = cnt.fetch_add(1, Relaxed) + 1;
356 
357                 if num == NUM {
358                     tx.send(()).unwrap();
359                 }
360             });
361         }
362 
363         rx.recv().unwrap();
364 
365         // Wait for the pool to shutdown
366         block.wait();
367     }
368 }
369 
370 #[test]
multi_threadpool()371 fn multi_threadpool() {
372     use tokio::sync::oneshot;
373 
374     let rt1 = rt();
375     let rt2 = rt();
376 
377     let (tx, rx) = oneshot::channel();
378     let (done_tx, done_rx) = mpsc::channel();
379 
380     rt2.spawn(async move {
381         rx.await.unwrap();
382         done_tx.send(()).unwrap();
383     });
384 
385     rt1.spawn(async move {
386         tx.send(()).unwrap();
387     });
388 
389     done_rx.recv().unwrap();
390 }
391 
392 // When `block_in_place` returns, it attempts to reclaim the yielded runtime
393 // worker. In this case, the remainder of the task is on the runtime worker and
394 // must take part in the cooperative task budgeting system.
395 //
396 // The test ensures that, when this happens, attempting to consume from a
397 // channel yields occasionally even if there are values ready to receive.
398 #[test]
coop_and_block_in_place()399 fn coop_and_block_in_place() {
400     let rt = tokio::runtime::Builder::new_multi_thread_alt()
401         // Setting max threads to 1 prevents another thread from claiming the
402         // runtime worker yielded as part of `block_in_place` and guarantees the
403         // same thread will reclaim the worker at the end of the
404         // `block_in_place` call.
405         .max_blocking_threads(1)
406         .build()
407         .unwrap();
408 
409     rt.block_on(async move {
410         let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
411 
412         // Fill the channel
413         for _ in 0..1024 {
414             tx.send(()).await.unwrap();
415         }
416 
417         drop(tx);
418 
419         tokio::spawn(async move {
420             // Block in place without doing anything
421             tokio::task::block_in_place(|| {});
422 
423             // Receive all the values, this should trigger a `Pending` as the
424             // coop limit will be reached.
425             poll_fn(|cx| {
426                 while let Poll::Ready(v) = {
427                     tokio::pin! {
428                         let fut = rx.recv();
429                     }
430 
431                     Pin::new(&mut fut).poll(cx)
432                 } {
433                     if v.is_none() {
434                         panic!("did not yield");
435                     }
436                 }
437 
438                 Poll::Ready(())
439             })
440             .await
441         })
442         .await
443         .unwrap();
444     });
445 }
446 
447 #[test]
yield_after_block_in_place()448 fn yield_after_block_in_place() {
449     let rt = tokio::runtime::Builder::new_multi_thread_alt()
450         .worker_threads(1)
451         .build()
452         .unwrap();
453 
454     rt.block_on(async {
455         tokio::spawn(async move {
456             // Block in place then enter a new runtime
457             tokio::task::block_in_place(|| {
458                 let rt = tokio::runtime::Builder::new_current_thread()
459                     .build()
460                     .unwrap();
461 
462                 rt.block_on(async {});
463             });
464 
465             // Yield, then complete
466             tokio::task::yield_now().await;
467         })
468         .await
469         .unwrap()
470     });
471 }
472 
473 // Testing this does not panic
474 #[test]
max_blocking_threads()475 fn max_blocking_threads() {
476     let _rt = tokio::runtime::Builder::new_multi_thread_alt()
477         .max_blocking_threads(1)
478         .build()
479         .unwrap();
480 }
481 
482 #[test]
483 #[should_panic]
max_blocking_threads_set_to_zero()484 fn max_blocking_threads_set_to_zero() {
485     let _rt = tokio::runtime::Builder::new_multi_thread_alt()
486         .max_blocking_threads(0)
487         .build()
488         .unwrap();
489 }
490 
491 /// Regression test for #6445.
492 ///
493 /// After #6445, setting `global_queue_interval` to 1 is now technically valid.
494 /// This test confirms that there is no regression in `multi_thread_runtime`
495 /// when global_queue_interval is set to 1.
496 #[test]
global_queue_interval_set_to_one()497 fn global_queue_interval_set_to_one() {
498     let rt = tokio::runtime::Builder::new_multi_thread_alt()
499         .global_queue_interval(1)
500         .build()
501         .unwrap();
502 
503     // Perform a simple work.
504     let cnt = Arc::new(AtomicUsize::new(0));
505     rt.block_on(async {
506         let mut set = tokio::task::JoinSet::new();
507         for _ in 0..10 {
508             let cnt = cnt.clone();
509             set.spawn(async move { cnt.fetch_add(1, Relaxed) });
510         }
511         while let Some(res) = set.join_next().await {
512             res.unwrap();
513         }
514     });
515     assert_eq!(cnt.load(Relaxed), 10);
516 }
517 
518 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
hang_on_shutdown()519 async fn hang_on_shutdown() {
520     let (sync_tx, sync_rx) = std::sync::mpsc::channel::<()>();
521     tokio::spawn(async move {
522         tokio::task::block_in_place(|| sync_rx.recv().ok());
523     });
524 
525     tokio::spawn(async {
526         tokio::time::sleep(std::time::Duration::from_secs(2)).await;
527         drop(sync_tx);
528     });
529     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
530 }
531 
532 /// Demonstrates tokio-rs/tokio#3869
533 #[test]
wake_during_shutdown()534 fn wake_during_shutdown() {
535     struct Shared {
536         waker: Option<Waker>,
537     }
538 
539     struct MyFuture {
540         shared: Arc<Mutex<Shared>>,
541         put_waker: bool,
542     }
543 
544     impl MyFuture {
545         fn new() -> (Self, Self) {
546             let shared = Arc::new(Mutex::new(Shared { waker: None }));
547             let f1 = MyFuture {
548                 shared: shared.clone(),
549                 put_waker: true,
550             };
551             let f2 = MyFuture {
552                 shared,
553                 put_waker: false,
554             };
555             (f1, f2)
556         }
557     }
558 
559     impl Future for MyFuture {
560         type Output = ();
561 
562         fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
563             let me = Pin::into_inner(self);
564             let mut lock = me.shared.lock().unwrap();
565             if me.put_waker {
566                 lock.waker = Some(cx.waker().clone());
567             }
568             Poll::Pending
569         }
570     }
571 
572     impl Drop for MyFuture {
573         fn drop(&mut self) {
574             let mut lock = self.shared.lock().unwrap();
575             if !self.put_waker {
576                 lock.waker.take().unwrap().wake();
577             }
578             drop(lock);
579         }
580     }
581 
582     let rt = tokio::runtime::Builder::new_multi_thread_alt()
583         .worker_threads(1)
584         .enable_all()
585         .build()
586         .unwrap();
587 
588     let (f1, f2) = MyFuture::new();
589 
590     rt.spawn(f1);
591     rt.spawn(f2);
592 
593     rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await });
594 }
595 
596 #[should_panic]
597 #[tokio::test]
test_block_in_place1()598 async fn test_block_in_place1() {
599     tokio::task::block_in_place(|| {});
600 }
601 
602 #[tokio::test(flavor = "multi_thread")]
test_block_in_place2()603 async fn test_block_in_place2() {
604     tokio::task::block_in_place(|| {});
605 }
606 
607 #[should_panic]
608 #[tokio::main(flavor = "current_thread")]
609 #[test]
test_block_in_place3()610 async fn test_block_in_place3() {
611     tokio::task::block_in_place(|| {});
612 }
613 
614 #[tokio::main]
615 #[test]
test_block_in_place4()616 async fn test_block_in_place4() {
617     tokio::task::block_in_place(|| {});
618 }
619 
620 // Testing the tuning logic is tricky as it is inherently timing based, and more
621 // of a heuristic than an exact behavior. This test checks that the interval
622 // changes over time based on load factors. There are no assertions, completion
623 // is sufficient. If there is a regression, this test will hang. In theory, we
624 // could add limits, but that would be likely to fail on CI.
625 #[test]
626 #[cfg(not(tokio_no_tuning_tests))]
test_tuning()627 fn test_tuning() {
628     use std::sync::atomic::AtomicBool;
629     use std::time::Duration;
630 
631     let rt = runtime::Builder::new_multi_thread_alt()
632         .worker_threads(1)
633         .build()
634         .unwrap();
635 
636     fn iter(flag: Arc<AtomicBool>, counter: Arc<AtomicUsize>, stall: bool) {
637         if flag.load(Relaxed) {
638             if stall {
639                 std::thread::sleep(Duration::from_micros(5));
640             }
641 
642             counter.fetch_add(1, Relaxed);
643             tokio::spawn(async move { iter(flag, counter, stall) });
644         }
645     }
646 
647     let flag = Arc::new(AtomicBool::new(true));
648     let counter = Arc::new(AtomicUsize::new(61));
649     let interval = Arc::new(AtomicUsize::new(61));
650 
651     {
652         let flag = flag.clone();
653         let counter = counter.clone();
654         rt.spawn(async move { iter(flag, counter, true) });
655     }
656 
657     // Now, hammer the injection queue until the interval drops.
658     let mut n = 0;
659     loop {
660         let curr = interval.load(Relaxed);
661 
662         if curr <= 8 {
663             n += 1;
664         } else {
665             n = 0;
666         }
667 
668         // Make sure we get a few good rounds. Jitter in the tuning could result
669         // in one "good" value without being representative of reaching a good
670         // state.
671         if n == 3 {
672             break;
673         }
674 
675         if Arc::strong_count(&interval) < 5_000 {
676             let counter = counter.clone();
677             let interval = interval.clone();
678 
679             rt.spawn(async move {
680                 let prev = counter.swap(0, Relaxed);
681                 interval.store(prev, Relaxed);
682             });
683 
684             std::thread::yield_now();
685         }
686     }
687 
688     flag.store(false, Relaxed);
689 
690     let w = Arc::downgrade(&interval);
691     drop(interval);
692 
693     while w.strong_count() > 0 {
694         std::thread::sleep(Duration::from_micros(500));
695     }
696 
697     // Now, run it again with a faster task
698     let flag = Arc::new(AtomicBool::new(true));
699     // Set it high, we know it shouldn't ever really be this high
700     let counter = Arc::new(AtomicUsize::new(10_000));
701     let interval = Arc::new(AtomicUsize::new(10_000));
702 
703     {
704         let flag = flag.clone();
705         let counter = counter.clone();
706         rt.spawn(async move { iter(flag, counter, false) });
707     }
708 
709     // Now, hammer the injection queue until the interval reaches the expected range.
710     let mut n = 0;
711     loop {
712         let curr = interval.load(Relaxed);
713 
714         if curr <= 1_000 && curr > 32 {
715             n += 1;
716         } else {
717             n = 0;
718         }
719 
720         if n == 3 {
721             break;
722         }
723 
724         if Arc::strong_count(&interval) <= 5_000 {
725             let counter = counter.clone();
726             let interval = interval.clone();
727 
728             rt.spawn(async move {
729                 let prev = counter.swap(0, Relaxed);
730                 interval.store(prev, Relaxed);
731             });
732         }
733 
734         std::thread::yield_now();
735     }
736 
737     flag.store(false, Relaxed);
738 }
739 
rt() -> runtime::Runtime740 fn rt() -> runtime::Runtime {
741     runtime::Builder::new_multi_thread_alt()
742         .enable_all()
743         .build()
744         .unwrap()
745 }
746