xref: /aosp_15_r20/external/executorch/extension/threadpool/threadpool.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
12*523fa7a6SAndroid Build Coastguard Worker #include <atomic>
13*523fa7a6SAndroid Build Coastguard Worker #include <memory>
14*523fa7a6SAndroid Build Coastguard Worker 
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool_guard.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/assert.h>
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker #include <cpuinfo.h>
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker namespace executorch::extension::threadpool {
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker #if !(defined(WIN32))
23*523fa7a6SAndroid Build Coastguard Worker namespace {
24*523fa7a6SAndroid Build Coastguard Worker // After fork, the child process inherits the data-structures of the parent
25*523fa7a6SAndroid Build Coastguard Worker // process' thread-pool, but since those threads don't exist, the thread-pool
26*523fa7a6SAndroid Build Coastguard Worker // is corrupt. It's leaked in order to prevent segfaults.
27*523fa7a6SAndroid Build Coastguard Worker // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
28*523fa7a6SAndroid Build Coastguard Worker bool leak_corrupted_threadpool = false;
29*523fa7a6SAndroid Build Coastguard Worker 
child_atfork()30*523fa7a6SAndroid Build Coastguard Worker void child_atfork() {
31*523fa7a6SAndroid Build Coastguard Worker   leak_corrupted_threadpool = true;
32*523fa7a6SAndroid Build Coastguard Worker }
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker } // namespace
35*523fa7a6SAndroid Build Coastguard Worker #endif
36*523fa7a6SAndroid Build Coastguard Worker 
ThreadPool(size_t thread_count)37*523fa7a6SAndroid Build Coastguard Worker ThreadPool::ThreadPool(size_t thread_count)
38*523fa7a6SAndroid Build Coastguard Worker     : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
39*523fa7a6SAndroid Build Coastguard Worker 
get_thread_count() const40*523fa7a6SAndroid Build Coastguard Worker size_t ThreadPool::get_thread_count() const {
41*523fa7a6SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
44*523fa7a6SAndroid Build Coastguard Worker   return pthreadpool_get_threads_count(threadpool_.get());
45*523fa7a6SAndroid Build Coastguard Worker }
46*523fa7a6SAndroid Build Coastguard Worker 
_unsafe_reset_threadpool(uint32_t new_thread_count)47*523fa7a6SAndroid Build Coastguard Worker bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
48*523fa7a6SAndroid Build Coastguard Worker   // No need to do anything if the count is same or 0
49*523fa7a6SAndroid Build Coastguard Worker   if (new_thread_count == get_thread_count() || new_thread_count == 0) {
50*523fa7a6SAndroid Build Coastguard Worker     return true;
51*523fa7a6SAndroid Build Coastguard Worker   }
52*523fa7a6SAndroid Build Coastguard Worker 
53*523fa7a6SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
54*523fa7a6SAndroid Build Coastguard Worker 
55*523fa7a6SAndroid Build Coastguard Worker   threadpool_.reset(pthreadpool_create(new_thread_count));
56*523fa7a6SAndroid Build Coastguard Worker   return true;
57*523fa7a6SAndroid Build Coastguard Worker }
58*523fa7a6SAndroid Build Coastguard Worker 
run(const std::function<void (size_t)> & fn,const size_t range)59*523fa7a6SAndroid Build Coastguard Worker void ThreadPool::run(
60*523fa7a6SAndroid Build Coastguard Worker     const std::function<void(size_t)>& fn,
61*523fa7a6SAndroid Build Coastguard Worker     const size_t range) {
62*523fa7a6SAndroid Build Coastguard Worker   // Run on same thread if NoThreadPoolGuard guard is enabled
63*523fa7a6SAndroid Build Coastguard Worker   if (NoThreadPoolGuard::is_enabled()) {
64*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < range; ++i) {
65*523fa7a6SAndroid Build Coastguard Worker       fn(i);
66*523fa7a6SAndroid Build Coastguard Worker     }
67*523fa7a6SAndroid Build Coastguard Worker     return;
68*523fa7a6SAndroid Build Coastguard Worker   }
69*523fa7a6SAndroid Build Coastguard Worker 
70*523fa7a6SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock{mutex_};
71*523fa7a6SAndroid Build Coastguard Worker 
72*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
73*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
74*523fa7a6SAndroid Build Coastguard Worker 
75*523fa7a6SAndroid Build Coastguard Worker   struct Context final {
76*523fa7a6SAndroid Build Coastguard Worker     const std::function<void(size_t)>& fn;
77*523fa7a6SAndroid Build Coastguard Worker   } context{
78*523fa7a6SAndroid Build Coastguard Worker       fn,
79*523fa7a6SAndroid Build Coastguard Worker   };
80*523fa7a6SAndroid Build Coastguard Worker 
81*523fa7a6SAndroid Build Coastguard Worker   pthreadpool_parallelize_1d(
82*523fa7a6SAndroid Build Coastguard Worker       threadpool_.get(),
83*523fa7a6SAndroid Build Coastguard Worker       // Note: pthreadpool_parallelize_1d() is a blocking function.  The
84*523fa7a6SAndroid Build Coastguard Worker       // function pointer to this lambda passed on to
85*523fa7a6SAndroid Build Coastguard Worker       // pthreadpool_parallelize_1d() cannot go out of scope until
86*523fa7a6SAndroid Build Coastguard Worker       // pthreadpool_parallelize_1d() returns.
87*523fa7a6SAndroid Build Coastguard Worker       [](void* const context, const size_t item) {
88*523fa7a6SAndroid Build Coastguard Worker         NoThreadPoolGuard guard;
89*523fa7a6SAndroid Build Coastguard Worker         reinterpret_cast<Context*>(context)->fn(item);
90*523fa7a6SAndroid Build Coastguard Worker       },
91*523fa7a6SAndroid Build Coastguard Worker       &context,
92*523fa7a6SAndroid Build Coastguard Worker       range,
93*523fa7a6SAndroid Build Coastguard Worker       0u);
94*523fa7a6SAndroid Build Coastguard Worker }
95*523fa7a6SAndroid Build Coastguard Worker 
96*523fa7a6SAndroid Build Coastguard Worker // get_threadpool is not thread safe due to leak_corrupted_threadpool
97*523fa7a6SAndroid Build Coastguard Worker // Make this part threadsafe: TODO(kimishpatel)
get_threadpool()98*523fa7a6SAndroid Build Coastguard Worker ThreadPool* get_threadpool() {
99*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
100*523fa7a6SAndroid Build Coastguard Worker   int num_threads = cpuinfo_get_processors_count();
101*523fa7a6SAndroid Build Coastguard Worker   /*
102*523fa7a6SAndroid Build Coastguard Worker    * For llvm-tsan, holding limit for the number of locks for a single thread
103*523fa7a6SAndroid Build Coastguard Worker    * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
104*523fa7a6SAndroid Build Coastguard Worker    * case is the number of threads in a pool. So we want to limit the threadpool
105*523fa7a6SAndroid Build Coastguard Worker    * size to 64 when running with tsan. However, sometimes it is tricky to
106*523fa7a6SAndroid Build Coastguard Worker    * detect if we are running under tsan, for now capping the default
107*523fa7a6SAndroid Build Coastguard Worker    * threadcount to the tsan limit unconditionally.
108*523fa7a6SAndroid Build Coastguard Worker    */
109*523fa7a6SAndroid Build Coastguard Worker   constexpr int tsan_thread_limit = 63;
110*523fa7a6SAndroid Build Coastguard Worker   num_threads = std::min(num_threads, tsan_thread_limit);
111*523fa7a6SAndroid Build Coastguard Worker   static auto threadpool = std::make_unique<ThreadPool>(num_threads);
112*523fa7a6SAndroid Build Coastguard Worker 
113*523fa7a6SAndroid Build Coastguard Worker // Inheriting from old threadpool to get around segfault issue
114*523fa7a6SAndroid Build Coastguard Worker // commented above at child_atfork
115*523fa7a6SAndroid Build Coastguard Worker #if !(defined(WIN32))
116*523fa7a6SAndroid Build Coastguard Worker   // @lint-ignore CLANGTIDY facebook-hte-std::once_flag
117*523fa7a6SAndroid Build Coastguard Worker   static std::once_flag flag;
118*523fa7a6SAndroid Build Coastguard Worker   // @lint-ignore CLANGTIDY facebook-hte-std::call_once
119*523fa7a6SAndroid Build Coastguard Worker   std::call_once(
120*523fa7a6SAndroid Build Coastguard Worker       flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
121*523fa7a6SAndroid Build Coastguard Worker   if ET_UNLIKELY (leak_corrupted_threadpool) {
122*523fa7a6SAndroid Build Coastguard Worker     leak_corrupted_threadpool = false;
123*523fa7a6SAndroid Build Coastguard Worker     if (auto leaked = threadpool.release()) {
124*523fa7a6SAndroid Build Coastguard Worker       auto t = leaked->get_thread_count();
125*523fa7a6SAndroid Build Coastguard Worker       threadpool = std::make_unique<ThreadPool>(t);
126*523fa7a6SAndroid Build Coastguard Worker     }
127*523fa7a6SAndroid Build Coastguard Worker   }
128*523fa7a6SAndroid Build Coastguard Worker #endif
129*523fa7a6SAndroid Build Coastguard Worker   return threadpool.get();
130*523fa7a6SAndroid Build Coastguard Worker }
131*523fa7a6SAndroid Build Coastguard Worker 
get_pthreadpool()132*523fa7a6SAndroid Build Coastguard Worker pthreadpool_t get_pthreadpool() {
133*523fa7a6SAndroid Build Coastguard Worker   if (NoThreadPoolGuard::is_enabled()) {
134*523fa7a6SAndroid Build Coastguard Worker     return nullptr;
135*523fa7a6SAndroid Build Coastguard Worker   }
136*523fa7a6SAndroid Build Coastguard Worker   ThreadPool* const threadpool = get_threadpool();
137*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
138*523fa7a6SAndroid Build Coastguard Worker   return threadpool->threadpool_.get();
139*523fa7a6SAndroid Build Coastguard Worker }
140*523fa7a6SAndroid Build Coastguard Worker 
141*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch::extension::threadpool
142