1 #pragma once
2
3 #include <atomic>
4 #include <condition_variable>
5 #include <thread>
6 #include "c10/util/thread_name.h"
7 #include <c10/util/irange.h>
8 #include <c10/util/Logging.h>
9
10 #if defined(_MSC_VER)
11 #include <intrin.h>
12 #endif
13
14 namespace caffe2 {
15
16 // Uses code derived from gemmlowp,
17 // https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h
18 // Changes:
19 // - allocation-free execute()
20 // - Use RAII where possible.
21 // - Run the first task on the main thread (since that is the largest task).
22 // - removed custom allocator.
23 // - Removed some ifdef's
24 // - cache-line align Worker.
25 // - use std::atomic instead of volatile and custom barriers.
26 // - use std::mutex/std::condition_variable instead of raw pthreads.
27
28 constexpr size_t kGEMMLOWPCacheLineSize = 64;
29
30 template <typename T>
31 struct AllocAligned {
32 // Allocate a T aligned at an `align` byte address
33 template <typename... Args>
allocAllocAligned34 static T* alloc(Args&&... args) {
35 void* p = nullptr;
36
37 #if defined(__ANDROID__)
38 p = memalign(kGEMMLOWPCacheLineSize, sizeof(T));
39 #elif defined(_MSC_VER)
40 p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
41 #else
42 auto res = posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T));
43 (void)res;
44 #endif
45
46 if (p) {
47 return new (p) T(std::forward<Args>(args)...);
48 }
49
50 return nullptr;
51 }
52
53 // Free a T previously allocated via AllocAligned<T>::alloc()
releaseAllocAligned54 static void release(T* p) {
55 if (p) {
56 p->~T();
57 #if defined(_MSC_VER)
58 _aligned_free((void*)p);
59 #else
60 free((void*)p);
61 #endif
62 }
63 }
64 };
65
66 // Deleter object for unique_ptr for an aligned object
67 template <typename T>
68 struct AlignedDeleter {
operatorAlignedDeleter69 void operator()(T* p) const { AllocAligned<T>::release(p); }
70 };
71
72 // make_unique that guarantees alignment
73 template <typename T>
74 struct MakeAligned {
75 template <typename... Args>
makeMakeAligned76 static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
77 return std::unique_ptr<T, AlignedDeleter<T>>(
78 AllocAligned<T>::alloc(std::forward<Args>(args)...));
79 }
80 };
81
82 const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
83
84 #if defined(_MSC_VER)
85 #define GEMMLOWP_NOP __nop();
86 #else
87 #define GEMMLOWP_NOP "nop\n"
88 #endif
89
90 #define GEMMLOWP_STRING_CONCAT_4(X) X X X X
91 #define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
92 #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
93 #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
94
Do256NOPs()95 inline int Do256NOPs() {
96 #if defined(_MSC_VER)
97 GEMMLOWP_NOP64;
98 #else
99 asm volatile(GEMMLOWP_NOP64);
100 #endif
101 return 64;
102 }
103
104 #undef GEMMLOWP_STRING_CONCAT_4
105 #undef GEMMLOWP_NOP256
106 #undef GEMMLOWP_NOP64
107 #undef GEMMLOWP_NOP16
108 #undef GEMMLOWP_NOP4
109 #undef GEMMLOWP_NOP
110
111 // Waits until *var != initial_value.
112 //
113 // Returns the new value of *var. The guarantee here is that
114 // the return value is different from initial_value, and that that
115 // new value has been taken by *var at some point during the
116 // execution of this function. There is no guarantee that this is
117 // still the value of *var when this function returns, since *var is
118 // not assumed to be guarded by any lock.
119 //
120 // First does some busy-waiting for a fixed number of no-op cycles,
121 // then falls back to passive waiting for the given condvar, guarded
122 // by the given mutex.
123 //
124 // The idea of doing some initial busy-waiting is to help get
125 // better and more consistent multithreading benefits for small GEMM sizes.
126 // Busy-waiting help ensuring that if we need to wake up soon after having
127 // started waiting, then we can wake up quickly (as opposed to, say,
128 // having to wait to be scheduled again by the OS). On the other hand,
129 // we must still eventually revert to passive waiting for longer waits
130 // (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
131 // so as to avoid permanently spinning.
132 //
133 template <typename T>
WaitForVariableChange(std::atomic<T> * var,T initial_value,std::condition_variable * cond,std::mutex * mutex)134 T WaitForVariableChange(std::atomic<T>* var,
135 T initial_value,
136 std::condition_variable* cond,
137 std::mutex* mutex) {
138 // If we are on a platform that supports it, spin for some time.
139 {
140 int nops = 0;
141 // First, trivial case where the variable already changed value.
142 T new_value = var->load(std::memory_order_relaxed);
143 if (new_value != initial_value) {
144 std::atomic_thread_fence(std::memory_order_acquire);
145 return new_value;
146 }
147 // Then try busy-waiting.
148 while (nops < kMaxBusyWaitNOPs) {
149 nops += Do256NOPs();
150 new_value = var->load(std::memory_order_relaxed);
151 if (new_value != initial_value) {
152 std::atomic_thread_fence(std::memory_order_acquire);
153 return new_value;
154 }
155 }
156 }
157
158 // Finally, do real passive waiting.
159 {
160 std::unique_lock<std::mutex> g(*mutex);
161 T new_value = var->load(std::memory_order_relaxed);
162 // Handle spurious wakeups.
163 cond->wait(g, [&]() {
164 new_value = var->load(std::memory_order_relaxed);
165 return new_value != initial_value;
166 });
167 TORCH_DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
168 return new_value;
169 }
170 }
171
172 // A BlockingCounter lets one thread to wait for N events to occur.
173 // This is how the master thread waits for all the worker threads
174 // to have finished working.
175 class BlockingCounter {
176 public:
177 // Sets/resets the counter; initial_count is the number of
178 // decrementing events that the Wait() call will be waiting for.
Reset(std::size_t initial_count)179 void Reset(std::size_t initial_count) {
180 std::lock_guard<std::mutex> g(mutex_);
181 TORCH_DCHECK_EQ(count_, 0);
182 count_ = initial_count;
183 }
184
185 // Decrements the counter; if the counter hits zero, signals
186 // the thread that was waiting for that, and returns true.
187 // Otherwise (if the decremented count is still nonzero),
188 // returns false.
DecrementCount()189 bool DecrementCount() {
190 const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
191 TORCH_DCHECK_GE(count_value, 0);
192 if (count_value == 0) {
193 std::lock_guard<std::mutex> g(mutex_);
194 cond_.notify_one();
195 }
196 bool retval = count_value == 0;
197 return retval;
198 }
199
200 // Waits for the N other threads (N having been set by Reset())
201 // to hit the BlockingCounter.
Wait()202 void Wait() {
203 while (size_t count_value = count_.load(std::memory_order_relaxed)) {
204 WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
205 }
206 }
207
208 private:
209 std::condition_variable cond_;
210 std::mutex mutex_;
211 std::atomic<std::size_t> count_{0};
212 };
213
214 // A workload for a worker.
215 struct Task {
216 Task() = default;
217 virtual ~Task() = default;
218 virtual void Run() = 0;
219 };
220
221 // A worker thread.
222 class alignas(kGEMMLOWPCacheLineSize) Worker {
223 public:
224 enum class State : uint8_t {
225 ThreadStartup, // The initial state before the thread main loop runs.
226 Ready, // Is not working, has not yet received new work to do.
227 HasWork, // Has work to do.
228 ExitAsSoonAsPossible // Should exit at earliest convenience.
229 };
230
Worker(BlockingCounter * counter_to_decrement_when_ready)231 explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
232 : task_(nullptr),
233 state_(State::ThreadStartup),
234 counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
235 thread_ = std::make_unique<std::thread>([this]() {
236 c10::setThreadName("pt_thread_pool");
237 this->ThreadFunc();
238 });
239 }
240
~Worker()241 ~Worker() {
242 ChangeState(State::ExitAsSoonAsPossible);
243 thread_->join();
244 }
245
246 // Changes State; may be called from either the worker thread
247 // or the master thread; however, not all state transitions are legal,
248 // which is guarded by assertions.
ChangeState(State new_state)249 void ChangeState(State new_state) {
250 std::lock_guard<std::mutex> g(state_mutex_);
251 DCHECK(new_state != state_.load(std::memory_order_relaxed));
252 switch (state_.load(std::memory_order_relaxed)) {
253 case State::ThreadStartup:
254 DCHECK(new_state == State::Ready);
255 break;
256 case State::Ready:
257 DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
258 break;
259 case State::HasWork:
260 DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
261 break;
262 default:
263 abort();
264 }
265 state_.store(new_state, std::memory_order_relaxed);
266 state_cond_.notify_one();
267 if (new_state == State::Ready) {
268 counter_to_decrement_when_ready_->DecrementCount();
269 }
270 }
271
272 // Thread entry point.
ThreadFunc()273 void ThreadFunc() {
274 c10::setThreadName("CaffeWorkersPool");
275 ChangeState(State::Ready);
276
277 // Thread main loop
278 while (true) {
279 // Get a state to act on
280 // In the 'Ready' state, we have nothing to do but to wait until
281 // we switch to another state.
282 State state_to_act_upon =
283 WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);
284
285 // We now have a state to act on, so act.
286 switch (state_to_act_upon) {
287 case State::HasWork:
288 // Got work to do! So do it, and then revert to 'Ready' state.
289 DCHECK(task_.load());
290 (*task_).Run();
291 task_ = nullptr;
292 ChangeState(State::Ready);
293 break;
294 case State::ExitAsSoonAsPossible:
295 return;
296 default:
297 abort();
298 }
299 }
300 }
301
ThreadFunc(void * arg)302 static void* ThreadFunc(void* arg) {
303 static_cast<Worker*>(arg)->ThreadFunc();
304 return nullptr;
305 }
306
307 // Called by the master thread to give this worker work to do.
308 // It is only legal to call this if the worker
StartWork(Task * task)309 void StartWork(Task* task) {
310 DCHECK(!task_.load());
311 task_ = task;
312 DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
313 ChangeState(State::HasWork);
314 }
315
316 private:
317 // The underlying thread.
318 std::unique_ptr<std::thread> thread_;
319
320 // The task to be worked on.
321 std::atomic<Task*> task_;
322
323 // The condition variable and mutex guarding state changes.
324 std::condition_variable state_cond_;
325 std::mutex state_mutex_;
326
327 // The state enum tells if we're currently working, waiting for work, etc.
328 std::atomic<State> state_;
329
330 // pointer to the master's thread BlockingCounter object, to notify the
331 // master thread of when this worker switches to the 'Ready' state.
332 BlockingCounter* const counter_to_decrement_when_ready_;
333 };
334
335 class WorkersPool {
336 public:
337 WorkersPool() = default;
338
Execute(const std::vector<std::shared_ptr<Task>> & tasks)339 void Execute(const std::vector<std::shared_ptr<Task>>& tasks) {
340 CAFFE_ENFORCE_GE(tasks.size(), 1);
341 // One of the tasks will be run on the current thread.
342 int workers_count = tasks.size() - 1;
343 CreateWorkers(workers_count);
344 TORCH_DCHECK_LE(workers_count, (int)workers_.size());
345 counter_to_decrement_when_ready_.Reset(workers_count);
346 for (const auto task : c10::irange(1, tasks.size())) {
347 workers_[task - 1]->StartWork(tasks[task].get());
348 }
349 // Execute the remaining workload immediately on the current thread.
350 auto& task = tasks.front();
351 task->Run();
352 // Wait for the workers submitted above to finish.
353 counter_to_decrement_when_ready_.Wait();
354 }
355
356 private:
357 // Ensures that the pool has at least the given count of workers.
358 // If any new worker has to be created, this function waits for it to
359 // be ready.
CreateWorkers(std::size_t workers_count)360 void CreateWorkers(std::size_t workers_count) {
361 if (workers_.size() >= workers_count) {
362 return;
363 }
364 counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
365 while (workers_.size() < workers_count) {
366 workers_.push_back(MakeAligned<Worker>::make(&counter_to_decrement_when_ready_));
367 }
368 counter_to_decrement_when_ready_.Wait();
369 }
370
371 C10_DISABLE_COPY_AND_ASSIGN(WorkersPool);
372 std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
373 // The BlockingCounter used to wait for the workers.
374 BlockingCounter counter_to_decrement_when_ready_;
375 };
376 } // namespace caffe2
377