1*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
2*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/thread_pool_guard.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <atomic>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker namespace {
8*da0073e9SAndroid Build Coastguard Worker // After fork, the child process inherits the data-structures of the parent
9*da0073e9SAndroid Build Coastguard Worker // process' thread-pool, but since those threads don't exist, the thread-pool
10*da0073e9SAndroid Build Coastguard Worker // is corrupt. It's leaked in order to prevent segfaults.
11*da0073e9SAndroid Build Coastguard Worker // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
12*da0073e9SAndroid Build Coastguard Worker bool leak_corrupted_threadpool = false;
13*da0073e9SAndroid Build Coastguard Worker
child_atfork()14*da0073e9SAndroid Build Coastguard Worker void child_atfork() {
15*da0073e9SAndroid Build Coastguard Worker leak_corrupted_threadpool = true;
16*da0073e9SAndroid Build Coastguard Worker }
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker } // namespace
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
21*da0073e9SAndroid Build Coastguard Worker
PThreadPool(const size_t thread_count)22*da0073e9SAndroid Build Coastguard Worker PThreadPool::PThreadPool(const size_t thread_count)
23*da0073e9SAndroid Build Coastguard Worker : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
24*da0073e9SAndroid Build Coastguard Worker
get_thread_count() const25*da0073e9SAndroid Build Coastguard Worker size_t PThreadPool::get_thread_count() const {
26*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
29*da0073e9SAndroid Build Coastguard Worker return pthreadpool_get_threads_count(threadpool_.get());
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker
set_thread_count(const size_t thread_count)32*da0073e9SAndroid Build Coastguard Worker void PThreadPool::set_thread_count(const size_t thread_count) {
33*da0073e9SAndroid Build Coastguard Worker // No need to do anything if the count is same
34*da0073e9SAndroid Build Coastguard Worker if (thread_count == get_thread_count()) {
35*da0073e9SAndroid Build Coastguard Worker return;
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker // As it stands, pthreadpool is an entirely data parallel framework with no
41*da0073e9SAndroid Build Coastguard Worker // support for task parallelism. Hence, all functions are blocking, and no
42*da0073e9SAndroid Build Coastguard Worker // user-provided tasks can be in flight when the control is returned to the
43*da0073e9SAndroid Build Coastguard Worker // user of the API, which means re-initializing the library, without the
44*da0073e9SAndroid Build Coastguard Worker // need to wait on any pending tasks, is all one needs to do to re-adjust
45*da0073e9SAndroid Build Coastguard Worker // the thread count.
46*da0073e9SAndroid Build Coastguard Worker threadpool_.reset(pthreadpool_create(thread_count));
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker
run(const std::function<void (size_t)> & fn,const size_t range)49*da0073e9SAndroid Build Coastguard Worker void PThreadPool::run(
50*da0073e9SAndroid Build Coastguard Worker const std::function<void(size_t)>& fn,
51*da0073e9SAndroid Build Coastguard Worker const size_t range) {
52*da0073e9SAndroid Build Coastguard Worker // Run on same thread if _NoPThreadPoolGuard guard is enabled
53*da0073e9SAndroid Build Coastguard Worker if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
54*da0073e9SAndroid Build Coastguard Worker for (size_t i = 0; i < range; ++i) {
55*da0073e9SAndroid Build Coastguard Worker fn(i);
56*da0073e9SAndroid Build Coastguard Worker }
57*da0073e9SAndroid Build Coastguard Worker return;
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
63*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker struct Context final {
66*da0073e9SAndroid Build Coastguard Worker const std::function<void(size_t)>& fn;
67*da0073e9SAndroid Build Coastguard Worker } context{
68*da0073e9SAndroid Build Coastguard Worker fn,
69*da0073e9SAndroid Build Coastguard Worker };
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker pthreadpool_parallelize_1d(
72*da0073e9SAndroid Build Coastguard Worker threadpool_.get(),
73*da0073e9SAndroid Build Coastguard Worker // Note: pthreadpool_parallelize_1d() is a blocking function. The
74*da0073e9SAndroid Build Coastguard Worker // function pointer to this lambda passed on to
75*da0073e9SAndroid Build Coastguard Worker // pthreadpool_parallelize_1d() cannot go out of scope until
76*da0073e9SAndroid Build Coastguard Worker // pthreadpool_parallelize_1d() returns.
77*da0073e9SAndroid Build Coastguard Worker [](void* const context, const size_t item) {
78*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<Context*>(context)->fn(item);
79*da0073e9SAndroid Build Coastguard Worker },
80*da0073e9SAndroid Build Coastguard Worker &context,
81*da0073e9SAndroid Build Coastguard Worker range,
82*da0073e9SAndroid Build Coastguard Worker 0u);
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker // Forward declaration
86*da0073e9SAndroid Build Coastguard Worker size_t getDefaultNumThreads();
87*da0073e9SAndroid Build Coastguard Worker
pthreadpool()88*da0073e9SAndroid Build Coastguard Worker PThreadPool* pthreadpool() {
89*da0073e9SAndroid Build Coastguard Worker static auto threadpool =
90*da0073e9SAndroid Build Coastguard Worker std::make_unique<PThreadPool>(getDefaultNumThreads());
91*da0073e9SAndroid Build Coastguard Worker #if !(defined(WIN32))
92*da0073e9SAndroid Build Coastguard Worker static std::once_flag flag;
93*da0073e9SAndroid Build Coastguard Worker std::call_once(flag, []() {
94*da0073e9SAndroid Build Coastguard Worker pthread_atfork(nullptr, nullptr, child_atfork);
95*da0073e9SAndroid Build Coastguard Worker });
96*da0073e9SAndroid Build Coastguard Worker #endif
97*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(leak_corrupted_threadpool)) {
98*da0073e9SAndroid Build Coastguard Worker leak_corrupted_threadpool = false;
99*da0073e9SAndroid Build Coastguard Worker if (auto leaked = threadpool.release()) {
100*da0073e9SAndroid Build Coastguard Worker auto num_threads = leaked->get_thread_count();
101*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(modernize-make-unique)
102*da0073e9SAndroid Build Coastguard Worker threadpool.reset(new PThreadPool(num_threads));
103*da0073e9SAndroid Build Coastguard Worker }
104*da0073e9SAndroid Build Coastguard Worker }
105*da0073e9SAndroid Build Coastguard Worker return threadpool.get();
106*da0073e9SAndroid Build Coastguard Worker }
107*da0073e9SAndroid Build Coastguard Worker
pthreadpool_()108*da0073e9SAndroid Build Coastguard Worker pthreadpool_t pthreadpool_() {
109*da0073e9SAndroid Build Coastguard Worker if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
110*da0073e9SAndroid Build Coastguard Worker return nullptr;
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker PThreadPool* const threadpool = pthreadpool();
113*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
114*da0073e9SAndroid Build Coastguard Worker threadpool, "Failed to acquire an instance of PThreadPool!");
115*da0073e9SAndroid Build Coastguard Worker return threadpool->threadpool_.get();
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
119