1 #pragma once
2
3 #include <array>
4 #include <cstring>
5 #include <utility>
6
7 #include <ATen/Parallel.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/cpu/vec/vec.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/SmallVector.h>
12 #include <c10/util/irange.h>
13
14 namespace at::native {
15 inline namespace CPU_CAPABILITY {
16
17 template<typename T> using opmath_t = at::opmath_type<T>;
18
19 constexpr int64_t kChunkSize = 16;
20
21 template <typename T>
AddMoments(int64_t m0_add,const T & m1_add,const T & m2_add,int64_t & m0,T & m1,T & m2)22 void AddMoments(
23 int64_t m0_add,
24 const T& m1_add,
25 const T& m2_add,
26 int64_t& m0,
27 T& m1,
28 T& m2) {
29 const int64_t n = m0 + m0_add;
30 const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
31 const T delta = m1_add - m1;
32 m1 += c * delta;
33 m2 += m2_add + delta * delta * c * static_cast<T>(m0);
34 m0 = n;
35 }
36
37 template <typename T>
AddMomentsVec(int64_t m0_add,const vec::Vectorized<T> & m1_add,const vec::Vectorized<T> & m2_add,int64_t & m0,vec::Vectorized<T> & m1,vec::Vectorized<T> & m2)38 C10_ALWAYS_INLINE void AddMomentsVec(
39 int64_t m0_add,
40 const vec::Vectorized<T>& m1_add,
41 const vec::Vectorized<T>& m2_add,
42 int64_t& m0,
43 vec::Vectorized<T>& m1,
44 vec::Vectorized<T>& m2) {
45 using Vec = vec::Vectorized<T>;
46 const int64_t n = m0 + m0_add;
47 const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
48 const Vec c_vec(c);
49 const Vec delta = m1_add - m1;
50 m1 += c_vec * delta;
51 m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
52 m0 = n;
53 }
54
55 template <typename T>
56 inline std::enable_if_t<std::is_same_v<T, opmath_t<T>>, void>
UpdateMomentsVec(int64_t m0,const T * X_ptr,const std::array<vec::Vectorized<opmath_t<T>>,kChunkSize> & c_vecs,int64_t & m0_stk0,vec::Vectorized<opmath_t<T>> & m1_stk0,vec::Vectorized<opmath_t<T>> & m2_stk0)57 UpdateMomentsVec(
58 int64_t m0,
59 const T* X_ptr,
60 const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
61 int64_t& m0_stk0,
62 vec::Vectorized<opmath_t<T>>& m1_stk0,
63 vec::Vectorized<opmath_t<T>>& m2_stk0) {
64 using Vec = vec::Vectorized<opmath_t<T>>;
65 Vec m1_vec(0);
66 Vec m2_vec(0);
67 for (const auto j : c10::irange(m0)) {
68 const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
69 const Vec delta_vec = x_vec - m1_vec;
70 m1_vec += delta_vec * c_vecs[j];
71 m2_vec += delta_vec * (x_vec - m1_vec);
72 }
73 AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
74 }
75
76 // each bfloat16/half vector will be converted to two float vectors,
77 // and accumulated successively on m1_stk0/m2_stk0.
78 template <typename T>
79 inline std::enable_if_t<!std::is_same_v<T, at::opmath_type<T>>, void>
UpdateMomentsVec(int64_t m0,const T * X_ptr,const std::array<vec::Vectorized<at::opmath_type<T>>,kChunkSize> & c_vecs,int64_t & m0_stk0,vec::Vectorized<at::opmath_type<T>> & m1_stk0,vec::Vectorized<at::opmath_type<T>> & m2_stk0)80 UpdateMomentsVec(
81 int64_t m0,
82 const T* X_ptr,
83 const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
84 int64_t& m0_stk0,
85 vec::Vectorized<at::opmath_type<T>>& m1_stk0,
86 vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
87 using Vec = vec::Vectorized<T>;
88 using fVec = vec::Vectorized<at::opmath_type<T>>;
89 fVec m1_fvec0(0), m1_fvec1(0);
90 fVec m2_fvec0(0), m2_fvec1(0);
91 for (const auto j : c10::irange(m0)) {
92 const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
93 auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
94 const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
95 const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
96 m1_fvec0 += delta_fvec0 * c_vecs[j];
97 m1_fvec1 += delta_fvec1 * c_vecs[j];
98 m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
99 m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
100 }
101 AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
102 AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
103 }
104
105 // Compute rowwise moments by Welford algorithm and cascade sum to improve
106 // numerical stability.
107 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
108 // https://en.wikipedia.org/wiki/Pairwise_summation
109 template <typename T, int64_t kMaxDepth>
110 std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
111 using math_t = opmath_t<T>;
112
113 constexpr int64_t kVecSize = vec::Vectorized<T>::size();
114 constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
115 const int64_t n = N / kVecSize;
116 const int64_t m = divup(n, kChunkSize);
117 const int64_t depth = utils::CeilLog2(m);
118
119 using Vec = vec::Vectorized<math_t>;
120 const Vec kZeroVec(math_t(0));
121 c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
122 c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
123 c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
124
125 for (const auto i : c10::irange(m)) {
126 const T* X_ptr = X + i * kChunkSize * kVecSize;
127 const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
128 static std::array<Vec, kChunkSize> c_vecs = ([]() {
129 std::array<Vec, kChunkSize> result;
130 for (const auto i : c10::irange(kChunkSize)) {
131 result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
132 }
133 return result;
134 })();
135 UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
136
137 int64_t mask = i + 1;
138 for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
139 AddMomentsVec(
140 m0_stk[j - 1],
141 m1_stk[j - 1],
142 m2_stk[j - 1],
143 m0_stk[j],
144 m1_stk[j],
145 m2_stk[j]);
146 m0_stk[j - 1] = 0;
147 m1_stk[j - 1] = kZeroVec;
148 m2_stk[j - 1] = kZeroVec;
149 mask >>= 1;
150 }
151 }
152 for (const auto i : c10::irange(1, depth)) {
153 AddMomentsVec(
154 m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
155 }
156
157 std::array<math_t, kAccVecSize> m1_arr{};
158 std::array<math_t, kAccVecSize> m2_arr{};
159 m1_stk[0].store(m1_arr.data());
160 m2_stk[0].store(m2_arr.data());
161
162 int64_t m0 = 0;
163 math_t m1 = 0;
164 math_t m2 = 0;
165 for (int64_t i = n * kVecSize; i < N; ++i) {
166 math_t x = static_cast<math_t>(X[i]);
167 const math_t delta = x - m1;
168 ++m0;
169 m1 += delta / static_cast<math_t>(m0);
170 m2 += delta * (x - m1);
171 }
172 // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
173 int64_t m0_add = n * kVecSize / kAccVecSize;
174 for (const auto i : c10::irange(kAccVecSize)) {
175 AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
176 }
177
178 return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
179 }
180
181 template <typename T>
182 std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
183 using Vec = vec::Vectorized<T>;
184 constexpr int64_t kVecSize = Vec::size();
185 const int64_t n = N / kVecSize;
186 const int64_t m = divup(n, kChunkSize);
187 const int64_t depth = utils::CeilLog2(m);
188 if (depth <= 4) {
189 return RowwiseMomentsImpl<T, 4>(X, N, ddof);
190 } else if (depth <= 8) {
191 return RowwiseMomentsImpl<T, 8>(X, N, ddof);
192 } else if (depth <= 16) {
193 return RowwiseMomentsImpl<T, 16>(X, N, ddof);
194 } else if (depth <= 32) {
195 return RowwiseMomentsImpl<T, 32>(X, N, ddof);
196 } else {
197 return RowwiseMomentsImpl<T, 64>(X, N, ddof);
198 }
199 }
200
201 } // namespace CPU_CAPABILITY
202 } // namespace at::native
203