1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ParallelGuard.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/SmallVector.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker namespace at {
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker template <class F>
parallel_for(const int64_t begin,const int64_t end,const int64_t grain_size,const F & f)10*da0073e9SAndroid Build Coastguard Worker inline void parallel_for(
11*da0073e9SAndroid Build Coastguard Worker const int64_t begin,
12*da0073e9SAndroid Build Coastguard Worker const int64_t end,
13*da0073e9SAndroid Build Coastguard Worker const int64_t grain_size,
14*da0073e9SAndroid Build Coastguard Worker const F& f) {
15*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
16*da0073e9SAndroid Build Coastguard Worker if (begin >= end) {
17*da0073e9SAndroid Build Coastguard Worker return;
18*da0073e9SAndroid Build Coastguard Worker }
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker #ifdef INTRA_OP_PARALLEL
21*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads();
22*da0073e9SAndroid Build Coastguard Worker const auto numiter = end - begin;
23*da0073e9SAndroid Build Coastguard Worker const bool use_parallel =
24*da0073e9SAndroid Build Coastguard Worker (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
25*da0073e9SAndroid Build Coastguard Worker at::get_num_threads() > 1);
26*da0073e9SAndroid Build Coastguard Worker if (!use_parallel) {
27*da0073e9SAndroid Build Coastguard Worker internal::ThreadIdGuard tid_guard(0);
28*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
29*da0073e9SAndroid Build Coastguard Worker f(begin, end);
30*da0073e9SAndroid Build Coastguard Worker return;
31*da0073e9SAndroid Build Coastguard Worker }
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker internal::invoke_parallel(
34*da0073e9SAndroid Build Coastguard Worker begin, end, grain_size, [&](int64_t begin, int64_t end) {
35*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
36*da0073e9SAndroid Build Coastguard Worker f(begin, end);
37*da0073e9SAndroid Build Coastguard Worker });
38*da0073e9SAndroid Build Coastguard Worker #else
39*da0073e9SAndroid Build Coastguard Worker internal::ThreadIdGuard tid_guard(0);
40*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
41*da0073e9SAndroid Build Coastguard Worker f(begin, end);
42*da0073e9SAndroid Build Coastguard Worker #endif
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker template <class scalar_t, class F, class SF>
parallel_reduce(const int64_t begin,const int64_t end,const int64_t grain_size,const scalar_t ident,const F & f,const SF & sf)46*da0073e9SAndroid Build Coastguard Worker inline scalar_t parallel_reduce(
47*da0073e9SAndroid Build Coastguard Worker const int64_t begin,
48*da0073e9SAndroid Build Coastguard Worker const int64_t end,
49*da0073e9SAndroid Build Coastguard Worker const int64_t grain_size,
50*da0073e9SAndroid Build Coastguard Worker const scalar_t ident,
51*da0073e9SAndroid Build Coastguard Worker const F& f,
52*da0073e9SAndroid Build Coastguard Worker const SF& sf) {
53*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(grain_size >= 0);
54*da0073e9SAndroid Build Coastguard Worker if (begin >= end) {
55*da0073e9SAndroid Build Coastguard Worker return ident;
56*da0073e9SAndroid Build Coastguard Worker }
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker #ifdef INTRA_OP_PARALLEL
59*da0073e9SAndroid Build Coastguard Worker at::internal::lazy_init_num_threads();
60*da0073e9SAndroid Build Coastguard Worker const auto max_threads = at::get_num_threads();
61*da0073e9SAndroid Build Coastguard Worker const bool use_parallel =
62*da0073e9SAndroid Build Coastguard Worker ((end - begin) > grain_size && !at::in_parallel_region() &&
63*da0073e9SAndroid Build Coastguard Worker max_threads > 1);
64*da0073e9SAndroid Build Coastguard Worker if (!use_parallel) {
65*da0073e9SAndroid Build Coastguard Worker internal::ThreadIdGuard tid_guard(0);
66*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
67*da0073e9SAndroid Build Coastguard Worker return f(begin, end, ident);
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker c10::SmallVector<scalar_t, 64> results(max_threads, ident);
71*da0073e9SAndroid Build Coastguard Worker internal::invoke_parallel(
72*da0073e9SAndroid Build Coastguard Worker begin,
73*da0073e9SAndroid Build Coastguard Worker end,
74*da0073e9SAndroid Build Coastguard Worker grain_size,
75*da0073e9SAndroid Build Coastguard Worker [&](const int64_t my_begin, const int64_t my_end) {
76*da0073e9SAndroid Build Coastguard Worker const auto tid = at::get_thread_num();
77*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
78*da0073e9SAndroid Build Coastguard Worker results[tid] = f(my_begin, my_end, ident);
79*da0073e9SAndroid Build Coastguard Worker });
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker scalar_t result = ident;
82*da0073e9SAndroid Build Coastguard Worker for (auto partial_result : results) {
83*da0073e9SAndroid Build Coastguard Worker result = sf(result, partial_result);
84*da0073e9SAndroid Build Coastguard Worker }
85*da0073e9SAndroid Build Coastguard Worker return result;
86*da0073e9SAndroid Build Coastguard Worker #else
87*da0073e9SAndroid Build Coastguard Worker internal::ThreadIdGuard tid_guard(0);
88*da0073e9SAndroid Build Coastguard Worker c10::ParallelGuard guard(true);
89*da0073e9SAndroid Build Coastguard Worker return f(begin, end, ident);
90*da0073e9SAndroid Build Coastguard Worker #endif
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker } // namespace at
94