xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec_mask.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <ATen/cpu/vec/vec_n.h>
5 namespace at::vec {
6 inline namespace CPU_CAPABILITY {
7 
8 /**
9  * The `VecMask` class provides a convenient interface for working with
10  * vectorized masks in SIMD operations. It encapsulates a `Vectorized<T, N>`
11  * mask that can be directly usable in masked vectorized operations. It provides
12  * various methods for manipulating and accessing the mask elements:
13  * 1. `from` and `to`: Conversion between a vector of boolean values and a
14  * vectorized mask.
15  * 2. `cast`: Casts the mask to a different base type.
16  * 3. `all_zero`: Checks if all mask elements are zero.
17  * 4. `is_masked`: Checks if a specific element is masked.
18  * 5. `loadu`: Loads data from memory using the mask.
19  * 6. `all_masked`: Checks if all mask elements are masked.
20  *
21  * Some helper template classes are provided to simplify the specialization of
22  * the `VecMask` for the specific CPU arch:
23  * 1. `VecMaskLoad`: Loads data from memory using the mask.
24  * 2. `VecMaskTo`: Converts the mask to boolean.
25  * 3. `VecMaskCast`: Casts the mask to a different base type.
26  *
27  */
28 template <typename T, int N>
29 class VecMask;
30 
31 template <
32     typename data_t,
33     int data_n,
34     typename mask_t,
35     int mask_n,
36     typename Enabled = void>
37 struct VecMaskLoad {
applyVecMaskLoad38   static inline VectorizedN<data_t, data_n> apply(
39       const data_t* ptr,
40       const VecMask<mask_t, mask_n>& vec_mask) {
41     constexpr typename VecMask<mask_t, mask_n>::size_type size =
42         VecMask<mask_t, mask_n>::size();
43     static_assert(VectorizedN<data_t, data_n>::size() >= size);
44     __at_align__ data_t data[size];
45     __at_align__ mask_t mask[size];
46     auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask);
47     mask_.store(mask);
48     for (int i = 0; i < size; i++) {
49       data[i] = mask[i] ? ptr[i] : static_cast<data_t>(0);
50     }
51     return VectorizedN<data_t, data_n>::loadu(data, size);
52   }
53 };
54 
55 template <
56     typename dst_t,
57     int dst_n,
58     typename src_t,
59     int src_n,
60     typename Enabled = void>
61 struct VecMaskTo {
applyVecMaskTo62   static inline VecMask<dst_t, dst_n> apply(
63       const VecMask<src_t, src_n>& vec_mask) {
64     auto zeros = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(0));
65     auto ones = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(1));
66     return VectorizedN<dst_t, dst_n>::blendv(
67         zeros, ones, vec_mask.template cast<dst_t, dst_n>());
68   }
69 };
70 
71 template <typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void>
72 struct VecMaskCast {
applyVecMaskCast73   static inline VecMask<dst_t, dst_n> apply(
74       const VecMask<src_t, src_n>& vec_mask) {
75     return VecMask<dst_t, dst_n>::from(VectorizedN<src_t, src_n>(vec_mask));
76   }
77 };
78 
79 template <typename T, int N>
80 struct VecMaskCast<T, N, T, N> {
81   static inline VecMask<T, N> apply(const VecMask<T, N>& vec_mask) {
82     return vec_mask;
83   }
84 };
85 
86 template <typename T, int N>
87 struct VecMaskCheck {
88   static inline bool all_zero(const VectorizedN<T, N>& vec_mask) {
89     __at_align__ T mask[VectorizedN<T, N>::size()];
90     vec_mask.store(mask);
91     return std::all_of(
92         mask, mask + VectorizedN<T, N>::size(), [](T m) { return m == static_cast<T>(0); });
93   }
94 
95   static inline bool all_masked(const VectorizedN<T, N>& vec_mask) {
96     __at_align__ T mask[VectorizedN<T, N>::size()];
97     vec_mask.store(mask);
98     return std::all_of(
99         mask, mask + VectorizedN<T, N>::size(), [](T m) { return m != static_cast<T>(0); });
100   }
101 
102   static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) {
103     __at_align__ T mask[VectorizedN<T, N>::size()];
104     vec_mask.store(mask);
105     return mask[i] != static_cast<T>(0);
106   }
107 };
108 
109 template <typename T, int N>
110 class VecMask {
111  public:
112   using size_type = int;
113   static constexpr size_type size() {
114     return VectorizedN<T, N>::size();
115   }
116 
117  private:
118   VectorizedN<T, N> mask_;
119 
120  public:
121   VecMask() : mask_(static_cast<T>(0)) {}
122   VecMask(const VectorizedN<T, N>& mask) : mask_(mask) {}
123 
124   template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
125   VecMask(const Vectorized<T>& mask) : mask_(mask) {}
126 
127   template <typename U, int L>
128   static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) {
129     __at_align__ U b_buf[size()];
130     if constexpr (size() >= VectorizedN<U, L>::size()) {
131       b_vec.store(b_buf);
132       for (int i = VectorizedN<U, L>::size(); i < size(); i++) {
133         b_buf[i] = static_cast<U>(0);
134       }
135     } else {
136       b_vec.store(b_buf, size());
137     }
138     return from(b_buf);
139   }
140 
141   template <typename U>
142   static VecMask<T, N> from(U b) {
143     using int_t = int_same_size_t<T>;
144     T mask = b ? c10::bit_cast<T>((int_t)(~(int_t)0)) : (T)0;
145     return VectorizedN<T, N>(mask);
146   }
147 
148   template <typename U>
149   static VecMask<T, N> from(U* b) {
150     using int_t = int_same_size_t<T>;
151     __at_align__ T mask[size()];
152 #ifndef __msvc_cl__
153 #pragma unroll
154 #endif
155     for (int i = 0; i < size(); i++) {
156       *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0;
157     }
158     return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask));
159   }
160 
161   static VecMask<T, N> blendv(
162     const VecMask<T, N>& c,
163     const VecMask<T, N>& b,
164     const VecMask<T, N>& a) {
165     VectorizedN<T, N> result = VectorizedN<T, N>::blendv(
166       VectorizedN<T, N>(c),
167       VectorizedN<T, N>(b),
168       VectorizedN<T, N>(a));
169     return result;
170   }
171 
172   static VecMask<T, N> set(
173       const VecMask<T, N>& a,
174       const VecMask<T, N>& b,
175       int64_t count = size()) {
176     VectorizedN<T, N> result = VectorizedN<T, N>::set(
177       VectorizedN<T, N>(a),
178       VectorizedN<T, N>(b),
179       count);
180     return result;
181   }
182 
183   void store(bool* b, int count = size()) {
184     constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size();
185     auto res = this->to<bool, L>();
186     res.store(b, count);
187     return;
188   }
189 
190   template <typename U, int L, std::enable_if_t<L >= 2, int> = 0>
191   inline VectorizedN<U, L> to() const {
192     return VecMaskTo<U, L, T, N>::apply(*this);
193   }
194 
195   template <typename U, int L, std::enable_if_t<L == 1, int> = 0>
196   inline Vectorized<U> to() const {
197     return VecMaskTo<U, L, T, N>::apply(*this);
198   }
199 
200   template <typename U, int L>
201   inline VecMask<U, L> cast() const {
202     return VecMaskCast<U, L, T, N>::apply(*this);
203   }
204 
205   inline bool all_zero() const {
206     return VecMaskCheck<T, N>::all_zero(mask_);
207   }
208 
209   inline bool all_masked() const {
210     return VecMaskCheck<T, N>::all_masked(mask_);
211   }
212 
213   inline bool is_masked(int i) const {
214     return VecMaskCheck<T, N>::is_masked(mask_, i);
215   }
216 
217   inline operator VectorizedN<T, N>() const {
218     return mask_;
219   }
220 
221   template <int L = N, typename std::enable_if_t<L == 1, int> = 0>
222   inline operator Vectorized<T>() const {
223     return mask_[0];
224   }
225 
226   inline Vectorized<T> operator[](int i) const {
227     return mask_[i];
228   }
229 
230   template <
231       typename U,
232       int L,
233       std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0>
234   VectorizedN<U, L> loadu(const U* ptr) const {
235     return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
236   }
237 
238   template <
239       typename U,
240       int L,
241       std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0>
242   Vectorized<U> loadu(const U* ptr) const {
243     return VecMaskLoad<U, L, T, N>::apply(ptr, *this);
244   }
245 };
246 
247 #define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op)         \
248   template <typename T, int N>                      \
249   inline VecMask<T, N> op(const VecMask<T, N>& a) { \
250     return op(VectorizedN<T, N>(a));                \
251   }
252 
253 #define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op)                                  \
254   template <                                                                  \
255       typename T,                                                             \
256       int N,                                                                  \
257       typename V,                                                             \
258       int M,                                                                  \
259       std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
260           0>                                                                  \
261   inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) {   \
262     return op(                                                                \
263         VectorizedN<T, N>(a), VectorizedN<T, N>(b.template cast<T, N>()));    \
264   }
265 
266 #define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR)                  \
267   template <                                                                  \
268       typename T,                                                             \
269       int N,                                                                  \
270       typename V,                                                             \
271       int M,                                                                  \
272       std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \
273           0>                                                                  \
274   inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) {   \
275     return EXPR;                                                              \
276   }
277 
278 VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~)
279 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&)
280 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|)
281 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^)
282 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b)
283 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b)
284 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b))
285 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b))
286 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b))
287 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b))
288 
289 #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL
290 #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL
291 #undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL
292 
293 } // namespace CPU_CAPABILITY
294 } // namespace at::vec
295