xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/pthreadpool-cpp.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
2 #include <caffe2/utils/threadpool/thread_pool_guard.h>
3 #include <c10/util/Exception.h>
4 
5 #include <atomic>
6 
7 namespace {
8 // After fork, the child process inherits the data-structures of the parent
9 // process' thread-pool, but since those threads don't exist, the thread-pool
10 // is corrupt. It's leaked in order to prevent segfaults.
11 // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
12 bool leak_corrupted_threadpool = false;
13 
child_atfork()14 void child_atfork() {
15   leak_corrupted_threadpool = true;
16 }
17 
18 } // namespace
19 
20 namespace caffe2 {
21 
PThreadPool(const size_t thread_count)22 PThreadPool::PThreadPool(const size_t thread_count)
23     : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
24 
get_thread_count() const25 size_t PThreadPool::get_thread_count() const {
26   std::lock_guard<std::mutex> lock{mutex_};
27 
28   TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
29   return pthreadpool_get_threads_count(threadpool_.get());
30 }
31 
set_thread_count(const size_t thread_count)32 void PThreadPool::set_thread_count(const size_t thread_count) {
33   // No need to do anything if the count is same
34   if (thread_count == get_thread_count()) {
35     return;
36   }
37 
38   std::lock_guard<std::mutex> lock{mutex_};
39 
40   // As it stands, pthreadpool is an entirely data parallel framework with no
41   // support for task parallelism.  Hence, all functions are blocking, and no
42   // user-provided tasks can be in flight when the control is returned to the
43   // user of the API, which means re-initializing the library, without the
44   // need to wait on any pending tasks, is all one needs to do to re-adjust
45   // the thread count.
46   threadpool_.reset(pthreadpool_create(thread_count));
47 }
48 
run(const std::function<void (size_t)> & fn,const size_t range)49 void PThreadPool::run(
50     const std::function<void(size_t)>& fn,
51     const size_t range) {
52   // Run on same thread if _NoPThreadPoolGuard guard is enabled
53   if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
54     for (size_t i = 0; i < range; ++i) {
55       fn(i);
56     }
57     return;
58   }
59 
60   std::lock_guard<std::mutex> lock{mutex_};
61 
62   TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
63   TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
64 
65   struct Context final {
66     const std::function<void(size_t)>& fn;
67   } context{
68       fn,
69   };
70 
71   pthreadpool_parallelize_1d(
72       threadpool_.get(),
73       // Note: pthreadpool_parallelize_1d() is a blocking function.  The
74       // function pointer to this lambda passed on to
75       // pthreadpool_parallelize_1d() cannot go out of scope until
76       // pthreadpool_parallelize_1d() returns.
77       [](void* const context, const size_t item) {
78         reinterpret_cast<Context*>(context)->fn(item);
79       },
80       &context,
81       range,
82       0u);
83 }
84 
85 // Forward declaration
86 size_t getDefaultNumThreads();
87 
pthreadpool()88 PThreadPool* pthreadpool() {
89   static auto threadpool =
90     std::make_unique<PThreadPool>(getDefaultNumThreads());
91 #if !(defined(WIN32))
92   static std::once_flag flag;
93   std::call_once(flag, []() {
94     pthread_atfork(nullptr, nullptr, child_atfork);
95   });
96 #endif
97   if (C10_UNLIKELY(leak_corrupted_threadpool)) {
98     leak_corrupted_threadpool = false;
99     if (auto leaked = threadpool.release()) {
100       auto num_threads = leaked->get_thread_count();
101       // NOLINTNEXTLINE(modernize-make-unique)
102       threadpool.reset(new PThreadPool(num_threads));
103     }
104   }
105   return threadpool.get();
106 }
107 
pthreadpool_()108 pthreadpool_t pthreadpool_() {
109   if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
110     return nullptr;
111   }
112   PThreadPool* const threadpool = pthreadpool();
113   TORCH_INTERNAL_ASSERT(
114       threadpool, "Failed to acquire an instance of PThreadPool!");
115   return threadpool->threadpool_.get();
116 }
117 
118 } // namespace caffe2
119