xref: /aosp_15_r20/external/eigen/unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
12 
13 namespace Eigen {
14 
15 // EventCount allows to wait for arbitrary predicates in non-blocking
16 // algorithms. Think of condition variable, but wait predicate does not need to
17 // be protected by a mutex. Usage:
18 // Waiting thread does:
19 //
20 //   if (predicate)
21 //     return act();
22 //   EventCount::Waiter& w = waiters[my_index];
23 //   ec.Prewait(&w);
24 //   if (predicate) {
25 //     ec.CancelWait(&w);
26 //     return act();
27 //   }
28 //   ec.CommitWait(&w);
29 //
30 // Notifying thread does:
31 //
32 //   predicate = true;
33 //   ec.Notify(true);
34 //
35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
36 // cheap, but they are executed only if the preceding predicate check has
37 // failed.
38 //
39 // Algorithm outline:
40 // There are two main variables: predicate (managed by user) and state_.
41 // Operation closely resembles Dekker mutual algorithm:
42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
43 // Waiting thread sets state_ then checks predicate, Notifying thread sets
44 // predicate then checks state_. Due to seq_cst fences in between these
45 // operations it is guaranteed than either waiter will see predicate change
46 // and won't block, or notifying thread will see state_ change and will unblock
47 // the waiter, or both. But it can't happen that both threads don't see each
48 // other changes, which would lead to deadlock.
49 class EventCount {
50  public:
51   class Waiter;
52 
EventCount(MaxSizeVector<Waiter> & waiters)53   EventCount(MaxSizeVector<Waiter>& waiters)
54       : state_(kStackMask), waiters_(waiters) {
55     eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
56   }
57 
~EventCount()58   ~EventCount() {
59     // Ensure there are no waiters.
60     eigen_plain_assert(state_.load() == kStackMask);
61   }
62 
63   // Prewait prepares for waiting.
64   // After calling Prewait, the thread must re-check the wait predicate
65   // and then call either CancelWait or CommitWait.
Prewait()66   void Prewait() {
67     uint64_t state = state_.load(std::memory_order_relaxed);
68     for (;;) {
69       CheckState(state);
70       uint64_t newstate = state + kWaiterInc;
71       CheckState(newstate);
72       if (state_.compare_exchange_weak(state, newstate,
73                                        std::memory_order_seq_cst))
74         return;
75     }
76   }
77 
78   // CommitWait commits waiting after Prewait.
CommitWait(Waiter * w)79   void CommitWait(Waiter* w) {
80     eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
81     w->state = Waiter::kNotSignaled;
82     const uint64_t me = (w - &waiters_[0]) | w->epoch;
83     uint64_t state = state_.load(std::memory_order_seq_cst);
84     for (;;) {
85       CheckState(state, true);
86       uint64_t newstate;
87       if ((state & kSignalMask) != 0) {
88         // Consume the signal and return immidiately.
89         newstate = state - kWaiterInc - kSignalInc;
90       } else {
91         // Remove this thread from pre-wait counter and add to the waiter stack.
92         newstate = ((state & kWaiterMask) - kWaiterInc) | me;
93         w->next.store(state & (kStackMask | kEpochMask),
94                       std::memory_order_relaxed);
95       }
96       CheckState(newstate);
97       if (state_.compare_exchange_weak(state, newstate,
98                                        std::memory_order_acq_rel)) {
99         if ((state & kSignalMask) == 0) {
100           w->epoch += kEpochInc;
101           Park(w);
102         }
103         return;
104       }
105     }
106   }
107 
108   // CancelWait cancels effects of the previous Prewait call.
CancelWait()109   void CancelWait() {
110     uint64_t state = state_.load(std::memory_order_relaxed);
111     for (;;) {
112       CheckState(state, true);
113       uint64_t newstate = state - kWaiterInc;
114       // We don't know if the thread was also notified or not,
115       // so we should not consume a signal unconditionaly.
116       // Only if number of waiters is equal to number of signals,
117       // we know that the thread was notified and we must take away the signal.
118       if (((state & kWaiterMask) >> kWaiterShift) ==
119           ((state & kSignalMask) >> kSignalShift))
120         newstate -= kSignalInc;
121       CheckState(newstate);
122       if (state_.compare_exchange_weak(state, newstate,
123                                        std::memory_order_acq_rel))
124         return;
125     }
126   }
127 
128   // Notify wakes one or all waiting threads.
129   // Must be called after changing the associated wait predicate.
Notify(bool notifyAll)130   void Notify(bool notifyAll) {
131     std::atomic_thread_fence(std::memory_order_seq_cst);
132     uint64_t state = state_.load(std::memory_order_acquire);
133     for (;;) {
134       CheckState(state);
135       const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
136       const uint64_t signals = (state & kSignalMask) >> kSignalShift;
137       // Easy case: no waiters.
138       if ((state & kStackMask) == kStackMask && waiters == signals) return;
139       uint64_t newstate;
140       if (notifyAll) {
141         // Empty wait stack and set signal to number of pre-wait threads.
142         newstate =
143             (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
144       } else if (signals < waiters) {
145         // There is a thread in pre-wait state, unblock it.
146         newstate = state + kSignalInc;
147       } else {
148         // Pop a waiter from list and unpark it.
149         Waiter* w = &waiters_[state & kStackMask];
150         uint64_t next = w->next.load(std::memory_order_relaxed);
151         newstate = (state & (kWaiterMask | kSignalMask)) | next;
152       }
153       CheckState(newstate);
154       if (state_.compare_exchange_weak(state, newstate,
155                                        std::memory_order_acq_rel)) {
156         if (!notifyAll && (signals < waiters))
157           return;  // unblocked pre-wait thread
158         if ((state & kStackMask) == kStackMask) return;
159         Waiter* w = &waiters_[state & kStackMask];
160         if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
161         Unpark(w);
162         return;
163       }
164     }
165   }
166 
167   class Waiter {
168     friend class EventCount;
169     // Align to 128 byte boundary to prevent false sharing with other Waiter
170     // objects in the same vector.
171     EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
172     std::mutex mu;
173     std::condition_variable cv;
174     uint64_t epoch = 0;
175     unsigned state = kNotSignaled;
176     enum {
177       kNotSignaled,
178       kWaiting,
179       kSignaled,
180     };
181   };
182 
183  private:
184   // State_ layout:
185   // - low kWaiterBits is a stack of waiters committed wait
186   //   (indexes in waiters_ array are used as stack elements,
187   //   kStackMask means empty stack).
188   // - next kWaiterBits is count of waiters in prewait state.
189   // - next kWaiterBits is count of pending signals.
190   // - remaining bits are ABA counter for the stack.
191   //   (stored in Waiter node and incremented on push).
192   static const uint64_t kWaiterBits = 14;
193   static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
194   static const uint64_t kWaiterShift = kWaiterBits;
195   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
196                                       << kWaiterShift;
197   static const uint64_t kWaiterInc = 1ull << kWaiterShift;
198   static const uint64_t kSignalShift = 2 * kWaiterBits;
199   static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
200                                       << kSignalShift;
201   static const uint64_t kSignalInc = 1ull << kSignalShift;
202   static const uint64_t kEpochShift = 3 * kWaiterBits;
203   static const uint64_t kEpochBits = 64 - kEpochShift;
204   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
205   static const uint64_t kEpochInc = 1ull << kEpochShift;
206   std::atomic<uint64_t> state_;
207   MaxSizeVector<Waiter>& waiters_;
208 
209   static void CheckState(uint64_t state, bool waiter = false) {
210     static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
211     const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
212     const uint64_t signals = (state & kSignalMask) >> kSignalShift;
213     eigen_plain_assert(waiters >= signals);
214     eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
215     eigen_plain_assert(!waiter || waiters > 0);
216     (void)waiters;
217     (void)signals;
218   }
219 
Park(Waiter * w)220   void Park(Waiter* w) {
221     std::unique_lock<std::mutex> lock(w->mu);
222     while (w->state != Waiter::kSignaled) {
223       w->state = Waiter::kWaiting;
224       w->cv.wait(lock);
225     }
226   }
227 
Unpark(Waiter * w)228   void Unpark(Waiter* w) {
229     for (Waiter* next; w; w = next) {
230       uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
231       next = wnext == kStackMask ? nullptr : &waiters_[wnext];
232       unsigned state;
233       {
234         std::unique_lock<std::mutex> lock(w->mu);
235         state = w->state;
236         w->state = Waiter::kSignaled;
237       }
238       // Avoid notifying if it wasn't waiting.
239       if (state == Waiter::kWaiting) w->cv.notify_one();
240     }
241   }
242 
243   EventCount(const EventCount&) = delete;
244   void operator=(const EventCount&) = delete;
245 };
246 
247 }  // namespace Eigen
248 
249 #endif  // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
250