1 use super::{BlockingRegionGuard, SetCurrentGuard, CONTEXT};
2 
3 use crate::runtime::scheduler;
4 use crate::util::rand::{FastRand, RngSeed};
5 
6 use std::fmt;
7 
8 #[derive(Debug, Clone, Copy)]
9 #[must_use]
10 pub(crate) enum EnterRuntime {
11     /// Currently in a runtime context.
12     #[cfg_attr(not(feature = "rt"), allow(dead_code))]
13     Entered { allow_block_in_place: bool },
14 
15     /// Not in a runtime context **or** a blocking region.
16     NotEntered,
17 }
18 
19 /// Guard tracking that a caller has entered a runtime context.
20 #[must_use]
21 pub(crate) struct EnterRuntimeGuard {
22     /// Tracks that the current thread has entered a blocking function call.
23     pub(crate) blocking: BlockingRegionGuard,
24 
25     #[allow(dead_code)] // Only tracking the guard.
26     pub(crate) handle: SetCurrentGuard,
27 
28     // Tracks the previous random number generator seed
29     old_seed: RngSeed,
30 }
31 
32 /// Marks the current thread as being within the dynamic extent of an
33 /// executor.
34 #[track_caller]
enter_runtime<F, R>(handle: &scheduler::Handle, allow_block_in_place: bool, f: F) -> R where F: FnOnce(&mut BlockingRegionGuard) -> R,35 pub(crate) fn enter_runtime<F, R>(handle: &scheduler::Handle, allow_block_in_place: bool, f: F) -> R
36 where
37     F: FnOnce(&mut BlockingRegionGuard) -> R,
38 {
39     let maybe_guard = CONTEXT.with(|c| {
40         if c.runtime.get().is_entered() {
41             None
42         } else {
43             // Set the entered flag
44             c.runtime.set(EnterRuntime::Entered {
45                 allow_block_in_place,
46             });
47 
48             // Generate a new seed
49             let rng_seed = handle.seed_generator().next_seed();
50 
51             // Swap the RNG seed
52             let mut rng = c.rng.get().unwrap_or_else(FastRand::new);
53             let old_seed = rng.replace_seed(rng_seed);
54             c.rng.set(Some(rng));
55 
56             Some(EnterRuntimeGuard {
57                 blocking: BlockingRegionGuard::new(),
58                 handle: c.set_current(handle),
59                 old_seed,
60             })
61         }
62     });
63 
64     if let Some(mut guard) = maybe_guard {
65         return f(&mut guard.blocking);
66     }
67 
68     panic!(
69         "Cannot start a runtime from within a runtime. This happens \
70             because a function (like `block_on`) attempted to block the \
71             current thread while the thread is being used to drive \
72             asynchronous tasks."
73     );
74 }
75 
76 impl fmt::Debug for EnterRuntimeGuard {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result77     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78         f.debug_struct("Enter").finish()
79     }
80 }
81 
82 impl Drop for EnterRuntimeGuard {
drop(&mut self)83     fn drop(&mut self) {
84         CONTEXT.with(|c| {
85             assert!(c.runtime.get().is_entered());
86             c.runtime.set(EnterRuntime::NotEntered);
87             // Replace the previous RNG seed
88             let mut rng = c.rng.get().unwrap_or_else(FastRand::new);
89             rng.replace_seed(self.old_seed.clone());
90             c.rng.set(Some(rng));
91         });
92     }
93 }
94 
95 impl EnterRuntime {
is_entered(self) -> bool96     pub(crate) fn is_entered(self) -> bool {
97         matches!(self, EnterRuntime::Entered { .. })
98     }
99 }
100