1 use crate::loom::sync::Mutex; 2 use crate::sync::watch; 3 #[cfg(all(tokio_unstable, feature = "tracing"))] 4 use crate::util::trace; 5 6 /// A barrier enables multiple tasks to synchronize the beginning of some computation. 7 /// 8 /// ``` 9 /// # #[tokio::main] 10 /// # async fn main() { 11 /// use tokio::sync::Barrier; 12 /// use std::sync::Arc; 13 /// 14 /// let mut handles = Vec::with_capacity(10); 15 /// let barrier = Arc::new(Barrier::new(10)); 16 /// for _ in 0..10 { 17 /// let c = barrier.clone(); 18 /// // The same messages will be printed together. 19 /// // You will NOT see any interleaving. 20 /// handles.push(tokio::spawn(async move { 21 /// println!("before wait"); 22 /// let wait_result = c.wait().await; 23 /// println!("after wait"); 24 /// wait_result 25 /// })); 26 /// } 27 /// 28 /// // Will not resolve until all "after wait" messages have been printed 29 /// let mut num_leaders = 0; 30 /// for handle in handles { 31 /// let wait_result = handle.await.unwrap(); 32 /// if wait_result.is_leader() { 33 /// num_leaders += 1; 34 /// } 35 /// } 36 /// 37 /// // Exactly one barrier will resolve as the "leader" 38 /// assert_eq!(num_leaders, 1); 39 /// # } 40 /// ``` 41 #[derive(Debug)] 42 pub struct Barrier { 43 state: Mutex<BarrierState>, 44 wait: watch::Receiver<usize>, 45 n: usize, 46 #[cfg(all(tokio_unstable, feature = "tracing"))] 47 resource_span: tracing::Span, 48 } 49 50 #[derive(Debug)] 51 struct BarrierState { 52 waker: watch::Sender<usize>, 53 arrived: usize, 54 generation: usize, 55 } 56 57 impl Barrier { 58 /// Creates a new barrier that can block a given number of tasks. 59 /// 60 /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all 61 /// tasks at once when the `n`th task calls `wait`. 62 #[track_caller] new(mut n: usize) -> Barrier63 pub fn new(mut n: usize) -> Barrier { 64 let (waker, wait) = crate::sync::watch::channel(0); 65 66 if n == 0 { 67 // if n is 0, it's not clear what behavior the user wants. 68 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every 69 // .wait() immediately unblocks, so we adopt that here as well. 70 n = 1; 71 } 72 73 #[cfg(all(tokio_unstable, feature = "tracing"))] 74 let resource_span = { 75 let location = std::panic::Location::caller(); 76 let resource_span = tracing::trace_span!( 77 parent: None, 78 "runtime.resource", 79 concrete_type = "Barrier", 80 kind = "Sync", 81 loc.file = location.file(), 82 loc.line = location.line(), 83 loc.col = location.column(), 84 ); 85 86 resource_span.in_scope(|| { 87 tracing::trace!( 88 target: "runtime::resource::state_update", 89 size = n, 90 ); 91 92 tracing::trace!( 93 target: "runtime::resource::state_update", 94 arrived = 0, 95 ) 96 }); 97 resource_span 98 }; 99 100 Barrier { 101 state: Mutex::new(BarrierState { 102 waker, 103 arrived: 0, 104 generation: 1, 105 }), 106 n, 107 wait, 108 #[cfg(all(tokio_unstable, feature = "tracing"))] 109 resource_span, 110 } 111 } 112 113 /// Does not resolve until all tasks have rendezvoused here. 114 /// 115 /// Barriers are re-usable after all tasks have rendezvoused once, and can 116 /// be used continuously. 117 /// 118 /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from 119 /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks 120 /// will receive a result that will return `false` from `is_leader`. 121 /// 122 /// # Cancel safety 123 /// 124 /// This method is not cancel safe. wait(&self) -> BarrierWaitResult125 pub async fn wait(&self) -> BarrierWaitResult { 126 #[cfg(all(tokio_unstable, feature = "tracing"))] 127 return trace::async_op( 128 || self.wait_internal(), 129 self.resource_span.clone(), 130 "Barrier::wait", 131 "poll", 132 false, 133 ) 134 .await; 135 136 #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] 137 return self.wait_internal().await; 138 } wait_internal(&self) -> BarrierWaitResult139 async fn wait_internal(&self) -> BarrierWaitResult { 140 crate::trace::async_trace_leaf().await; 141 142 // NOTE: we are taking a _synchronous_ lock here. 143 // It is okay to do so because the critical section is fast and never yields, so it cannot 144 // deadlock even if another future is concurrently holding the lock. 145 // It is _desirable_ to do so as synchronous Mutexes are, at least in theory, faster than 146 // the asynchronous counter-parts, so we should use them where possible [citation needed]. 147 // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across 148 // a yield point, and thus marks the returned future as !Send. 149 let generation = { 150 let mut state = self.state.lock(); 151 let generation = state.generation; 152 state.arrived += 1; 153 #[cfg(all(tokio_unstable, feature = "tracing"))] 154 tracing::trace!( 155 target: "runtime::resource::state_update", 156 arrived = 1, 157 arrived.op = "add", 158 ); 159 #[cfg(all(tokio_unstable, feature = "tracing"))] 160 tracing::trace!( 161 target: "runtime::resource::async_op::state_update", 162 arrived = true, 163 ); 164 if state.arrived == self.n { 165 #[cfg(all(tokio_unstable, feature = "tracing"))] 166 tracing::trace!( 167 target: "runtime::resource::async_op::state_update", 168 is_leader = true, 169 ); 170 // we are the leader for this generation 171 // wake everyone, increment the generation, and return 172 state 173 .waker 174 .send(state.generation) 175 .expect("there is at least one receiver"); 176 state.arrived = 0; 177 state.generation += 1; 178 return BarrierWaitResult(true); 179 } 180 181 generation 182 }; 183 184 // we're going to have to wait for the last of the generation to arrive 185 let mut wait = self.wait.clone(); 186 187 loop { 188 let _ = wait.changed().await; 189 190 // note that the first time through the loop, this _will_ yield a generation 191 // immediately, since we cloned a receiver that has never seen any values. 192 if *wait.borrow() >= generation { 193 break; 194 } 195 } 196 197 BarrierWaitResult(false) 198 } 199 } 200 201 /// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. 202 #[derive(Debug, Clone)] 203 pub struct BarrierWaitResult(bool); 204 205 impl BarrierWaitResult { 206 /// Returns `true` if this task from wait is the "leader task". 207 /// 208 /// Only one task will have `true` returned from their result, all other tasks will have 209 /// `false` returned. is_leader(&self) -> bool210 pub fn is_leader(&self) -> bool { 211 self.0 212 } 213 } 214