xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Parallel-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/ParallelGuard.h>
5 #include <c10/util/SmallVector.h>
6 
7 namespace at {
8 
9 template <class F>
parallel_for(const int64_t begin,const int64_t end,const int64_t grain_size,const F & f)10 inline void parallel_for(
11     const int64_t begin,
12     const int64_t end,
13     const int64_t grain_size,
14     const F& f) {
15   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
16   if (begin >= end) {
17     return;
18   }
19 
20 #ifdef INTRA_OP_PARALLEL
21   at::internal::lazy_init_num_threads();
22   const auto numiter = end - begin;
23   const bool use_parallel =
24       (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
25        at::get_num_threads() > 1);
26   if (!use_parallel) {
27     internal::ThreadIdGuard tid_guard(0);
28     c10::ParallelGuard guard(true);
29     f(begin, end);
30     return;
31   }
32 
33   internal::invoke_parallel(
34       begin, end, grain_size, [&](int64_t begin, int64_t end) {
35         c10::ParallelGuard guard(true);
36         f(begin, end);
37       });
38 #else
39   internal::ThreadIdGuard tid_guard(0);
40   c10::ParallelGuard guard(true);
41   f(begin, end);
42 #endif
43 }
44 
45 template <class scalar_t, class F, class SF>
parallel_reduce(const int64_t begin,const int64_t end,const int64_t grain_size,const scalar_t ident,const F & f,const SF & sf)46 inline scalar_t parallel_reduce(
47     const int64_t begin,
48     const int64_t end,
49     const int64_t grain_size,
50     const scalar_t ident,
51     const F& f,
52     const SF& sf) {
53   TORCH_CHECK(grain_size >= 0);
54   if (begin >= end) {
55     return ident;
56   }
57 
58 #ifdef INTRA_OP_PARALLEL
59   at::internal::lazy_init_num_threads();
60   const auto max_threads = at::get_num_threads();
61   const bool use_parallel =
62       ((end - begin) > grain_size && !at::in_parallel_region() &&
63        max_threads > 1);
64   if (!use_parallel) {
65     internal::ThreadIdGuard tid_guard(0);
66     c10::ParallelGuard guard(true);
67     return f(begin, end, ident);
68   }
69 
70   c10::SmallVector<scalar_t, 64> results(max_threads, ident);
71   internal::invoke_parallel(
72       begin,
73       end,
74       grain_size,
75       [&](const int64_t my_begin, const int64_t my_end) {
76         const auto tid = at::get_thread_num();
77         c10::ParallelGuard guard(true);
78         results[tid] = f(my_begin, my_end, ident);
79       });
80 
81   scalar_t result = ident;
82   for (auto partial_result : results) {
83     result = sf(result, partial_result);
84   }
85   return result;
86 #else
87   internal::ThreadIdGuard tid_guard(0);
88   c10::ParallelGuard guard(true);
89   return f(begin, end, ident);
90 #endif
91 }
92 
93 } // namespace at
94