/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include namespace executorch::extension::threadpool { #if !(defined(WIN32)) namespace { // After fork, the child process inherits the data-structures of the parent // process' thread-pool, but since those threads don't exist, the thread-pool // is corrupt. It's leaked in order to prevent segfaults. // Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302 bool leak_corrupted_threadpool = false; void child_atfork() { leak_corrupted_threadpool = true; } } // namespace #endif ThreadPool::ThreadPool(size_t thread_count) : threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {} size_t ThreadPool::get_thread_count() const { std::lock_guard lock{mutex_}; ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!"); return pthreadpool_get_threads_count(threadpool_.get()); } bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) { // No need to do anything if the count is same or 0 if (new_thread_count == get_thread_count() || new_thread_count == 0) { return true; } std::lock_guard lock{mutex_}; threadpool_.reset(pthreadpool_create(new_thread_count)); return true; } void ThreadPool::run( const std::function& fn, const size_t range) { // Run on same thread if NoThreadPoolGuard guard is enabled if (NoThreadPoolGuard::is_enabled()) { for (size_t i = 0; i < range; ++i) { fn(i); } return; } std::lock_guard lock{mutex_}; ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!"); ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!"); struct Context final { const std::function& fn; } context{ fn, }; pthreadpool_parallelize_1d( threadpool_.get(), // Note: pthreadpool_parallelize_1d() is a blocking function. The // function pointer to this lambda passed on to // pthreadpool_parallelize_1d() cannot go out of scope until // pthreadpool_parallelize_1d() returns. [](void* const context, const size_t item) { NoThreadPoolGuard guard; reinterpret_cast(context)->fn(item); }, &context, range, 0u); } // get_threadpool is not thread safe due to leak_corrupted_threadpool // Make this part threadsafe: TODO(kimishpatel) ThreadPool* get_threadpool() { ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed"); int num_threads = cpuinfo_get_processors_count(); /* * For llvm-tsan, holding limit for the number of locks for a single thread * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst * case is the number of threads in a pool. So we want to limit the threadpool * size to 64 when running with tsan. However, sometimes it is tricky to * detect if we are running under tsan, for now capping the default * threadcount to the tsan limit unconditionally. */ constexpr int tsan_thread_limit = 63; num_threads = std::min(num_threads, tsan_thread_limit); static auto threadpool = std::make_unique(num_threads); // Inheriting from old threadpool to get around segfault issue // commented above at child_atfork #if !(defined(WIN32)) // @lint-ignore CLANGTIDY facebook-hte-std::once_flag static std::once_flag flag; // @lint-ignore CLANGTIDY facebook-hte-std::call_once std::call_once( flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); }); if ET_UNLIKELY (leak_corrupted_threadpool) { leak_corrupted_threadpool = false; if (auto leaked = threadpool.release()) { auto t = leaked->get_thread_count(); threadpool = std::make_unique(t); } } #endif return threadpool.get(); } pthreadpool_t get_pthreadpool() { if (NoThreadPoolGuard::is_enabled()) { return nullptr; } ThreadPool* const threadpool = get_threadpool(); ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!"); return threadpool->threadpool_.get(); } } // namespace executorch::extension::threadpool