xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Parallel-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ParallelGuard.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/SmallVector.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace at {
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker template <class F>
parallel_for(const int64_t begin,const int64_t end,const int64_t grain_size,const F & f)10*da0073e9SAndroid Build Coastguard Worker inline void parallel_for(
11*da0073e9SAndroid Build Coastguard Worker     const int64_t begin,
12*da0073e9SAndroid Build Coastguard Worker     const int64_t end,
13*da0073e9SAndroid Build Coastguard Worker     const int64_t grain_size,
14*da0073e9SAndroid Build Coastguard Worker     const F& f) {
15*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
16*da0073e9SAndroid Build Coastguard Worker   if (begin >= end) {
17*da0073e9SAndroid Build Coastguard Worker     return;
18*da0073e9SAndroid Build Coastguard Worker   }
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker #ifdef INTRA_OP_PARALLEL
21*da0073e9SAndroid Build Coastguard Worker   at::internal::lazy_init_num_threads();
22*da0073e9SAndroid Build Coastguard Worker   const auto numiter = end - begin;
23*da0073e9SAndroid Build Coastguard Worker   const bool use_parallel =
24*da0073e9SAndroid Build Coastguard Worker       (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
25*da0073e9SAndroid Build Coastguard Worker        at::get_num_threads() > 1);
26*da0073e9SAndroid Build Coastguard Worker   if (!use_parallel) {
27*da0073e9SAndroid Build Coastguard Worker     internal::ThreadIdGuard tid_guard(0);
28*da0073e9SAndroid Build Coastguard Worker     c10::ParallelGuard guard(true);
29*da0073e9SAndroid Build Coastguard Worker     f(begin, end);
30*da0073e9SAndroid Build Coastguard Worker     return;
31*da0073e9SAndroid Build Coastguard Worker   }
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker   internal::invoke_parallel(
34*da0073e9SAndroid Build Coastguard Worker       begin, end, grain_size, [&](int64_t begin, int64_t end) {
35*da0073e9SAndroid Build Coastguard Worker         c10::ParallelGuard guard(true);
36*da0073e9SAndroid Build Coastguard Worker         f(begin, end);
37*da0073e9SAndroid Build Coastguard Worker       });
38*da0073e9SAndroid Build Coastguard Worker #else
39*da0073e9SAndroid Build Coastguard Worker   internal::ThreadIdGuard tid_guard(0);
40*da0073e9SAndroid Build Coastguard Worker   c10::ParallelGuard guard(true);
41*da0073e9SAndroid Build Coastguard Worker   f(begin, end);
42*da0073e9SAndroid Build Coastguard Worker #endif
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker template <class scalar_t, class F, class SF>
parallel_reduce(const int64_t begin,const int64_t end,const int64_t grain_size,const scalar_t ident,const F & f,const SF & sf)46*da0073e9SAndroid Build Coastguard Worker inline scalar_t parallel_reduce(
47*da0073e9SAndroid Build Coastguard Worker     const int64_t begin,
48*da0073e9SAndroid Build Coastguard Worker     const int64_t end,
49*da0073e9SAndroid Build Coastguard Worker     const int64_t grain_size,
50*da0073e9SAndroid Build Coastguard Worker     const scalar_t ident,
51*da0073e9SAndroid Build Coastguard Worker     const F& f,
52*da0073e9SAndroid Build Coastguard Worker     const SF& sf) {
53*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(grain_size >= 0);
54*da0073e9SAndroid Build Coastguard Worker   if (begin >= end) {
55*da0073e9SAndroid Build Coastguard Worker     return ident;
56*da0073e9SAndroid Build Coastguard Worker   }
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker #ifdef INTRA_OP_PARALLEL
59*da0073e9SAndroid Build Coastguard Worker   at::internal::lazy_init_num_threads();
60*da0073e9SAndroid Build Coastguard Worker   const auto max_threads = at::get_num_threads();
61*da0073e9SAndroid Build Coastguard Worker   const bool use_parallel =
62*da0073e9SAndroid Build Coastguard Worker       ((end - begin) > grain_size && !at::in_parallel_region() &&
63*da0073e9SAndroid Build Coastguard Worker        max_threads > 1);
64*da0073e9SAndroid Build Coastguard Worker   if (!use_parallel) {
65*da0073e9SAndroid Build Coastguard Worker     internal::ThreadIdGuard tid_guard(0);
66*da0073e9SAndroid Build Coastguard Worker     c10::ParallelGuard guard(true);
67*da0073e9SAndroid Build Coastguard Worker     return f(begin, end, ident);
68*da0073e9SAndroid Build Coastguard Worker   }
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   c10::SmallVector<scalar_t, 64> results(max_threads, ident);
71*da0073e9SAndroid Build Coastguard Worker   internal::invoke_parallel(
72*da0073e9SAndroid Build Coastguard Worker       begin,
73*da0073e9SAndroid Build Coastguard Worker       end,
74*da0073e9SAndroid Build Coastguard Worker       grain_size,
75*da0073e9SAndroid Build Coastguard Worker       [&](const int64_t my_begin, const int64_t my_end) {
76*da0073e9SAndroid Build Coastguard Worker         const auto tid = at::get_thread_num();
77*da0073e9SAndroid Build Coastguard Worker         c10::ParallelGuard guard(true);
78*da0073e9SAndroid Build Coastguard Worker         results[tid] = f(my_begin, my_end, ident);
79*da0073e9SAndroid Build Coastguard Worker       });
80*da0073e9SAndroid Build Coastguard Worker 
81*da0073e9SAndroid Build Coastguard Worker   scalar_t result = ident;
82*da0073e9SAndroid Build Coastguard Worker   for (auto partial_result : results) {
83*da0073e9SAndroid Build Coastguard Worker     result = sf(result, partial_result);
84*da0073e9SAndroid Build Coastguard Worker   }
85*da0073e9SAndroid Build Coastguard Worker   return result;
86*da0073e9SAndroid Build Coastguard Worker #else
87*da0073e9SAndroid Build Coastguard Worker   internal::ThreadIdGuard tid_guard(0);
88*da0073e9SAndroid Build Coastguard Worker   c10::ParallelGuard guard(true);
89*da0073e9SAndroid Build Coastguard Worker   return f(begin, end, ident);
90*da0073e9SAndroid Build Coastguard Worker #endif
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker } // namespace at
94