xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/ThreadPool.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include "caffe2/utils/threadpool/ThreadPool.h"
2*da0073e9SAndroid Build Coastguard Worker #include "WorkersPool.h"
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #if !defined(__s390x__) && !defined(__powerpc__)
5*da0073e9SAndroid Build Coastguard Worker #include <cpuinfo.h>
6*da0073e9SAndroid Build Coastguard Worker #else
7*da0073e9SAndroid Build Coastguard Worker #include <thread>
8*da0073e9SAndroid Build Coastguard Worker #endif
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_bool(
11*da0073e9SAndroid Build Coastguard Worker     caffe2_threadpool_force_inline,
12*da0073e9SAndroid Build Coastguard Worker     false,
13*da0073e9SAndroid Build Coastguard Worker     "Force to always run jobs on the calling thread");
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker // Whether or not threadpool caps apply to Android
16*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(caffe2_threadpool_android_cap, true, "");
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker // Whether or not threadpool caps apply to iOS and MacOS
19*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "");
20*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(caffe2_threadpool_macos_cap, true, "");
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_int(pthreadpool_size, 0, "Override the default thread pool size.");
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker namespace {
27*da0073e9SAndroid Build Coastguard Worker   class ThreadPoolImpl : public ThreadPool {
28*da0073e9SAndroid Build Coastguard Worker   public:
29*da0073e9SAndroid Build Coastguard Worker     explicit ThreadPoolImpl(int numThreads);
30*da0073e9SAndroid Build Coastguard Worker     ~ThreadPoolImpl() override;
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker     // Returns the number of threads currently in use
33*da0073e9SAndroid Build Coastguard Worker     int getNumThreads() const override;
34*da0073e9SAndroid Build Coastguard Worker     void setNumThreads(size_t numThreads) override;
35*da0073e9SAndroid Build Coastguard Worker 
36*da0073e9SAndroid Build Coastguard Worker     void run(const std::function<void(int, size_t)>& fn, size_t range) override;
37*da0073e9SAndroid Build Coastguard Worker     void withPool(const std::function<void(WorkersPool*)>& f) override;
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker   private:
40*da0073e9SAndroid Build Coastguard Worker     std::atomic_size_t numThreads_;
41*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<WorkersPool> workersPool_;
42*da0073e9SAndroid Build Coastguard Worker     std::vector<std::shared_ptr<Task>> tasks_;
43*da0073e9SAndroid Build Coastguard Worker   };
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
getDefaultNumThreads()46*da0073e9SAndroid Build Coastguard Worker size_t getDefaultNumThreads() {
47*da0073e9SAndroid Build Coastguard Worker #if !defined(__s390x__) && !defined(__powerpc__)
48*da0073e9SAndroid Build Coastguard Worker   auto numThreads = 1U;
49*da0073e9SAndroid Build Coastguard Worker   if (cpuinfo_initialize()) {
50*da0073e9SAndroid Build Coastguard Worker     numThreads = std::max(cpuinfo_get_processors_count(), 1U);
51*da0073e9SAndroid Build Coastguard Worker   } else {
52*da0073e9SAndroid Build Coastguard Worker     LOG(WARNING) << "cpuinfo initialization failed";
53*da0073e9SAndroid Build Coastguard Worker     numThreads = std::max(std::thread::hardware_concurrency(), 1U);
54*da0073e9SAndroid Build Coastguard Worker   }
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker   bool applyCap = false;
57*da0073e9SAndroid Build Coastguard Worker #if defined(C10_ANDROID)
58*da0073e9SAndroid Build Coastguard Worker   applyCap = FLAGS_caffe2_threadpool_android_cap;
59*da0073e9SAndroid Build Coastguard Worker #elif defined(C10_IOS)
60*da0073e9SAndroid Build Coastguard Worker   applyCap = FLAGS_caffe2_threadpool_ios_cap;
61*da0073e9SAndroid Build Coastguard Worker #elif defined(TARGET_OS_MAC)
62*da0073e9SAndroid Build Coastguard Worker   applyCap = FLAGS_caffe2_threadpool_macos_cap;
63*da0073e9SAndroid Build Coastguard Worker #endif
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker   if (applyCap) {
66*da0073e9SAndroid Build Coastguard Worker     switch (numThreads) {
67*da0073e9SAndroid Build Coastguard Worker #if defined(C10_ANDROID) && (CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64)
68*da0073e9SAndroid Build Coastguard Worker       case 4:
69*da0073e9SAndroid Build Coastguard Worker         switch (cpuinfo_get_core(0)->midr & UINT32_C(0xFF00FFF0)) {
70*da0073e9SAndroid Build Coastguard Worker           case UINT32_C(0x51002110): /* Snapdragon 820 Kryo Silver */
71*da0073e9SAndroid Build Coastguard Worker           case UINT32_C(0x51002010): /* Snapdragon 821 Kryo Silver */
72*da0073e9SAndroid Build Coastguard Worker           case UINT32_C(0x51002050): /* Snapdragon 820/821 Kryo Gold */
73*da0073e9SAndroid Build Coastguard Worker             /* Kryo: 2+2 big.LITTLE */
74*da0073e9SAndroid Build Coastguard Worker             numThreads = 2;
75*da0073e9SAndroid Build Coastguard Worker             break;
76*da0073e9SAndroid Build Coastguard Worker           default:
77*da0073e9SAndroid Build Coastguard Worker             /* Anything else: assume homogeneous architecture */
78*da0073e9SAndroid Build Coastguard Worker             numThreads = 4;
79*da0073e9SAndroid Build Coastguard Worker             break;
80*da0073e9SAndroid Build Coastguard Worker         }
81*da0073e9SAndroid Build Coastguard Worker         break;
82*da0073e9SAndroid Build Coastguard Worker #endif
83*da0073e9SAndroid Build Coastguard Worker       case 5:
84*da0073e9SAndroid Build Coastguard Worker         /* 4+1 big.LITTLE */
85*da0073e9SAndroid Build Coastguard Worker         numThreads = 4;
86*da0073e9SAndroid Build Coastguard Worker         break;
87*da0073e9SAndroid Build Coastguard Worker       case 6:
88*da0073e9SAndroid Build Coastguard Worker         /* 2+4 big.LITTLE */
89*da0073e9SAndroid Build Coastguard Worker         numThreads = 2;
90*da0073e9SAndroid Build Coastguard Worker         break;
91*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-branch-clone)
92*da0073e9SAndroid Build Coastguard Worker       case 8:
93*da0073e9SAndroid Build Coastguard Worker         /* 4+4 big.LITTLE */
94*da0073e9SAndroid Build Coastguard Worker         numThreads = 4;
95*da0073e9SAndroid Build Coastguard Worker         break;
96*da0073e9SAndroid Build Coastguard Worker       case 10:
97*da0073e9SAndroid Build Coastguard Worker         /* 4+4+2 Min.Med.Max, running on Med cores */
98*da0073e9SAndroid Build Coastguard Worker         numThreads = 4;
99*da0073e9SAndroid Build Coastguard Worker         break;
100*da0073e9SAndroid Build Coastguard Worker       default:
101*da0073e9SAndroid Build Coastguard Worker         if (numThreads > 4) {
102*da0073e9SAndroid Build Coastguard Worker           numThreads = numThreads / 2;
103*da0073e9SAndroid Build Coastguard Worker         }
104*da0073e9SAndroid Build Coastguard Worker         break;
105*da0073e9SAndroid Build Coastguard Worker     }
106*da0073e9SAndroid Build Coastguard Worker   }
107*da0073e9SAndroid Build Coastguard Worker #else
108*da0073e9SAndroid Build Coastguard Worker   auto numThreads = std::max(std::thread::hardware_concurrency(), 1U);
109*da0073e9SAndroid Build Coastguard Worker #endif
110*da0073e9SAndroid Build Coastguard Worker 
111*da0073e9SAndroid Build Coastguard Worker   if (FLAGS_pthreadpool_size) {
112*da0073e9SAndroid Build Coastguard Worker     // Always give precedence to explicit setting.
113*da0073e9SAndroid Build Coastguard Worker     numThreads = FLAGS_pthreadpool_size;
114*da0073e9SAndroid Build Coastguard Worker   }
115*da0073e9SAndroid Build Coastguard Worker 
116*da0073e9SAndroid Build Coastguard Worker   /*
117*da0073e9SAndroid Build Coastguard Worker    * For llvm-tsan, holding limit for the number of locks for a single thread
118*da0073e9SAndroid Build Coastguard Worker    * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
119*da0073e9SAndroid Build Coastguard Worker    * case is the number of threads in a pool. So we want to limit the threadpool
120*da0073e9SAndroid Build Coastguard Worker    * size to 64 when running with tsan. However, sometimes it is tricky to
121*da0073e9SAndroid Build Coastguard Worker    * detect if we are running under tsan, for now capping the default
122*da0073e9SAndroid Build Coastguard Worker    * threadcount to the tsan limit unconditionally.
123*da0073e9SAndroid Build Coastguard Worker    */
124*da0073e9SAndroid Build Coastguard Worker   auto tsanThreadLimit = 63U;
125*da0073e9SAndroid Build Coastguard Worker   numThreads = std::min(numThreads, tsanThreadLimit);
126*da0073e9SAndroid Build Coastguard Worker 
127*da0073e9SAndroid Build Coastguard Worker   return numThreads;
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker 
130*da0073e9SAndroid Build Coastguard Worker // Default smallest amount of work that will be partitioned between
131*da0073e9SAndroid Build Coastguard Worker // multiple threads; the runtime value is configurable
132*da0073e9SAndroid Build Coastguard Worker constexpr size_t kDefaultMinWorkSize = 1;
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker size_t ThreadPool::defaultNumThreads_ = 0;
135*da0073e9SAndroid Build Coastguard Worker 
createThreadPool(int numThreads)136*da0073e9SAndroid Build Coastguard Worker ThreadPool* ThreadPool::createThreadPool(int numThreads) {
137*da0073e9SAndroid Build Coastguard Worker   return new ThreadPoolImpl(numThreads);
138*da0073e9SAndroid Build Coastguard Worker }
139*da0073e9SAndroid Build Coastguard Worker 
defaultThreadPool()140*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<ThreadPool> ThreadPool::defaultThreadPool() {
141*da0073e9SAndroid Build Coastguard Worker   defaultNumThreads_ = getDefaultNumThreads();
142*da0073e9SAndroid Build Coastguard Worker   LOG(INFO) << "Constructing thread pool with " << defaultNumThreads_
143*da0073e9SAndroid Build Coastguard Worker             << " threads";
144*da0073e9SAndroid Build Coastguard Worker   return std::make_unique<ThreadPoolImpl>(defaultNumThreads_);
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker 
ThreadPoolImpl(int numThreads)147*da0073e9SAndroid Build Coastguard Worker ThreadPoolImpl::ThreadPoolImpl(int numThreads)
148*da0073e9SAndroid Build Coastguard Worker     : numThreads_(numThreads),
149*da0073e9SAndroid Build Coastguard Worker       workersPool_(std::make_shared<WorkersPool>()) {
150*da0073e9SAndroid Build Coastguard Worker   minWorkSize_ = kDefaultMinWorkSize;
151*da0073e9SAndroid Build Coastguard Worker }
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(modernize-use-equals-default)
~ThreadPoolImpl()154*da0073e9SAndroid Build Coastguard Worker ThreadPoolImpl::~ThreadPoolImpl() {}
155*da0073e9SAndroid Build Coastguard Worker 
getNumThreads() const156*da0073e9SAndroid Build Coastguard Worker int ThreadPoolImpl::getNumThreads() const {
157*da0073e9SAndroid Build Coastguard Worker   return numThreads_;
158*da0073e9SAndroid Build Coastguard Worker }
159*da0073e9SAndroid Build Coastguard Worker 
160*da0073e9SAndroid Build Coastguard Worker // Sets the number of threads
161*da0073e9SAndroid Build Coastguard Worker // # of threads should not be bigger than the number of big cores
setNumThreads(size_t numThreads)162*da0073e9SAndroid Build Coastguard Worker void ThreadPoolImpl::setNumThreads(size_t numThreads) {
163*da0073e9SAndroid Build Coastguard Worker   if (defaultNumThreads_ == 0) {
164*da0073e9SAndroid Build Coastguard Worker     defaultNumThreads_ = getDefaultNumThreads();
165*da0073e9SAndroid Build Coastguard Worker   }
166*da0073e9SAndroid Build Coastguard Worker   numThreads_ = std::min(numThreads, defaultNumThreads_);
167*da0073e9SAndroid Build Coastguard Worker }
168*da0073e9SAndroid Build Coastguard Worker 
run(const std::function<void (int,size_t)> & fn,size_t range)169*da0073e9SAndroid Build Coastguard Worker void ThreadPoolImpl::run(const std::function<void(int, size_t)>& fn, size_t range) {
170*da0073e9SAndroid Build Coastguard Worker   const auto numThreads = numThreads_.load(std::memory_order_relaxed);
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(executionMutex_);
173*da0073e9SAndroid Build Coastguard Worker   // If there are no worker threads, or if the range is too small (too
174*da0073e9SAndroid Build Coastguard Worker   // little work), just run locally
175*da0073e9SAndroid Build Coastguard Worker   const bool runLocally = range < minWorkSize_ ||
176*da0073e9SAndroid Build Coastguard Worker       FLAGS_caffe2_threadpool_force_inline || (numThreads == 0);
177*da0073e9SAndroid Build Coastguard Worker   if (runLocally) {
178*da0073e9SAndroid Build Coastguard Worker     // Work is small enough to just run locally; multithread overhead
179*da0073e9SAndroid Build Coastguard Worker     // is too high
180*da0073e9SAndroid Build Coastguard Worker     for (size_t i = 0; i < range; ++i) {
181*da0073e9SAndroid Build Coastguard Worker       fn(0, i);
182*da0073e9SAndroid Build Coastguard Worker     }
183*da0073e9SAndroid Build Coastguard Worker     return;
184*da0073e9SAndroid Build Coastguard Worker   }
185*da0073e9SAndroid Build Coastguard Worker 
186*da0073e9SAndroid Build Coastguard Worker   struct FnTask : public Task {
187*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(modernize-use-equals-default,cppcoreguidelines-pro-type-member-init)
188*da0073e9SAndroid Build Coastguard Worker     FnTask(){};
189*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(modernize-use-equals-default)
190*da0073e9SAndroid Build Coastguard Worker     ~FnTask() override{};
191*da0073e9SAndroid Build Coastguard Worker     const std::function<void(int, size_t)>* fn_;
192*da0073e9SAndroid Build Coastguard Worker     int idx_;
193*da0073e9SAndroid Build Coastguard Worker     size_t start_;
194*da0073e9SAndroid Build Coastguard Worker     size_t end_;
195*da0073e9SAndroid Build Coastguard Worker     void Run() override {
196*da0073e9SAndroid Build Coastguard Worker       for (auto i = start_; i < end_; ++i) {
197*da0073e9SAndroid Build Coastguard Worker         (*fn_)(idx_, i);
198*da0073e9SAndroid Build Coastguard Worker       }
199*da0073e9SAndroid Build Coastguard Worker     }
200*da0073e9SAndroid Build Coastguard Worker   };
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE_GE(numThreads_, 1);
203*da0073e9SAndroid Build Coastguard Worker   const size_t unitsPerTask = (range + numThreads - 1) / numThreads;
204*da0073e9SAndroid Build Coastguard Worker   tasks_.resize(numThreads);
205*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < numThreads; ++i) {
206*da0073e9SAndroid Build Coastguard Worker     if (!tasks_[i]) {
207*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(modernize-make-shared)
208*da0073e9SAndroid Build Coastguard Worker       tasks_[i].reset(new FnTask());
209*da0073e9SAndroid Build Coastguard Worker     }
210*da0073e9SAndroid Build Coastguard Worker     auto* task = (FnTask*)tasks_[i].get();
211*da0073e9SAndroid Build Coastguard Worker     task->fn_ = &fn;
212*da0073e9SAndroid Build Coastguard Worker     task->idx_ = i;
213*da0073e9SAndroid Build Coastguard Worker     task->start_ = std::min<size_t>(range, i * unitsPerTask);
214*da0073e9SAndroid Build Coastguard Worker     task->end_ = std::min<size_t>(range, (i + 1) * unitsPerTask);
215*da0073e9SAndroid Build Coastguard Worker     if (task->start_ >= task->end_) {
216*da0073e9SAndroid Build Coastguard Worker       tasks_.resize(i);
217*da0073e9SAndroid Build Coastguard Worker       break;
218*da0073e9SAndroid Build Coastguard Worker     }
219*da0073e9SAndroid Build Coastguard Worker     CAFFE_ENFORCE_LE(task->start_, range);
220*da0073e9SAndroid Build Coastguard Worker     CAFFE_ENFORCE_LE(task->end_, range);
221*da0073e9SAndroid Build Coastguard Worker   }
222*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE_LE(tasks_.size(), numThreads);
223*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE_GE(tasks_.size(), 1);
224*da0073e9SAndroid Build Coastguard Worker   workersPool_->Execute(tasks_);
225*da0073e9SAndroid Build Coastguard Worker }
226*da0073e9SAndroid Build Coastguard Worker 
withPool(const std::function<void (WorkersPool *)> & f)227*da0073e9SAndroid Build Coastguard Worker void ThreadPoolImpl::withPool(const std::function<void(WorkersPool*)>& f) {
228*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(executionMutex_);
229*da0073e9SAndroid Build Coastguard Worker   f(workersPool_.get());
230*da0073e9SAndroid Build Coastguard Worker }
231*da0073e9SAndroid Build Coastguard Worker 
232*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
233