1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
12*523fa7a6SAndroid Build Coastguard Worker #include <atomic>
13*523fa7a6SAndroid Build Coastguard Worker #include <memory>
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool_guard.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/assert.h>
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker #include <cpuinfo.h>
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker namespace executorch::extension::threadpool {
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker #if !(defined(WIN32))
23*523fa7a6SAndroid Build Coastguard Worker namespace {
24*523fa7a6SAndroid Build Coastguard Worker // After fork, the child process inherits the data-structures of the parent
25*523fa7a6SAndroid Build Coastguard Worker // process' thread-pool, but since those threads don't exist, the thread-pool
26*523fa7a6SAndroid Build Coastguard Worker // is corrupt. It's leaked in order to prevent segfaults.
27*523fa7a6SAndroid Build Coastguard Worker // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
28*523fa7a6SAndroid Build Coastguard Worker bool leak_corrupted_threadpool = false;
29*523fa7a6SAndroid Build Coastguard Worker
child_atfork()30*523fa7a6SAndroid Build Coastguard Worker void child_atfork() {
31*523fa7a6SAndroid Build Coastguard Worker leak_corrupted_threadpool = true;
32*523fa7a6SAndroid Build Coastguard Worker }
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker } // namespace
35*523fa7a6SAndroid Build Coastguard Worker #endif
36*523fa7a6SAndroid Build Coastguard Worker
ThreadPool(size_t thread_count)37*523fa7a6SAndroid Build Coastguard Worker ThreadPool::ThreadPool(size_t thread_count)
38*523fa7a6SAndroid Build Coastguard Worker : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
39*523fa7a6SAndroid Build Coastguard Worker
get_thread_count() const40*523fa7a6SAndroid Build Coastguard Worker size_t ThreadPool::get_thread_count() const {
41*523fa7a6SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
44*523fa7a6SAndroid Build Coastguard Worker return pthreadpool_get_threads_count(threadpool_.get());
45*523fa7a6SAndroid Build Coastguard Worker }
46*523fa7a6SAndroid Build Coastguard Worker
_unsafe_reset_threadpool(uint32_t new_thread_count)47*523fa7a6SAndroid Build Coastguard Worker bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
48*523fa7a6SAndroid Build Coastguard Worker // No need to do anything if the count is same or 0
49*523fa7a6SAndroid Build Coastguard Worker if (new_thread_count == get_thread_count() || new_thread_count == 0) {
50*523fa7a6SAndroid Build Coastguard Worker return true;
51*523fa7a6SAndroid Build Coastguard Worker }
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker threadpool_.reset(pthreadpool_create(new_thread_count));
56*523fa7a6SAndroid Build Coastguard Worker return true;
57*523fa7a6SAndroid Build Coastguard Worker }
58*523fa7a6SAndroid Build Coastguard Worker
run(const std::function<void (size_t)> & fn,const size_t range)59*523fa7a6SAndroid Build Coastguard Worker void ThreadPool::run(
60*523fa7a6SAndroid Build Coastguard Worker const std::function<void(size_t)>& fn,
61*523fa7a6SAndroid Build Coastguard Worker const size_t range) {
62*523fa7a6SAndroid Build Coastguard Worker // Run on same thread if NoThreadPoolGuard guard is enabled
63*523fa7a6SAndroid Build Coastguard Worker if (NoThreadPoolGuard::is_enabled()) {
64*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < range; ++i) {
65*523fa7a6SAndroid Build Coastguard Worker fn(i);
66*523fa7a6SAndroid Build Coastguard Worker }
67*523fa7a6SAndroid Build Coastguard Worker return;
68*523fa7a6SAndroid Build Coastguard Worker }
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lock{mutex_};
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
73*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker struct Context final {
76*523fa7a6SAndroid Build Coastguard Worker const std::function<void(size_t)>& fn;
77*523fa7a6SAndroid Build Coastguard Worker } context{
78*523fa7a6SAndroid Build Coastguard Worker fn,
79*523fa7a6SAndroid Build Coastguard Worker };
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker pthreadpool_parallelize_1d(
82*523fa7a6SAndroid Build Coastguard Worker threadpool_.get(),
83*523fa7a6SAndroid Build Coastguard Worker // Note: pthreadpool_parallelize_1d() is a blocking function. The
84*523fa7a6SAndroid Build Coastguard Worker // function pointer to this lambda passed on to
85*523fa7a6SAndroid Build Coastguard Worker // pthreadpool_parallelize_1d() cannot go out of scope until
86*523fa7a6SAndroid Build Coastguard Worker // pthreadpool_parallelize_1d() returns.
87*523fa7a6SAndroid Build Coastguard Worker [](void* const context, const size_t item) {
88*523fa7a6SAndroid Build Coastguard Worker NoThreadPoolGuard guard;
89*523fa7a6SAndroid Build Coastguard Worker reinterpret_cast<Context*>(context)->fn(item);
90*523fa7a6SAndroid Build Coastguard Worker },
91*523fa7a6SAndroid Build Coastguard Worker &context,
92*523fa7a6SAndroid Build Coastguard Worker range,
93*523fa7a6SAndroid Build Coastguard Worker 0u);
94*523fa7a6SAndroid Build Coastguard Worker }
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker // get_threadpool is not thread safe due to leak_corrupted_threadpool
97*523fa7a6SAndroid Build Coastguard Worker // Make this part threadsafe: TODO(kimishpatel)
get_threadpool()98*523fa7a6SAndroid Build Coastguard Worker ThreadPool* get_threadpool() {
99*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
100*523fa7a6SAndroid Build Coastguard Worker int num_threads = cpuinfo_get_processors_count();
101*523fa7a6SAndroid Build Coastguard Worker /*
102*523fa7a6SAndroid Build Coastguard Worker * For llvm-tsan, holding limit for the number of locks for a single thread
103*523fa7a6SAndroid Build Coastguard Worker * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
104*523fa7a6SAndroid Build Coastguard Worker * case is the number of threads in a pool. So we want to limit the threadpool
105*523fa7a6SAndroid Build Coastguard Worker * size to 64 when running with tsan. However, sometimes it is tricky to
106*523fa7a6SAndroid Build Coastguard Worker * detect if we are running under tsan, for now capping the default
107*523fa7a6SAndroid Build Coastguard Worker * threadcount to the tsan limit unconditionally.
108*523fa7a6SAndroid Build Coastguard Worker */
109*523fa7a6SAndroid Build Coastguard Worker constexpr int tsan_thread_limit = 63;
110*523fa7a6SAndroid Build Coastguard Worker num_threads = std::min(num_threads, tsan_thread_limit);
111*523fa7a6SAndroid Build Coastguard Worker static auto threadpool = std::make_unique<ThreadPool>(num_threads);
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker // Inheriting from old threadpool to get around segfault issue
114*523fa7a6SAndroid Build Coastguard Worker // commented above at child_atfork
115*523fa7a6SAndroid Build Coastguard Worker #if !(defined(WIN32))
116*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore CLANGTIDY facebook-hte-std::once_flag
117*523fa7a6SAndroid Build Coastguard Worker static std::once_flag flag;
118*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore CLANGTIDY facebook-hte-std::call_once
119*523fa7a6SAndroid Build Coastguard Worker std::call_once(
120*523fa7a6SAndroid Build Coastguard Worker flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
121*523fa7a6SAndroid Build Coastguard Worker if ET_UNLIKELY (leak_corrupted_threadpool) {
122*523fa7a6SAndroid Build Coastguard Worker leak_corrupted_threadpool = false;
123*523fa7a6SAndroid Build Coastguard Worker if (auto leaked = threadpool.release()) {
124*523fa7a6SAndroid Build Coastguard Worker auto t = leaked->get_thread_count();
125*523fa7a6SAndroid Build Coastguard Worker threadpool = std::make_unique<ThreadPool>(t);
126*523fa7a6SAndroid Build Coastguard Worker }
127*523fa7a6SAndroid Build Coastguard Worker }
128*523fa7a6SAndroid Build Coastguard Worker #endif
129*523fa7a6SAndroid Build Coastguard Worker return threadpool.get();
130*523fa7a6SAndroid Build Coastguard Worker }
131*523fa7a6SAndroid Build Coastguard Worker
get_pthreadpool()132*523fa7a6SAndroid Build Coastguard Worker pthreadpool_t get_pthreadpool() {
133*523fa7a6SAndroid Build Coastguard Worker if (NoThreadPoolGuard::is_enabled()) {
134*523fa7a6SAndroid Build Coastguard Worker return nullptr;
135*523fa7a6SAndroid Build Coastguard Worker }
136*523fa7a6SAndroid Build Coastguard Worker ThreadPool* const threadpool = get_threadpool();
137*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
138*523fa7a6SAndroid Build Coastguard Worker return threadpool->threadpool_.get();
139*523fa7a6SAndroid Build Coastguard Worker }
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch::extension::threadpool
142