#include #if AT_PARALLEL_NATIVE #include #include #include #ifndef C10_MOBILE #include #include #else #include #endif // C10_MOBILE #include #include #ifdef _OPENMP #include #endif #if AT_MKL_ENABLED() #include #endif namespace at { namespace { // used with _set_in_parallel_region to mark master thread // as in parallel region while executing parallel primitives thread_local bool in_parallel_region_ = false; // thread number (task_id) set by parallel primitive thread_local int thread_num_ = 0; void _set_in_parallel_region(bool in_region) { in_parallel_region_ = in_region; } } // namespace (anonymous) namespace internal { void set_thread_num(int thread_num) { thread_num_ = thread_num; } } namespace { void _unset_thread_num() { thread_num_ = 0; } #ifndef C10_MOBILE const int NOT_SET = -1; const int CONSUMED = -2; // Number of threads set by the user // NOT_SET -> positive value -> CONSUMED // or // NOT_SET -> CONSUMED // Meaning: // - NOT_SET - pool not initialized, user value is not set // - positive value - pool not initialized, user value set // - CONSUMED - pool is initialized std::atomic num_intraop_threads{NOT_SET}; int _num_pool_threads(int nthreads) { if (nthreads == NOT_SET) { nthreads = intraop_default_num_threads(); } else { TORCH_INTERNAL_ASSERT(nthreads > 0); } // minus one because of the master thread return nthreads - 1; } TaskThreadPoolBase& _get_intraop_pool() { static std::shared_ptr pool = ThreadPoolRegistry()->Create( "C10", /* device_id */ 0, /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)), /* create_new */ true); // create a separate thread pool for intra-op return *pool; } #endif // C10_MOBILE // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool. // `fn` will be called with params: (thread_pool_task_id, task_id). void _run_with_pool(const std::function& fn, size_t range) { #ifndef C10_MOBILE for (const auto i : c10::irange(1, range)) { _get_intraop_pool().run([fn, i]() { fn((int)i, i); }); } // Run the first task on the current thread directly. fn(0, 0); #else caffe2::PThreadPool* const pool = caffe2::pthreadpool(); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); pool->run( // PThreadPool::run() is blocking. A std::function [const] reference to // this lambda cannot go out of scope before PThreadPool::run() returns. [&fn](const size_t task_id) { fn(0 /* unused */, task_id); }, range); #endif // C10_MOBILE } // RAII guard helps to support in_parallel_region() and get_thread_num() API. struct ParallelRegionGuard { ParallelRegionGuard(int task_id) { internal::set_thread_num(task_id); _set_in_parallel_region(true); } ~ParallelRegionGuard() { _set_in_parallel_region(false); _unset_thread_num(); } }; } // namespace namespace internal { inline std::tuple calc_num_tasks_and_chunk_size( int64_t begin, int64_t end, int64_t grain_size) { if ((end - begin) < grain_size) { return std::make_tuple(1, std::max((int64_t)0, end - begin)); } // Choose number of tasks based on grain size and number of threads. size_t chunk_size = divup((end - begin), get_num_threads()); // Make sure each task is at least grain_size size. chunk_size = std::max((size_t)grain_size, chunk_size); size_t num_tasks = divup((end - begin), chunk_size); return std::make_tuple(num_tasks, chunk_size); } void invoke_parallel( const int64_t begin, const int64_t end, const int64_t grain_size, const std::function& f) { at::internal::lazy_init_num_threads(); size_t num_tasks = 0, chunk_size = 0; std::tie(num_tasks, chunk_size) = internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); struct { std::atomic_flag err_flag = ATOMIC_FLAG_INIT; std::exception_ptr eptr; std::mutex mutex; std::atomic_size_t remaining{0}; std::condition_variable cv; } state; auto task = [f, &state, begin, end, chunk_size] (int /* unused */, size_t task_id) { int64_t local_start = begin + task_id * chunk_size; if (local_start < end) { int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); try { ParallelRegionGuard guard(task_id); f(local_start, local_end); } catch (...) { if (!state.err_flag.test_and_set()) { state.eptr = std::current_exception(); } } } { std::unique_lock lk(state.mutex); if (--state.remaining == 0) { state.cv.notify_one(); } } }; state.remaining = num_tasks; _run_with_pool(std::move(task), num_tasks); // Wait for all tasks to finish. { std::unique_lock lk(state.mutex); if (state.remaining != 0) { state.cv.wait(lk); } } if (state.eptr) { std::rethrow_exception(state.eptr); } } } // namespace internal void init_num_threads() { #ifdef _OPENMP omp_set_num_threads(1); #endif #if AT_MKL_ENABLED() mkl_set_num_threads(1); #endif #ifdef C10_MOBILE caffe2::pthreadpool(); #endif } void set_num_threads(int nthreads) { #ifndef C10_MOBILE TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); int no_value = NOT_SET; if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) { // num_intraop_threads either stores a positive integer or CONSUMED, // check that requested size is the same as the current one int stored_nthreads = num_intraop_threads.load(); if (stored_nthreads <= 0) { // plus one because of master thread // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) stored_nthreads = _get_intraop_pool().size() + 1; } if (stored_nthreads != nthreads) { TORCH_WARN( "Cannot set number of intraop threads " "after parallel work has started or after set_num_threads call " "when using native parallel backend"); } } #else caffe2::PThreadPool* const pool = caffe2::pthreadpool(); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); pool->set_thread_count(nthreads); #endif // C10_MOBILE } int get_num_threads() { at::internal::lazy_init_num_threads(); #ifndef C10_MOBILE // not initializing pool unnecessarily, // because pool cannot be resized after initialization int nthreads = num_intraop_threads.load(); if (nthreads > 0) { return nthreads; } else if (nthreads == NOT_SET) { return intraop_default_num_threads(); } else { TORCH_INTERNAL_ASSERT(nthreads == CONSUMED); // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return _get_intraop_pool().size() + 1; } #else caffe2::PThreadPool* const pool = caffe2::pthreadpool(); TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!") return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count(); #endif // C10_MOBILE } int get_thread_num() { return thread_num_; } bool in_parallel_region() { #ifndef C10_MOBILE return in_parallel_region_ || ( num_intraop_threads.load() == CONSUMED && // Needed as intraop_launch() doesn't set in_parallel_region(). _get_intraop_pool().inThreadPool() ); #else return in_parallel_region_; #endif // C10_MOBILE } void intraop_launch(std::function func) { #ifndef C10_MOBILE if (!in_parallel_region() && get_num_threads() > 1) { _get_intraop_pool().run(std::move(func)); } else { // execute inline if we're in parallel region func(); } #else // TODO: caffe2::PThreadPool only provides a data-parallel API. // Task parallelism is not currently supported. func(); #endif // C10_MOBILE } c10::intrusive_ptr intraop_launch_future( std::function func) { #ifndef C10_MOBILE auto future = c10::make_intrusive(c10::NoneType::get()); if (!in_parallel_region() && get_num_threads() > 1) { _get_intraop_pool().run( [func, future]() { func(); future->markCompleted(); } ); } else { func(); future->markCompleted(); } return future; #else // TODO: caffe2::PThreadPool only provides a data-parallel API. // Task parallelism is not currently supported. auto future = c10::make_intrusive(c10::dynT()); func(); future->markCompleted(); return future; #endif // C10_MOBILE } } // namespace at #endif