xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/pthreadpool-cpp.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
2*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/thread_pool_guard.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <atomic>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace {
8*da0073e9SAndroid Build Coastguard Worker // After fork, the child process inherits the data-structures of the parent
9*da0073e9SAndroid Build Coastguard Worker // process' thread-pool, but since those threads don't exist, the thread-pool
10*da0073e9SAndroid Build Coastguard Worker // is corrupt. It's leaked in order to prevent segfaults.
11*da0073e9SAndroid Build Coastguard Worker // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
12*da0073e9SAndroid Build Coastguard Worker bool leak_corrupted_threadpool = false;
13*da0073e9SAndroid Build Coastguard Worker 
child_atfork()14*da0073e9SAndroid Build Coastguard Worker void child_atfork() {
15*da0073e9SAndroid Build Coastguard Worker   leak_corrupted_threadpool = true;
16*da0073e9SAndroid Build Coastguard Worker }
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker } // namespace
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
21*da0073e9SAndroid Build Coastguard Worker 
PThreadPool(const size_t thread_count)22*da0073e9SAndroid Build Coastguard Worker PThreadPool::PThreadPool(const size_t thread_count)
23*da0073e9SAndroid Build Coastguard Worker     : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
24*da0073e9SAndroid Build Coastguard Worker 
get_thread_count() const25*da0073e9SAndroid Build Coastguard Worker size_t PThreadPool::get_thread_count() const {
26*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
29*da0073e9SAndroid Build Coastguard Worker   return pthreadpool_get_threads_count(threadpool_.get());
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker 
set_thread_count(const size_t thread_count)32*da0073e9SAndroid Build Coastguard Worker void PThreadPool::set_thread_count(const size_t thread_count) {
33*da0073e9SAndroid Build Coastguard Worker   // No need to do anything if the count is same
34*da0073e9SAndroid Build Coastguard Worker   if (thread_count == get_thread_count()) {
35*da0073e9SAndroid Build Coastguard Worker     return;
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   // As it stands, pthreadpool is an entirely data parallel framework with no
41*da0073e9SAndroid Build Coastguard Worker   // support for task parallelism.  Hence, all functions are blocking, and no
42*da0073e9SAndroid Build Coastguard Worker   // user-provided tasks can be in flight when the control is returned to the
43*da0073e9SAndroid Build Coastguard Worker   // user of the API, which means re-initializing the library, without the
44*da0073e9SAndroid Build Coastguard Worker   // need to wait on any pending tasks, is all one needs to do to re-adjust
45*da0073e9SAndroid Build Coastguard Worker   // the thread count.
46*da0073e9SAndroid Build Coastguard Worker   threadpool_.reset(pthreadpool_create(thread_count));
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker 
run(const std::function<void (size_t)> & fn,const size_t range)49*da0073e9SAndroid Build Coastguard Worker void PThreadPool::run(
50*da0073e9SAndroid Build Coastguard Worker     const std::function<void(size_t)>& fn,
51*da0073e9SAndroid Build Coastguard Worker     const size_t range) {
52*da0073e9SAndroid Build Coastguard Worker   // Run on same thread if _NoPThreadPoolGuard guard is enabled
53*da0073e9SAndroid Build Coastguard Worker   if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
54*da0073e9SAndroid Build Coastguard Worker     for (size_t i = 0; i < range; ++i) {
55*da0073e9SAndroid Build Coastguard Worker       fn(i);
56*da0073e9SAndroid Build Coastguard Worker     }
57*da0073e9SAndroid Build Coastguard Worker     return;
58*da0073e9SAndroid Build Coastguard Worker   }
59*da0073e9SAndroid Build Coastguard Worker 
60*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
63*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
64*da0073e9SAndroid Build Coastguard Worker 
65*da0073e9SAndroid Build Coastguard Worker   struct Context final {
66*da0073e9SAndroid Build Coastguard Worker     const std::function<void(size_t)>& fn;
67*da0073e9SAndroid Build Coastguard Worker   } context{
68*da0073e9SAndroid Build Coastguard Worker       fn,
69*da0073e9SAndroid Build Coastguard Worker   };
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker   pthreadpool_parallelize_1d(
72*da0073e9SAndroid Build Coastguard Worker       threadpool_.get(),
73*da0073e9SAndroid Build Coastguard Worker       // Note: pthreadpool_parallelize_1d() is a blocking function.  The
74*da0073e9SAndroid Build Coastguard Worker       // function pointer to this lambda passed on to
75*da0073e9SAndroid Build Coastguard Worker       // pthreadpool_parallelize_1d() cannot go out of scope until
76*da0073e9SAndroid Build Coastguard Worker       // pthreadpool_parallelize_1d() returns.
77*da0073e9SAndroid Build Coastguard Worker       [](void* const context, const size_t item) {
78*da0073e9SAndroid Build Coastguard Worker         reinterpret_cast<Context*>(context)->fn(item);
79*da0073e9SAndroid Build Coastguard Worker       },
80*da0073e9SAndroid Build Coastguard Worker       &context,
81*da0073e9SAndroid Build Coastguard Worker       range,
82*da0073e9SAndroid Build Coastguard Worker       0u);
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker // Forward declaration
86*da0073e9SAndroid Build Coastguard Worker size_t getDefaultNumThreads();
87*da0073e9SAndroid Build Coastguard Worker 
pthreadpool()88*da0073e9SAndroid Build Coastguard Worker PThreadPool* pthreadpool() {
89*da0073e9SAndroid Build Coastguard Worker   static auto threadpool =
90*da0073e9SAndroid Build Coastguard Worker     std::make_unique<PThreadPool>(getDefaultNumThreads());
91*da0073e9SAndroid Build Coastguard Worker #if !(defined(WIN32))
92*da0073e9SAndroid Build Coastguard Worker   static std::once_flag flag;
93*da0073e9SAndroid Build Coastguard Worker   std::call_once(flag, []() {
94*da0073e9SAndroid Build Coastguard Worker     pthread_atfork(nullptr, nullptr, child_atfork);
95*da0073e9SAndroid Build Coastguard Worker   });
96*da0073e9SAndroid Build Coastguard Worker #endif
97*da0073e9SAndroid Build Coastguard Worker   if (C10_UNLIKELY(leak_corrupted_threadpool)) {
98*da0073e9SAndroid Build Coastguard Worker     leak_corrupted_threadpool = false;
99*da0073e9SAndroid Build Coastguard Worker     if (auto leaked = threadpool.release()) {
100*da0073e9SAndroid Build Coastguard Worker       auto num_threads = leaked->get_thread_count();
101*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(modernize-make-unique)
102*da0073e9SAndroid Build Coastguard Worker       threadpool.reset(new PThreadPool(num_threads));
103*da0073e9SAndroid Build Coastguard Worker     }
104*da0073e9SAndroid Build Coastguard Worker   }
105*da0073e9SAndroid Build Coastguard Worker   return threadpool.get();
106*da0073e9SAndroid Build Coastguard Worker }
107*da0073e9SAndroid Build Coastguard Worker 
pthreadpool_()108*da0073e9SAndroid Build Coastguard Worker pthreadpool_t pthreadpool_() {
109*da0073e9SAndroid Build Coastguard Worker   if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
110*da0073e9SAndroid Build Coastguard Worker     return nullptr;
111*da0073e9SAndroid Build Coastguard Worker   }
112*da0073e9SAndroid Build Coastguard Worker   PThreadPool* const threadpool = pthreadpool();
113*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
114*da0073e9SAndroid Build Coastguard Worker       threadpool, "Failed to acquire an instance of PThreadPool!");
115*da0073e9SAndroid Build Coastguard Worker   return threadpool->threadpool_.get();
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
119