xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/moments_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <array>
4 #include <cstring>
5 #include <utility>
6 
7 #include <ATen/Parallel.h>
8 #include <ATen/OpMathType.h>
9 #include <ATen/cpu/vec/vec.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/SmallVector.h>
12 #include <c10/util/irange.h>
13 
14 namespace at::native {
15 inline namespace CPU_CAPABILITY {
16 
17 template<typename T> using opmath_t = at::opmath_type<T>;
18 
19 constexpr int64_t kChunkSize = 16;
20 
21 template <typename T>
AddMoments(int64_t m0_add,const T & m1_add,const T & m2_add,int64_t & m0,T & m1,T & m2)22 void AddMoments(
23     int64_t m0_add,
24     const T& m1_add,
25     const T& m2_add,
26     int64_t& m0,
27     T& m1,
28     T& m2) {
29   const int64_t n = m0 + m0_add;
30   const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
31   const T delta = m1_add - m1;
32   m1 += c * delta;
33   m2 += m2_add + delta * delta * c * static_cast<T>(m0);
34   m0 = n;
35 }
36 
37 template <typename T>
AddMomentsVec(int64_t m0_add,const vec::Vectorized<T> & m1_add,const vec::Vectorized<T> & m2_add,int64_t & m0,vec::Vectorized<T> & m1,vec::Vectorized<T> & m2)38 C10_ALWAYS_INLINE void AddMomentsVec(
39     int64_t m0_add,
40     const vec::Vectorized<T>& m1_add,
41     const vec::Vectorized<T>& m2_add,
42     int64_t& m0,
43     vec::Vectorized<T>& m1,
44     vec::Vectorized<T>& m2) {
45   using Vec = vec::Vectorized<T>;
46   const int64_t n = m0 + m0_add;
47   const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
48   const Vec c_vec(c);
49   const Vec delta = m1_add - m1;
50   m1 += c_vec * delta;
51   m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
52   m0 = n;
53 }
54 
55 template <typename T>
56 inline std::enable_if_t<std::is_same_v<T, opmath_t<T>>, void>
UpdateMomentsVec(int64_t m0,const T * X_ptr,const std::array<vec::Vectorized<opmath_t<T>>,kChunkSize> & c_vecs,int64_t & m0_stk0,vec::Vectorized<opmath_t<T>> & m1_stk0,vec::Vectorized<opmath_t<T>> & m2_stk0)57 UpdateMomentsVec(
58     int64_t m0,
59     const T* X_ptr,
60     const std::array<vec::Vectorized<opmath_t<T>>, kChunkSize>& c_vecs,
61     int64_t& m0_stk0,
62     vec::Vectorized<opmath_t<T>>& m1_stk0,
63     vec::Vectorized<opmath_t<T>>& m2_stk0) {
64   using Vec = vec::Vectorized<opmath_t<T>>;
65   Vec m1_vec(0);
66   Vec m2_vec(0);
67   for (const auto j : c10::irange(m0)) {
68     const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
69     const Vec delta_vec = x_vec - m1_vec;
70     m1_vec += delta_vec * c_vecs[j];
71     m2_vec += delta_vec * (x_vec - m1_vec);
72   }
73   AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
74 }
75 
76 // each bfloat16/half vector will be converted to two float vectors,
77 // and accumulated successively on m1_stk0/m2_stk0.
78 template <typename T>
79 inline std::enable_if_t<!std::is_same_v<T, at::opmath_type<T>>, void>
UpdateMomentsVec(int64_t m0,const T * X_ptr,const std::array<vec::Vectorized<at::opmath_type<T>>,kChunkSize> & c_vecs,int64_t & m0_stk0,vec::Vectorized<at::opmath_type<T>> & m1_stk0,vec::Vectorized<at::opmath_type<T>> & m2_stk0)80 UpdateMomentsVec(
81     int64_t m0,
82     const T* X_ptr,
83     const std::array<vec::Vectorized<at::opmath_type<T>>, kChunkSize>& c_vecs,
84     int64_t& m0_stk0,
85     vec::Vectorized<at::opmath_type<T>>& m1_stk0,
86     vec::Vectorized<at::opmath_type<T>>& m2_stk0) {
87   using Vec = vec::Vectorized<T>;
88   using fVec = vec::Vectorized<at::opmath_type<T>>;
89   fVec m1_fvec0(0), m1_fvec1(0);
90   fVec m2_fvec0(0), m2_fvec1(0);
91   for (const auto j : c10::irange(m0)) {
92     const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size());
93     auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
94     const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
95     const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
96     m1_fvec0 += delta_fvec0 * c_vecs[j];
97     m1_fvec1 += delta_fvec1 * c_vecs[j];
98     m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
99     m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
100   }
101   AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
102   AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
103 }
104 
105 // Compute rowwise moments by Welford algorithm and cascade sum to improve
106 // numerical stability.
107 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
108 // https://en.wikipedia.org/wiki/Pairwise_summation
109 template <typename T, int64_t kMaxDepth>
110 std::pair<opmath_t<T>, opmath_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
111   using math_t = opmath_t<T>;
112 
113   constexpr int64_t kVecSize = vec::Vectorized<T>::size();
114   constexpr int64_t kAccVecSize = vec::Vectorized<math_t>::size();
115   const int64_t n = N / kVecSize;
116   const int64_t m = divup(n, kChunkSize);
117   const int64_t depth = utils::CeilLog2(m);
118 
119   using Vec = vec::Vectorized<math_t>;
120   const Vec kZeroVec(math_t(0));
121   c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
122   c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
123   c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
124 
125   for (const auto i : c10::irange(m)) {
126     const T* X_ptr = X + i * kChunkSize * kVecSize;
127     const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
128     static std::array<Vec, kChunkSize> c_vecs = ([]() {
129       std::array<Vec, kChunkSize> result;
130       for (const auto i : c10::irange(kChunkSize)) {
131         result[i] = Vec(math_t(1) / static_cast<math_t>(i + 1));
132       }
133       return result;
134     })();
135     UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
136 
137     int64_t mask = i + 1;
138     for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
139       AddMomentsVec(
140           m0_stk[j - 1],
141           m1_stk[j - 1],
142           m2_stk[j - 1],
143           m0_stk[j],
144           m1_stk[j],
145           m2_stk[j]);
146       m0_stk[j - 1] = 0;
147       m1_stk[j - 1] = kZeroVec;
148       m2_stk[j - 1] = kZeroVec;
149       mask >>= 1;
150     }
151   }
152   for (const auto i : c10::irange(1, depth)) {
153     AddMomentsVec(
154         m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
155   }
156 
157   std::array<math_t, kAccVecSize> m1_arr{};
158   std::array<math_t, kAccVecSize> m2_arr{};
159   m1_stk[0].store(m1_arr.data());
160   m2_stk[0].store(m2_arr.data());
161 
162   int64_t m0 = 0;
163   math_t m1 = 0;
164   math_t m2 = 0;
165   for (int64_t i = n * kVecSize; i < N; ++i) {
166     math_t x = static_cast<math_t>(X[i]);
167     const math_t delta = x - m1;
168     ++m0;
169     m1 += delta / static_cast<math_t>(m0);
170     m2 += delta * (x - m1);
171   }
172   // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
173   int64_t m0_add = n * kVecSize / kAccVecSize;
174   for (const auto i : c10::irange(kAccVecSize)) {
175     AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
176   }
177 
178   return std::make_pair(m1, m2 / static_cast<math_t>(N - ddof));
179 }
180 
181 template <typename T>
182 std::pair<opmath_t<T>, opmath_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
183   using Vec = vec::Vectorized<T>;
184   constexpr int64_t kVecSize = Vec::size();
185   const int64_t n = N / kVecSize;
186   const int64_t m = divup(n, kChunkSize);
187   const int64_t depth = utils::CeilLog2(m);
188   if (depth <= 4) {
189     return RowwiseMomentsImpl<T, 4>(X, N, ddof);
190   } else if (depth <= 8) {
191     return RowwiseMomentsImpl<T, 8>(X, N, ddof);
192   } else if (depth <= 16) {
193     return RowwiseMomentsImpl<T, 16>(X, N, ddof);
194   } else if (depth <= 32) {
195     return RowwiseMomentsImpl<T, 32>(X, N, ddof);
196   } else {
197     return RowwiseMomentsImpl<T, 64>(X, N, ddof);
198   }
199 }
200 
201 } // namespace CPU_CAPABILITY
202 } // namespace at::native
203