xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/WorkersPool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <condition_variable>
5 #include <thread>
6 #include "c10/util/thread_name.h"
7 #include <c10/util/irange.h>
8 #include <c10/util/Logging.h>
9 
10 #if defined(_MSC_VER)
11 #include <intrin.h>
12 #endif
13 
14 namespace caffe2 {
15 
16 // Uses code derived from gemmlowp,
17 // https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h
18 // Changes:
19 // - allocation-free execute()
20 // - Use RAII where possible.
21 // - Run the first task on the main thread (since that is the largest task).
22 // - removed custom allocator.
23 // - Removed some ifdef's
24 // - cache-line align Worker.
25 // - use std::atomic instead of volatile and custom barriers.
26 // - use std::mutex/std::condition_variable instead of raw pthreads.
27 
28 constexpr size_t kGEMMLOWPCacheLineSize = 64;
29 
30 template <typename T>
31 struct AllocAligned {
32   // Allocate a T aligned at an `align` byte address
33   template <typename... Args>
allocAllocAligned34   static T* alloc(Args&&... args) {
35     void* p = nullptr;
36 
37 #if defined(__ANDROID__)
38     p = memalign(kGEMMLOWPCacheLineSize, sizeof(T));
39 #elif defined(_MSC_VER)
40     p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
41 #else
42     auto res = posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T));
43     (void)res;
44 #endif
45 
46     if (p) {
47       return new (p) T(std::forward<Args>(args)...);
48     }
49 
50     return nullptr;
51   }
52 
53   // Free a T previously allocated via AllocAligned<T>::alloc()
releaseAllocAligned54   static void release(T* p) {
55     if (p) {
56       p->~T();
57 #if defined(_MSC_VER)
58       _aligned_free((void*)p);
59 #else
60       free((void*)p);
61 #endif
62     }
63   }
64 };
65 
66 // Deleter object for unique_ptr for an aligned object
67 template <typename T>
68 struct AlignedDeleter {
operatorAlignedDeleter69   void operator()(T* p) const { AllocAligned<T>::release(p); }
70 };
71 
72 // make_unique that guarantees alignment
73 template <typename T>
74 struct MakeAligned {
75   template <typename... Args>
makeMakeAligned76   static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
77     return std::unique_ptr<T, AlignedDeleter<T>>(
78         AllocAligned<T>::alloc(std::forward<Args>(args)...));
79   }
80 };
81 
82 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
83 
84 #if defined(_MSC_VER)
85 #define GEMMLOWP_NOP __nop();
86 #else
87 #define GEMMLOWP_NOP "nop\n"
88 #endif
89 
90 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
91 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
92 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
93 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
94 
Do256NOPs()95 inline int Do256NOPs() {
96 #if defined(_MSC_VER)
97   GEMMLOWP_NOP64;
98 #else
99   asm volatile(GEMMLOWP_NOP64);
100 #endif
101   return 64;
102 }
103 
104 #undef GEMMLOWP_STRING_CONCAT_4
105 #undef GEMMLOWP_NOP256
106 #undef GEMMLOWP_NOP64
107 #undef GEMMLOWP_NOP16
108 #undef GEMMLOWP_NOP4
109 #undef GEMMLOWP_NOP
110 
111 // Waits until *var != initial_value.
112 //
113 // Returns the new value of *var. The guarantee here is that
114 // the return value is different from initial_value, and that that
115 // new value has been taken by *var at some point during the
116 // execution of this function. There is no guarantee that this is
117 // still the value of *var when this function returns, since *var is
118 // not assumed to be guarded by any lock.
119 //
120 // First does some busy-waiting for a fixed number of no-op cycles,
121 // then falls back to passive waiting for the given condvar, guarded
122 // by the given mutex.
123 //
124 // The idea of doing some initial busy-waiting is to help get
125 // better and more consistent multithreading benefits for small GEMM sizes.
126 // Busy-waiting help ensuring that if we need to wake up soon after having
127 // started waiting, then we can wake up quickly (as opposed to, say,
128 // having to wait to be scheduled again by the OS). On the other hand,
129 // we must still eventually revert to passive waiting for longer waits
130 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
131 // so as to avoid permanently spinning.
132 //
133 template <typename T>
WaitForVariableChange(std::atomic<T> * var,T initial_value,std::condition_variable * cond,std::mutex * mutex)134 T WaitForVariableChange(std::atomic<T>* var,
135                         T initial_value,
136                         std::condition_variable* cond,
137                         std::mutex* mutex) {
138   // If we are on a platform that supports it, spin for some time.
139   {
140     int nops = 0;
141     // First, trivial case where the variable already changed value.
142     T new_value = var->load(std::memory_order_relaxed);
143     if (new_value != initial_value) {
144       std::atomic_thread_fence(std::memory_order_acquire);
145       return new_value;
146     }
147     // Then try busy-waiting.
148     while (nops < kMaxBusyWaitNOPs) {
149       nops += Do256NOPs();
150       new_value = var->load(std::memory_order_relaxed);
151       if (new_value != initial_value) {
152         std::atomic_thread_fence(std::memory_order_acquire);
153         return new_value;
154       }
155     }
156   }
157 
158   // Finally, do real passive waiting.
159   {
160     std::unique_lock<std::mutex> g(*mutex);
161     T new_value = var->load(std::memory_order_relaxed);
162     // Handle spurious wakeups.
163     cond->wait(g, [&]() {
164       new_value = var->load(std::memory_order_relaxed);
165       return new_value != initial_value;
166     });
167     TORCH_DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
168     return new_value;
169   }
170 }
171 
172 // A BlockingCounter lets one thread to wait for N events to occur.
173 // This is how the master thread waits for all the worker threads
174 // to have finished working.
175 class BlockingCounter {
176  public:
177   // Sets/resets the counter; initial_count is the number of
178   // decrementing events that the Wait() call will be waiting for.
Reset(std::size_t initial_count)179   void Reset(std::size_t initial_count) {
180     std::lock_guard<std::mutex> g(mutex_);
181     TORCH_DCHECK_EQ(count_, 0);
182     count_ = initial_count;
183   }
184 
185   // Decrements the counter; if the counter hits zero, signals
186   // the thread that was waiting for that, and returns true.
187   // Otherwise (if the decremented count is still nonzero),
188   // returns false.
DecrementCount()189   bool DecrementCount() {
190     const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
191     TORCH_DCHECK_GE(count_value, 0);
192     if (count_value == 0) {
193       std::lock_guard<std::mutex> g(mutex_);
194       cond_.notify_one();
195     }
196     bool retval = count_value == 0;
197     return retval;
198   }
199 
200   // Waits for the N other threads (N having been set by Reset())
201   // to hit the BlockingCounter.
Wait()202   void Wait() {
203     while (size_t count_value = count_.load(std::memory_order_relaxed)) {
204       WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
205     }
206   }
207 
208  private:
209   std::condition_variable cond_;
210   std::mutex mutex_;
211   std::atomic<std::size_t> count_{0};
212 };
213 
214 // A workload for a worker.
215 struct Task {
216   Task() = default;
217   virtual ~Task() = default;
218   virtual void Run() = 0;
219 };
220 
221 // A worker thread.
222 class alignas(kGEMMLOWPCacheLineSize) Worker {
223  public:
224   enum class State : uint8_t {
225     ThreadStartup, // The initial state before the thread main loop runs.
226     Ready, // Is not working, has not yet received new work to do.
227     HasWork, // Has work to do.
228     ExitAsSoonAsPossible // Should exit at earliest convenience.
229   };
230 
Worker(BlockingCounter * counter_to_decrement_when_ready)231   explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
232       : task_(nullptr),
233         state_(State::ThreadStartup),
234         counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
235     thread_ = std::make_unique<std::thread>([this]() {
236       c10::setThreadName("pt_thread_pool");
237       this->ThreadFunc();
238     });
239   }
240 
~Worker()241   ~Worker() {
242     ChangeState(State::ExitAsSoonAsPossible);
243     thread_->join();
244   }
245 
246   // Changes State; may be called from either the worker thread
247   // or the master thread; however, not all state transitions are legal,
248   // which is guarded by assertions.
ChangeState(State new_state)249   void ChangeState(State new_state) {
250     std::lock_guard<std::mutex> g(state_mutex_);
251     DCHECK(new_state != state_.load(std::memory_order_relaxed));
252     switch (state_.load(std::memory_order_relaxed)) {
253     case State::ThreadStartup:
254       DCHECK(new_state == State::Ready);
255       break;
256     case State::Ready:
257       DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
258       break;
259     case State::HasWork:
260       DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
261       break;
262     default:
263       abort();
264     }
265     state_.store(new_state, std::memory_order_relaxed);
266     state_cond_.notify_one();
267     if (new_state == State::Ready) {
268       counter_to_decrement_when_ready_->DecrementCount();
269     }
270   }
271 
272   // Thread entry point.
ThreadFunc()273   void ThreadFunc() {
274     c10::setThreadName("CaffeWorkersPool");
275     ChangeState(State::Ready);
276 
277     // Thread main loop
278     while (true) {
279       // Get a state to act on
280       // In the 'Ready' state, we have nothing to do but to wait until
281       // we switch to another state.
282       State state_to_act_upon =
283           WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);
284 
285       // We now have a state to act on, so act.
286       switch (state_to_act_upon) {
287       case State::HasWork:
288         // Got work to do! So do it, and then revert to 'Ready' state.
289         DCHECK(task_.load());
290         (*task_).Run();
291         task_ = nullptr;
292         ChangeState(State::Ready);
293         break;
294       case State::ExitAsSoonAsPossible:
295         return;
296       default:
297         abort();
298       }
299     }
300   }
301 
ThreadFunc(void * arg)302   static void* ThreadFunc(void* arg) {
303     static_cast<Worker*>(arg)->ThreadFunc();
304     return nullptr;
305   }
306 
307   // Called by the master thread to give this worker work to do.
308   // It is only legal to call this if the worker
StartWork(Task * task)309   void StartWork(Task* task) {
310     DCHECK(!task_.load());
311     task_ = task;
312     DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
313     ChangeState(State::HasWork);
314   }
315 
316  private:
317   // The underlying thread.
318   std::unique_ptr<std::thread> thread_;
319 
320   // The task to be worked on.
321   std::atomic<Task*> task_;
322 
323   // The condition variable and mutex guarding state changes.
324   std::condition_variable state_cond_;
325   std::mutex state_mutex_;
326 
327   // The state enum tells if we're currently working, waiting for work, etc.
328   std::atomic<State> state_;
329 
330   // pointer to the master's thread BlockingCounter object, to notify the
331   // master thread of when this worker switches to the 'Ready' state.
332   BlockingCounter* const counter_to_decrement_when_ready_;
333 };
334 
335 class WorkersPool {
336  public:
337   WorkersPool() = default;
338 
Execute(const std::vector<std::shared_ptr<Task>> & tasks)339   void Execute(const std::vector<std::shared_ptr<Task>>& tasks) {
340     CAFFE_ENFORCE_GE(tasks.size(), 1);
341     // One of the tasks will be run on the current thread.
342     int workers_count = tasks.size() - 1;
343     CreateWorkers(workers_count);
344     TORCH_DCHECK_LE(workers_count, (int)workers_.size());
345     counter_to_decrement_when_ready_.Reset(workers_count);
346     for (const auto task : c10::irange(1, tasks.size())) {
347       workers_[task - 1]->StartWork(tasks[task].get());
348     }
349     // Execute the remaining workload immediately on the current thread.
350     auto& task = tasks.front();
351     task->Run();
352     // Wait for the workers submitted above to finish.
353     counter_to_decrement_when_ready_.Wait();
354   }
355 
356  private:
357   // Ensures that the pool has at least the given count of workers.
358   // If any new worker has to be created, this function waits for it to
359   // be ready.
CreateWorkers(std::size_t workers_count)360   void CreateWorkers(std::size_t workers_count) {
361     if (workers_.size() >= workers_count) {
362       return;
363     }
364     counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
365     while (workers_.size() < workers_count) {
366       workers_.push_back(MakeAligned<Worker>::make(&counter_to_decrement_when_ready_));
367     }
368     counter_to_decrement_when_ready_.Wait();
369   }
370 
371   C10_DISABLE_COPY_AND_ASSIGN(WorkersPool);
372   std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
373   // The BlockingCounter used to wait for the workers.
374   BlockingCounter counter_to_decrement_when_ready_;
375 };
376 } // namespace caffe2
377