1 #pragma once
2
3 #include <c10/util/Exception.h>
4 #include <c10/util/ParallelGuard.h>
5 #include <c10/util/SmallVector.h>
6
7 namespace at {
8
9 template <class F>
parallel_for(const int64_t begin,const int64_t end,const int64_t grain_size,const F & f)10 inline void parallel_for(
11 const int64_t begin,
12 const int64_t end,
13 const int64_t grain_size,
14 const F& f) {
15 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
16 if (begin >= end) {
17 return;
18 }
19
20 #ifdef INTRA_OP_PARALLEL
21 at::internal::lazy_init_num_threads();
22 const auto numiter = end - begin;
23 const bool use_parallel =
24 (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
25 at::get_num_threads() > 1);
26 if (!use_parallel) {
27 internal::ThreadIdGuard tid_guard(0);
28 c10::ParallelGuard guard(true);
29 f(begin, end);
30 return;
31 }
32
33 internal::invoke_parallel(
34 begin, end, grain_size, [&](int64_t begin, int64_t end) {
35 c10::ParallelGuard guard(true);
36 f(begin, end);
37 });
38 #else
39 internal::ThreadIdGuard tid_guard(0);
40 c10::ParallelGuard guard(true);
41 f(begin, end);
42 #endif
43 }
44
45 template <class scalar_t, class F, class SF>
parallel_reduce(const int64_t begin,const int64_t end,const int64_t grain_size,const scalar_t ident,const F & f,const SF & sf)46 inline scalar_t parallel_reduce(
47 const int64_t begin,
48 const int64_t end,
49 const int64_t grain_size,
50 const scalar_t ident,
51 const F& f,
52 const SF& sf) {
53 TORCH_CHECK(grain_size >= 0);
54 if (begin >= end) {
55 return ident;
56 }
57
58 #ifdef INTRA_OP_PARALLEL
59 at::internal::lazy_init_num_threads();
60 const auto max_threads = at::get_num_threads();
61 const bool use_parallel =
62 ((end - begin) > grain_size && !at::in_parallel_region() &&
63 max_threads > 1);
64 if (!use_parallel) {
65 internal::ThreadIdGuard tid_guard(0);
66 c10::ParallelGuard guard(true);
67 return f(begin, end, ident);
68 }
69
70 c10::SmallVector<scalar_t, 64> results(max_threads, ident);
71 internal::invoke_parallel(
72 begin,
73 end,
74 grain_size,
75 [&](const int64_t my_begin, const int64_t my_end) {
76 const auto tid = at::get_thread_num();
77 c10::ParallelGuard guard(true);
78 results[tid] = f(my_begin, my_end, ident);
79 });
80
81 scalar_t result = ident;
82 for (auto partial_result : results) {
83 result = sf(result, partial_result);
84 }
85 return result;
86 #else
87 internal::ThreadIdGuard tid_guard(0);
88 c10::ParallelGuard guard(true);
89 return f(begin, end, ident);
90 #endif
91 }
92
93 } // namespace at
94