1 use crate::loom::sync::Mutex;
2 use crate::sync::watch;
3 #[cfg(all(tokio_unstable, feature = "tracing"))]
4 use crate::util::trace;
5 
6 /// A barrier enables multiple tasks to synchronize the beginning of some computation.
7 ///
8 /// ```
9 /// # #[tokio::main]
10 /// # async fn main() {
11 /// use tokio::sync::Barrier;
12 /// use std::sync::Arc;
13 ///
14 /// let mut handles = Vec::with_capacity(10);
15 /// let barrier = Arc::new(Barrier::new(10));
16 /// for _ in 0..10 {
17 ///     let c = barrier.clone();
18 ///     // The same messages will be printed together.
19 ///     // You will NOT see any interleaving.
20 ///     handles.push(tokio::spawn(async move {
21 ///         println!("before wait");
22 ///         let wait_result = c.wait().await;
23 ///         println!("after wait");
24 ///         wait_result
25 ///     }));
26 /// }
27 ///
28 /// // Will not resolve until all "after wait" messages have been printed
29 /// let mut num_leaders = 0;
30 /// for handle in handles {
31 ///     let wait_result = handle.await.unwrap();
32 ///     if wait_result.is_leader() {
33 ///         num_leaders += 1;
34 ///     }
35 /// }
36 ///
37 /// // Exactly one barrier will resolve as the "leader"
38 /// assert_eq!(num_leaders, 1);
39 /// # }
40 /// ```
41 #[derive(Debug)]
42 pub struct Barrier {
43     state: Mutex<BarrierState>,
44     wait: watch::Receiver<usize>,
45     n: usize,
46     #[cfg(all(tokio_unstable, feature = "tracing"))]
47     resource_span: tracing::Span,
48 }
49 
50 #[derive(Debug)]
51 struct BarrierState {
52     waker: watch::Sender<usize>,
53     arrived: usize,
54     generation: usize,
55 }
56 
57 impl Barrier {
58     /// Creates a new barrier that can block a given number of tasks.
59     ///
60     /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all
61     /// tasks at once when the `n`th task calls `wait`.
62     #[track_caller]
new(mut n: usize) -> Barrier63     pub fn new(mut n: usize) -> Barrier {
64         let (waker, wait) = crate::sync::watch::channel(0);
65 
66         if n == 0 {
67             // if n is 0, it's not clear what behavior the user wants.
68             // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
69             // .wait() immediately unblocks, so we adopt that here as well.
70             n = 1;
71         }
72 
73         #[cfg(all(tokio_unstable, feature = "tracing"))]
74         let resource_span = {
75             let location = std::panic::Location::caller();
76             let resource_span = tracing::trace_span!(
77                 parent: None,
78                 "runtime.resource",
79                 concrete_type = "Barrier",
80                 kind = "Sync",
81                 loc.file = location.file(),
82                 loc.line = location.line(),
83                 loc.col = location.column(),
84             );
85 
86             resource_span.in_scope(|| {
87                 tracing::trace!(
88                     target: "runtime::resource::state_update",
89                     size = n,
90                 );
91 
92                 tracing::trace!(
93                     target: "runtime::resource::state_update",
94                     arrived = 0,
95                 )
96             });
97             resource_span
98         };
99 
100         Barrier {
101             state: Mutex::new(BarrierState {
102                 waker,
103                 arrived: 0,
104                 generation: 1,
105             }),
106             n,
107             wait,
108             #[cfg(all(tokio_unstable, feature = "tracing"))]
109             resource_span,
110         }
111     }
112 
113     /// Does not resolve until all tasks have rendezvoused here.
114     ///
115     /// Barriers are re-usable after all tasks have rendezvoused once, and can
116     /// be used continuously.
117     ///
118     /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
119     /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks
120     /// will receive a result that will return `false` from `is_leader`.
121     ///
122     /// # Cancel safety
123     ///
124     /// This method is not cancel safe.
wait(&self) -> BarrierWaitResult125     pub async fn wait(&self) -> BarrierWaitResult {
126         #[cfg(all(tokio_unstable, feature = "tracing"))]
127         return trace::async_op(
128             || self.wait_internal(),
129             self.resource_span.clone(),
130             "Barrier::wait",
131             "poll",
132             false,
133         )
134         .await;
135 
136         #[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
137         return self.wait_internal().await;
138     }
wait_internal(&self) -> BarrierWaitResult139     async fn wait_internal(&self) -> BarrierWaitResult {
140         crate::trace::async_trace_leaf().await;
141 
142         // NOTE: we are taking a _synchronous_ lock here.
143         // It is okay to do so because the critical section is fast and never yields, so it cannot
144         // deadlock even if another future is concurrently holding the lock.
145         // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than
146         // the asynchronous counter-parts, so we should use them where possible [citation needed].
147         // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
148         // a yield point, and thus marks the returned future as !Send.
149         let generation = {
150             let mut state = self.state.lock();
151             let generation = state.generation;
152             state.arrived += 1;
153             #[cfg(all(tokio_unstable, feature = "tracing"))]
154             tracing::trace!(
155                 target: "runtime::resource::state_update",
156                 arrived = 1,
157                 arrived.op = "add",
158             );
159             #[cfg(all(tokio_unstable, feature = "tracing"))]
160             tracing::trace!(
161                 target: "runtime::resource::async_op::state_update",
162                 arrived = true,
163             );
164             if state.arrived == self.n {
165                 #[cfg(all(tokio_unstable, feature = "tracing"))]
166                 tracing::trace!(
167                     target: "runtime::resource::async_op::state_update",
168                     is_leader = true,
169                 );
170                 // we are the leader for this generation
171                 // wake everyone, increment the generation, and return
172                 state
173                     .waker
174                     .send(state.generation)
175                     .expect("there is at least one receiver");
176                 state.arrived = 0;
177                 state.generation += 1;
178                 return BarrierWaitResult(true);
179             }
180 
181             generation
182         };
183 
184         // we're going to have to wait for the last of the generation to arrive
185         let mut wait = self.wait.clone();
186 
187         loop {
188             let _ = wait.changed().await;
189 
190             // note that the first time through the loop, this _will_ yield a generation
191             // immediately, since we cloned a receiver that has never seen any values.
192             if *wait.borrow() >= generation {
193                 break;
194             }
195         }
196 
197         BarrierWaitResult(false)
198     }
199 }
200 
201 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused.
202 #[derive(Debug, Clone)]
203 pub struct BarrierWaitResult(bool);
204 
205 impl BarrierWaitResult {
206     /// Returns `true` if this task from wait is the "leader task".
207     ///
208     /// Only one task will have `true` returned from their result, all other tasks will have
209     /// `false` returned.
is_leader(&self) -> bool210     pub fn is_leader(&self) -> bool {
211         self.0
212     }
213 }
214