xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/ThreadPool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #ifndef CAFFE2_UTILS_THREADPOOL_H_
2*da0073e9SAndroid Build Coastguard Worker #define CAFFE2_UTILS_THREADPOOL_H_
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #include "ThreadPoolCommon.h"
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <atomic>
7*da0073e9SAndroid Build Coastguard Worker #include <functional>
8*da0073e9SAndroid Build Coastguard Worker #include <memory>
9*da0073e9SAndroid Build Coastguard Worker #include <mutex>
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #include "caffe2/core/common.h"
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker //
15*da0073e9SAndroid Build Coastguard Worker // A work-stealing threadpool loosely based off of pthreadpool
16*da0073e9SAndroid Build Coastguard Worker //
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker struct Task;
21*da0073e9SAndroid Build Coastguard Worker class WorkersPool;
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker constexpr size_t kCacheLineSize = 64;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker // A threadpool with the given number of threads.
26*da0073e9SAndroid Build Coastguard Worker // NOTE: the kCacheLineSize alignment is present only for cache
27*da0073e9SAndroid Build Coastguard Worker // performance, and is not strictly enforced (for example, when
28*da0073e9SAndroid Build Coastguard Worker // the object is created on the heap). Thus, in order to avoid
29*da0073e9SAndroid Build Coastguard Worker // misaligned intrinsics, no SSE instructions shall be involved in
30*da0073e9SAndroid Build Coastguard Worker // the ThreadPool implementation.
31*da0073e9SAndroid Build Coastguard Worker // Note: alignas is disabled because some compilers do not deal with
32*da0073e9SAndroid Build Coastguard Worker // TORCH_API and alignas annotations at the same time.
33*da0073e9SAndroid Build Coastguard Worker class TORCH_API /*alignas(kCacheLineSize)*/ ThreadPool {
34*da0073e9SAndroid Build Coastguard Worker  public:
35*da0073e9SAndroid Build Coastguard Worker   static ThreadPool* createThreadPool(int numThreads);
36*da0073e9SAndroid Build Coastguard Worker   static std::unique_ptr<ThreadPool> defaultThreadPool();
37*da0073e9SAndroid Build Coastguard Worker   virtual ~ThreadPool() = default;
38*da0073e9SAndroid Build Coastguard Worker   // Returns the number of threads currently in use
39*da0073e9SAndroid Build Coastguard Worker   virtual int getNumThreads() const = 0;
40*da0073e9SAndroid Build Coastguard Worker   virtual void setNumThreads(size_t numThreads) = 0;
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker   // Sets the minimum work size (range) for which to invoke the
43*da0073e9SAndroid Build Coastguard Worker   // threadpool; work sizes smaller than this will just be run on the
44*da0073e9SAndroid Build Coastguard Worker   // main (calling) thread
setMinWorkSize(size_t size)45*da0073e9SAndroid Build Coastguard Worker   void setMinWorkSize(size_t size) {
46*da0073e9SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> guard(executionMutex_);
47*da0073e9SAndroid Build Coastguard Worker     minWorkSize_ = size;
48*da0073e9SAndroid Build Coastguard Worker   }
49*da0073e9SAndroid Build Coastguard Worker 
getMinWorkSize()50*da0073e9SAndroid Build Coastguard Worker   size_t getMinWorkSize() const {
51*da0073e9SAndroid Build Coastguard Worker     return minWorkSize_;
52*da0073e9SAndroid Build Coastguard Worker   }
53*da0073e9SAndroid Build Coastguard Worker   virtual void run(const std::function<void(int, size_t)>& fn, size_t range) = 0;
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   // Run an arbitrary function in a thread-safe manner accessing the Workers
56*da0073e9SAndroid Build Coastguard Worker   // Pool
57*da0073e9SAndroid Build Coastguard Worker   virtual void withPool(const std::function<void(WorkersPool*)>& fn) = 0;
58*da0073e9SAndroid Build Coastguard Worker 
59*da0073e9SAndroid Build Coastguard Worker  protected:
60*da0073e9SAndroid Build Coastguard Worker   static size_t defaultNumThreads_;
61*da0073e9SAndroid Build Coastguard Worker   mutable std::mutex executionMutex_;
62*da0073e9SAndroid Build Coastguard Worker   size_t minWorkSize_;
63*da0073e9SAndroid Build Coastguard Worker };
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker size_t getDefaultNumThreads();
66*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker #endif // CAFFE2_UTILS_THREADPOOL_H_
69