1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2016 Dmitry Vyukov <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_ 11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_ 12 13 namespace Eigen { 14 15 // EventCount allows to wait for arbitrary predicates in non-blocking 16 // algorithms. Think of condition variable, but wait predicate does not need to 17 // be protected by a mutex. Usage: 18 // Waiting thread does: 19 // 20 // if (predicate) 21 // return act(); 22 // EventCount::Waiter& w = waiters[my_index]; 23 // ec.Prewait(&w); 24 // if (predicate) { 25 // ec.CancelWait(&w); 26 // return act(); 27 // } 28 // ec.CommitWait(&w); 29 // 30 // Notifying thread does: 31 // 32 // predicate = true; 33 // ec.Notify(true); 34 // 35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not 36 // cheap, but they are executed only if the preceding predicate check has 37 // failed. 38 // 39 // Algorithm outline: 40 // There are two main variables: predicate (managed by user) and state_. 41 // Operation closely resembles Dekker mutual algorithm: 42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm 43 // Waiting thread sets state_ then checks predicate, Notifying thread sets 44 // predicate then checks state_. Due to seq_cst fences in between these 45 // operations it is guaranteed than either waiter will see predicate change 46 // and won't block, or notifying thread will see state_ change and will unblock 47 // the waiter, or both. But it can't happen that both threads don't see each 48 // other changes, which would lead to deadlock. 49 class EventCount { 50 public: 51 class Waiter; 52 EventCount(MaxSizeVector<Waiter> & waiters)53 EventCount(MaxSizeVector<Waiter>& waiters) 54 : state_(kStackMask), waiters_(waiters) { 55 eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1); 56 } 57 ~EventCount()58 ~EventCount() { 59 // Ensure there are no waiters. 60 eigen_plain_assert(state_.load() == kStackMask); 61 } 62 63 // Prewait prepares for waiting. 64 // After calling Prewait, the thread must re-check the wait predicate 65 // and then call either CancelWait or CommitWait. Prewait()66 void Prewait() { 67 uint64_t state = state_.load(std::memory_order_relaxed); 68 for (;;) { 69 CheckState(state); 70 uint64_t newstate = state + kWaiterInc; 71 CheckState(newstate); 72 if (state_.compare_exchange_weak(state, newstate, 73 std::memory_order_seq_cst)) 74 return; 75 } 76 } 77 78 // CommitWait commits waiting after Prewait. CommitWait(Waiter * w)79 void CommitWait(Waiter* w) { 80 eigen_plain_assert((w->epoch & ~kEpochMask) == 0); 81 w->state = Waiter::kNotSignaled; 82 const uint64_t me = (w - &waiters_[0]) | w->epoch; 83 uint64_t state = state_.load(std::memory_order_seq_cst); 84 for (;;) { 85 CheckState(state, true); 86 uint64_t newstate; 87 if ((state & kSignalMask) != 0) { 88 // Consume the signal and return immidiately. 89 newstate = state - kWaiterInc - kSignalInc; 90 } else { 91 // Remove this thread from pre-wait counter and add to the waiter stack. 92 newstate = ((state & kWaiterMask) - kWaiterInc) | me; 93 w->next.store(state & (kStackMask | kEpochMask), 94 std::memory_order_relaxed); 95 } 96 CheckState(newstate); 97 if (state_.compare_exchange_weak(state, newstate, 98 std::memory_order_acq_rel)) { 99 if ((state & kSignalMask) == 0) { 100 w->epoch += kEpochInc; 101 Park(w); 102 } 103 return; 104 } 105 } 106 } 107 108 // CancelWait cancels effects of the previous Prewait call. CancelWait()109 void CancelWait() { 110 uint64_t state = state_.load(std::memory_order_relaxed); 111 for (;;) { 112 CheckState(state, true); 113 uint64_t newstate = state - kWaiterInc; 114 // We don't know if the thread was also notified or not, 115 // so we should not consume a signal unconditionaly. 116 // Only if number of waiters is equal to number of signals, 117 // we know that the thread was notified and we must take away the signal. 118 if (((state & kWaiterMask) >> kWaiterShift) == 119 ((state & kSignalMask) >> kSignalShift)) 120 newstate -= kSignalInc; 121 CheckState(newstate); 122 if (state_.compare_exchange_weak(state, newstate, 123 std::memory_order_acq_rel)) 124 return; 125 } 126 } 127 128 // Notify wakes one or all waiting threads. 129 // Must be called after changing the associated wait predicate. Notify(bool notifyAll)130 void Notify(bool notifyAll) { 131 std::atomic_thread_fence(std::memory_order_seq_cst); 132 uint64_t state = state_.load(std::memory_order_acquire); 133 for (;;) { 134 CheckState(state); 135 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; 136 const uint64_t signals = (state & kSignalMask) >> kSignalShift; 137 // Easy case: no waiters. 138 if ((state & kStackMask) == kStackMask && waiters == signals) return; 139 uint64_t newstate; 140 if (notifyAll) { 141 // Empty wait stack and set signal to number of pre-wait threads. 142 newstate = 143 (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask; 144 } else if (signals < waiters) { 145 // There is a thread in pre-wait state, unblock it. 146 newstate = state + kSignalInc; 147 } else { 148 // Pop a waiter from list and unpark it. 149 Waiter* w = &waiters_[state & kStackMask]; 150 uint64_t next = w->next.load(std::memory_order_relaxed); 151 newstate = (state & (kWaiterMask | kSignalMask)) | next; 152 } 153 CheckState(newstate); 154 if (state_.compare_exchange_weak(state, newstate, 155 std::memory_order_acq_rel)) { 156 if (!notifyAll && (signals < waiters)) 157 return; // unblocked pre-wait thread 158 if ((state & kStackMask) == kStackMask) return; 159 Waiter* w = &waiters_[state & kStackMask]; 160 if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed); 161 Unpark(w); 162 return; 163 } 164 } 165 } 166 167 class Waiter { 168 friend class EventCount; 169 // Align to 128 byte boundary to prevent false sharing with other Waiter 170 // objects in the same vector. 171 EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next; 172 std::mutex mu; 173 std::condition_variable cv; 174 uint64_t epoch = 0; 175 unsigned state = kNotSignaled; 176 enum { 177 kNotSignaled, 178 kWaiting, 179 kSignaled, 180 }; 181 }; 182 183 private: 184 // State_ layout: 185 // - low kWaiterBits is a stack of waiters committed wait 186 // (indexes in waiters_ array are used as stack elements, 187 // kStackMask means empty stack). 188 // - next kWaiterBits is count of waiters in prewait state. 189 // - next kWaiterBits is count of pending signals. 190 // - remaining bits are ABA counter for the stack. 191 // (stored in Waiter node and incremented on push). 192 static const uint64_t kWaiterBits = 14; 193 static const uint64_t kStackMask = (1ull << kWaiterBits) - 1; 194 static const uint64_t kWaiterShift = kWaiterBits; 195 static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) 196 << kWaiterShift; 197 static const uint64_t kWaiterInc = 1ull << kWaiterShift; 198 static const uint64_t kSignalShift = 2 * kWaiterBits; 199 static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1) 200 << kSignalShift; 201 static const uint64_t kSignalInc = 1ull << kSignalShift; 202 static const uint64_t kEpochShift = 3 * kWaiterBits; 203 static const uint64_t kEpochBits = 64 - kEpochShift; 204 static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; 205 static const uint64_t kEpochInc = 1ull << kEpochShift; 206 std::atomic<uint64_t> state_; 207 MaxSizeVector<Waiter>& waiters_; 208 209 static void CheckState(uint64_t state, bool waiter = false) { 210 static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem"); 211 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift; 212 const uint64_t signals = (state & kSignalMask) >> kSignalShift; 213 eigen_plain_assert(waiters >= signals); 214 eigen_plain_assert(waiters < (1 << kWaiterBits) - 1); 215 eigen_plain_assert(!waiter || waiters > 0); 216 (void)waiters; 217 (void)signals; 218 } 219 Park(Waiter * w)220 void Park(Waiter* w) { 221 std::unique_lock<std::mutex> lock(w->mu); 222 while (w->state != Waiter::kSignaled) { 223 w->state = Waiter::kWaiting; 224 w->cv.wait(lock); 225 } 226 } 227 Unpark(Waiter * w)228 void Unpark(Waiter* w) { 229 for (Waiter* next; w; w = next) { 230 uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask; 231 next = wnext == kStackMask ? nullptr : &waiters_[wnext]; 232 unsigned state; 233 { 234 std::unique_lock<std::mutex> lock(w->mu); 235 state = w->state; 236 w->state = Waiter::kSignaled; 237 } 238 // Avoid notifying if it wasn't waiting. 239 if (state == Waiter::kWaiting) w->cv.notify_one(); 240 } 241 } 242 243 EventCount(const EventCount&) = delete; 244 void operator=(const EventCount&) = delete; 245 }; 246 247 } // namespace Eigen 248 249 #endif // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_ 250