1 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
2 #include <caffe2/utils/threadpool/thread_pool_guard.h>
3 #include <c10/util/Exception.h>
4
5 #include <atomic>
6
7 namespace {
8 // After fork, the child process inherits the data-structures of the parent
9 // process' thread-pool, but since those threads don't exist, the thread-pool
10 // is corrupt. It's leaked in order to prevent segfaults.
11 // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
12 bool leak_corrupted_threadpool = false;
13
child_atfork()14 void child_atfork() {
15 leak_corrupted_threadpool = true;
16 }
17
18 } // namespace
19
20 namespace caffe2 {
21
PThreadPool(const size_t thread_count)22 PThreadPool::PThreadPool(const size_t thread_count)
23 : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
24
get_thread_count() const25 size_t PThreadPool::get_thread_count() const {
26 std::lock_guard<std::mutex> lock{mutex_};
27
28 TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
29 return pthreadpool_get_threads_count(threadpool_.get());
30 }
31
set_thread_count(const size_t thread_count)32 void PThreadPool::set_thread_count(const size_t thread_count) {
33 // No need to do anything if the count is same
34 if (thread_count == get_thread_count()) {
35 return;
36 }
37
38 std::lock_guard<std::mutex> lock{mutex_};
39
40 // As it stands, pthreadpool is an entirely data parallel framework with no
41 // support for task parallelism. Hence, all functions are blocking, and no
42 // user-provided tasks can be in flight when the control is returned to the
43 // user of the API, which means re-initializing the library, without the
44 // need to wait on any pending tasks, is all one needs to do to re-adjust
45 // the thread count.
46 threadpool_.reset(pthreadpool_create(thread_count));
47 }
48
run(const std::function<void (size_t)> & fn,const size_t range)49 void PThreadPool::run(
50 const std::function<void(size_t)>& fn,
51 const size_t range) {
52 // Run on same thread if _NoPThreadPoolGuard guard is enabled
53 if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
54 for (size_t i = 0; i < range; ++i) {
55 fn(i);
56 }
57 return;
58 }
59
60 std::lock_guard<std::mutex> lock{mutex_};
61
62 TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
63 TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
64
65 struct Context final {
66 const std::function<void(size_t)>& fn;
67 } context{
68 fn,
69 };
70
71 pthreadpool_parallelize_1d(
72 threadpool_.get(),
73 // Note: pthreadpool_parallelize_1d() is a blocking function. The
74 // function pointer to this lambda passed on to
75 // pthreadpool_parallelize_1d() cannot go out of scope until
76 // pthreadpool_parallelize_1d() returns.
77 [](void* const context, const size_t item) {
78 reinterpret_cast<Context*>(context)->fn(item);
79 },
80 &context,
81 range,
82 0u);
83 }
84
85 // Forward declaration
86 size_t getDefaultNumThreads();
87
pthreadpool()88 PThreadPool* pthreadpool() {
89 static auto threadpool =
90 std::make_unique<PThreadPool>(getDefaultNumThreads());
91 #if !(defined(WIN32))
92 static std::once_flag flag;
93 std::call_once(flag, []() {
94 pthread_atfork(nullptr, nullptr, child_atfork);
95 });
96 #endif
97 if (C10_UNLIKELY(leak_corrupted_threadpool)) {
98 leak_corrupted_threadpool = false;
99 if (auto leaked = threadpool.release()) {
100 auto num_threads = leaked->get_thread_count();
101 // NOLINTNEXTLINE(modernize-make-unique)
102 threadpool.reset(new PThreadPool(num_threads));
103 }
104 }
105 return threadpool.get();
106 }
107
pthreadpool_()108 pthreadpool_t pthreadpool_() {
109 if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
110 return nullptr;
111 }
112 PThreadPool* const threadpool = pthreadpool();
113 TORCH_INTERNAL_ASSERT(
114 threadpool, "Failed to acquire an instance of PThreadPool!");
115 return threadpool->threadpool_.get();
116 }
117
118 } // namespace caffe2
119