1 #pragma once 2 3 #include <ATen/cpu/vec/vec_base.h> 4 #include <ATen/cpu/vec/vec_n.h> 5 namespace at::vec { 6 inline namespace CPU_CAPABILITY { 7 8 /** 9 * The `VecMask` class provides a convenient interface for working with 10 * vectorized masks in SIMD operations. It encapsulates a `Vectorized<T, N>` 11 * mask that can be directly usable in masked vectorized operations. It provides 12 * various methods for manipulating and accessing the mask elements: 13 * 1. `from` and `to`: Conversion between a vector of boolean values and a 14 * vectorized mask. 15 * 2. `cast`: Casts the mask to a different base type. 16 * 3. `all_zero`: Checks if all mask elements are zero. 17 * 4. `is_masked`: Checks if a specific element is masked. 18 * 5. `loadu`: Loads data from memory using the mask. 19 * 6. `all_masked`: Checks if all mask elements are masked. 20 * 21 * Some helper template classes are provided to simplify the specialization of 22 * the `VecMask` for the specific CPU arch: 23 * 1. `VecMaskLoad`: Loads data from memory using the mask. 24 * 2. `VecMaskTo`: Converts the mask to boolean. 25 * 3. `VecMaskCast`: Casts the mask to a different base type. 26 * 27 */ 28 template <typename T, int N> 29 class VecMask; 30 31 template < 32 typename data_t, 33 int data_n, 34 typename mask_t, 35 int mask_n, 36 typename Enabled = void> 37 struct VecMaskLoad { applyVecMaskLoad38 static inline VectorizedN<data_t, data_n> apply( 39 const data_t* ptr, 40 const VecMask<mask_t, mask_n>& vec_mask) { 41 constexpr typename VecMask<mask_t, mask_n>::size_type size = 42 VecMask<mask_t, mask_n>::size(); 43 static_assert(VectorizedN<data_t, data_n>::size() >= size); 44 __at_align__ data_t data[size]; 45 __at_align__ mask_t mask[size]; 46 auto mask_ = VectorizedN<mask_t, mask_n>(vec_mask); 47 mask_.store(mask); 48 for (int i = 0; i < size; i++) { 49 data[i] = mask[i] ? ptr[i] : static_cast<data_t>(0); 50 } 51 return VectorizedN<data_t, data_n>::loadu(data, size); 52 } 53 }; 54 55 template < 56 typename dst_t, 57 int dst_n, 58 typename src_t, 59 int src_n, 60 typename Enabled = void> 61 struct VecMaskTo { applyVecMaskTo62 static inline VecMask<dst_t, dst_n> apply( 63 const VecMask<src_t, src_n>& vec_mask) { 64 auto zeros = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(0)); 65 auto ones = VectorizedN<dst_t, dst_n>(static_cast<dst_t>(1)); 66 return VectorizedN<dst_t, dst_n>::blendv( 67 zeros, ones, vec_mask.template cast<dst_t, dst_n>()); 68 } 69 }; 70 71 template <typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void> 72 struct VecMaskCast { applyVecMaskCast73 static inline VecMask<dst_t, dst_n> apply( 74 const VecMask<src_t, src_n>& vec_mask) { 75 return VecMask<dst_t, dst_n>::from(VectorizedN<src_t, src_n>(vec_mask)); 76 } 77 }; 78 79 template <typename T, int N> 80 struct VecMaskCast<T, N, T, N> { 81 static inline VecMask<T, N> apply(const VecMask<T, N>& vec_mask) { 82 return vec_mask; 83 } 84 }; 85 86 template <typename T, int N> 87 struct VecMaskCheck { 88 static inline bool all_zero(const VectorizedN<T, N>& vec_mask) { 89 __at_align__ T mask[VectorizedN<T, N>::size()]; 90 vec_mask.store(mask); 91 return std::all_of( 92 mask, mask + VectorizedN<T, N>::size(), [](T m) { return m == static_cast<T>(0); }); 93 } 94 95 static inline bool all_masked(const VectorizedN<T, N>& vec_mask) { 96 __at_align__ T mask[VectorizedN<T, N>::size()]; 97 vec_mask.store(mask); 98 return std::all_of( 99 mask, mask + VectorizedN<T, N>::size(), [](T m) { return m != static_cast<T>(0); }); 100 } 101 102 static inline bool is_masked(const VectorizedN<T, N>& vec_mask, int i) { 103 __at_align__ T mask[VectorizedN<T, N>::size()]; 104 vec_mask.store(mask); 105 return mask[i] != static_cast<T>(0); 106 } 107 }; 108 109 template <typename T, int N> 110 class VecMask { 111 public: 112 using size_type = int; 113 static constexpr size_type size() { 114 return VectorizedN<T, N>::size(); 115 } 116 117 private: 118 VectorizedN<T, N> mask_; 119 120 public: 121 VecMask() : mask_(static_cast<T>(0)) {} 122 VecMask(const VectorizedN<T, N>& mask) : mask_(mask) {} 123 124 template <int L = N, typename std::enable_if_t<L == 1, int> = 0> 125 VecMask(const Vectorized<T>& mask) : mask_(mask) {} 126 127 template <typename U, int L> 128 static VecMask<T, N> from(const VectorizedN<U, L>& b_vec) { 129 __at_align__ U b_buf[size()]; 130 if constexpr (size() >= VectorizedN<U, L>::size()) { 131 b_vec.store(b_buf); 132 for (int i = VectorizedN<U, L>::size(); i < size(); i++) { 133 b_buf[i] = static_cast<U>(0); 134 } 135 } else { 136 b_vec.store(b_buf, size()); 137 } 138 return from(b_buf); 139 } 140 141 template <typename U> 142 static VecMask<T, N> from(U b) { 143 using int_t = int_same_size_t<T>; 144 T mask = b ? c10::bit_cast<T>((int_t)(~(int_t)0)) : (T)0; 145 return VectorizedN<T, N>(mask); 146 } 147 148 template <typename U> 149 static VecMask<T, N> from(U* b) { 150 using int_t = int_same_size_t<T>; 151 __at_align__ T mask[size()]; 152 #ifndef __msvc_cl__ 153 #pragma unroll 154 #endif 155 for (int i = 0; i < size(); i++) { 156 *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; 157 } 158 return VectorizedN<T, N>(VectorizedN<T, N>::loadu(mask)); 159 } 160 161 static VecMask<T, N> blendv( 162 const VecMask<T, N>& c, 163 const VecMask<T, N>& b, 164 const VecMask<T, N>& a) { 165 VectorizedN<T, N> result = VectorizedN<T, N>::blendv( 166 VectorizedN<T, N>(c), 167 VectorizedN<T, N>(b), 168 VectorizedN<T, N>(a)); 169 return result; 170 } 171 172 static VecMask<T, N> set( 173 const VecMask<T, N>& a, 174 const VecMask<T, N>& b, 175 int64_t count = size()) { 176 VectorizedN<T, N> result = VectorizedN<T, N>::set( 177 VectorizedN<T, N>(a), 178 VectorizedN<T, N>(b), 179 count); 180 return result; 181 } 182 183 void store(bool* b, int count = size()) { 184 constexpr int L = (VectorizedN<T, N>::size() + Vectorized<bool>::size() - 1)/ Vectorized<bool>::size(); 185 auto res = this->to<bool, L>(); 186 res.store(b, count); 187 return; 188 } 189 190 template <typename U, int L, std::enable_if_t<L >= 2, int> = 0> 191 inline VectorizedN<U, L> to() const { 192 return VecMaskTo<U, L, T, N>::apply(*this); 193 } 194 195 template <typename U, int L, std::enable_if_t<L == 1, int> = 0> 196 inline Vectorized<U> to() const { 197 return VecMaskTo<U, L, T, N>::apply(*this); 198 } 199 200 template <typename U, int L> 201 inline VecMask<U, L> cast() const { 202 return VecMaskCast<U, L, T, N>::apply(*this); 203 } 204 205 inline bool all_zero() const { 206 return VecMaskCheck<T, N>::all_zero(mask_); 207 } 208 209 inline bool all_masked() const { 210 return VecMaskCheck<T, N>::all_masked(mask_); 211 } 212 213 inline bool is_masked(int i) const { 214 return VecMaskCheck<T, N>::is_masked(mask_, i); 215 } 216 217 inline operator VectorizedN<T, N>() const { 218 return mask_; 219 } 220 221 template <int L = N, typename std::enable_if_t<L == 1, int> = 0> 222 inline operator Vectorized<T>() const { 223 return mask_[0]; 224 } 225 226 inline Vectorized<T> operator[](int i) const { 227 return mask_[i]; 228 } 229 230 template < 231 typename U, 232 int L, 233 std::enable_if_t<L >= 2 && VectorizedN<U, L>::size() >= size(), int> = 0> 234 VectorizedN<U, L> loadu(const U* ptr) const { 235 return VecMaskLoad<U, L, T, N>::apply(ptr, *this); 236 } 237 238 template < 239 typename U, 240 int L, 241 std::enable_if_t<L == 1 && Vectorized<U>::size() >= size(), int> = 0> 242 Vectorized<U> loadu(const U* ptr) const { 243 return VecMaskLoad<U, L, T, N>::apply(ptr, *this); 244 } 245 }; 246 247 #define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \ 248 template <typename T, int N> \ 249 inline VecMask<T, N> op(const VecMask<T, N>& a) { \ 250 return op(VectorizedN<T, N>(a)); \ 251 } 252 253 #define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \ 254 template < \ 255 typename T, \ 256 int N, \ 257 typename V, \ 258 int M, \ 259 std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \ 260 0> \ 261 inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \ 262 return op( \ 263 VectorizedN<T, N>(a), VectorizedN<T, N>(b.template cast<T, N>())); \ 264 } 265 266 #define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \ 267 template < \ 268 typename T, \ 269 int N, \ 270 typename V, \ 271 int M, \ 272 std::enable_if_t<VecMask<T, N>::size() == VecMask<V, M>::size(), int> = \ 273 0> \ 274 inline VecMask<T, N> op(const VecMask<T, N>& a, const VecMask<V, M>& b) { \ 275 return EXPR; \ 276 } 277 278 VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) 279 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) 280 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) 281 VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) 282 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) 283 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) 284 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) 285 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) 286 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) 287 VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) 288 289 #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL 290 #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL 291 #undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL 292 293 } // namespace CPU_CAPABILITY 294 } // namespace at::vec 295