xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 namespace at {
7 namespace vec {
8 // See Note [CPU_CAPABILITY namespace]
9 inline namespace CPU_CAPABILITY {
10 
11 template <>
12 class Vectorized<int32_t> {
13  private:
14   union {
15     struct {
16       vint32 _vec0;
17       vint32 _vec1;
18     };
19     struct {
20       vbool32 _vecb0;
21       vbool32 _vecb1;
22     };
23 
24   } __attribute__((__may_alias__));
25 
26  public:
27   using value_type = int32_t;
28   using vec_internal_type = vint32;
29   using vec_internal_mask_type = vbool32;
30   using size_type = int;
size()31   static constexpr size_type size() {
32     return 8;
33   }
Vectorized()34   Vectorized() {}
Vectorized(vint32 v)35   C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool32 vmask)36   C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vint32 v1,vint32 v2)37   C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool32 v1,vbool32 v2)38   C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(int32_t scalar)39   C10_ALWAYS_INLINE Vectorized(int32_t scalar)
40       : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
Vectorized(int32_t scalar1,int32_t scalar2,int32_t scalar3,int32_t scalar4,int32_t scalar5,int32_t scalar6,int32_t scalar7,int32_t scalar8)41   C10_ALWAYS_INLINE Vectorized(
42       int32_t scalar1,
43       int32_t scalar2,
44       int32_t scalar3,
45       int32_t scalar4,
46       int32_t scalar5,
47       int32_t scalar6,
48       int32_t scalar7,
49       int32_t scalar8)
50       : _vec0{vint32{scalar1, scalar2, scalar3, scalar4}},
51         _vec1{vint32{scalar5, scalar6, scalar7, scalar8}} {}
vec0()52   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
53     return _vec0;
54   }
vec1()55   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
56     return _vec1;
57   }
58 
59   template <uint64_t mask>
60   static std::enable_if_t<mask == 0, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)61   blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
62     return a;
63   }
64 
65   template <uint64_t mask>
66   static std::enable_if_t<(mask & 255) == 255, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)67   blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
68     return b;
69   }
70 
71   template <uint64_t mask>
72   static std::enable_if_t<mask == 15, Vectorized<int32_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)73   blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
74     return {b._vec0, a._vec1};
75   }
76 
77   template <uint64_t mask>
78   static std::enable_if_t<(mask > 0 && mask < 15), Vectorized<int32_t>>
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)79       C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
80     constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
81     constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
82     constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
83     constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
84     const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
85 
86     return {(vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st), a._vec1};
87   }
88 
89   template <uint64_t mask>
90   static std::enable_if_t<
91       (mask > 15 && (mask & 255) != 255 && ((mask & 15) == 15)),
92       Vectorized<int32_t>>
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)93       C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
94     constexpr uint32_t mask2 = (mask & 255) >> 4;
95     constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
96     constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
97     constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
98     constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
99 
100     const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
101     // generated masks
102     return {b._vec0, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
103   }
104 
105   template <uint64_t mask>
106   static std::enable_if_t<
107       (mask > 15 && ((mask & 255) != 255) && ((mask & 15) == 0)),
108       Vectorized<int32_t>>
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)109       C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
110     constexpr uint32_t mask2 = (mask & 255) >> 4;
111     constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
112     constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
113     constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
114     constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
115 
116     const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
117     // generated masks
118     return {a, (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
119   }
120 
121   template <uint64_t mask>
122   static std::enable_if_t<
123       (mask > 15 && ((mask & 255) != 255) && ((mask & 15) != 0) &&
124        ((mask & 15) != 15)),
125       Vectorized<int32_t>>
blend(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)126       C10_ALWAYS_INLINE blend(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
127     constexpr uint32_t g0 = (mask & 1) * 0xffffffff;
128     constexpr uint32_t g1 = ((mask & 2) >> 1) * 0xffffffff;
129     constexpr uint32_t g2 = ((mask & 4) >> 2) * 0xffffffff;
130     constexpr uint32_t g3 = ((mask & 8) >> 3) * 0xffffffff;
131     constexpr uint32_t mask2 = (mask & 255) >> 4;
132     constexpr uint32_t g0_2 = (mask2 & 1) * 0xffffffff;
133     constexpr uint32_t g1_2 = ((mask2 & 2) >> 1) * 0xffffffff;
134     constexpr uint32_t g2_2 = ((mask2 & 4) >> 2) * 0xffffffff;
135     constexpr uint32_t g3_2 = ((mask2 & 8) >> 3) * 0xffffffff;
136 
137     const vbool32 mask_1st = (vbool32){g0, g1, g2, g3};
138     const vbool32 mask_2nd = (vbool32){g0_2, g1_2, g2_2, g3_2};
139     // generated masks
140     return {
141         (vint32)vec_sel(a._vec0, b._vec0, (vbool32)mask_1st),
142         (vint32)vec_sel(a._vec1, b._vec1, (vbool32)mask_2nd)};
143   }
144 
blendv(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b,const Vectorized<int32_t> & mask)145   static Vectorized<int32_t> C10_ALWAYS_INLINE blendv(
146       const Vectorized<int32_t>& a,
147       const Vectorized<int32_t>& b,
148       const Vectorized<int32_t>& mask) {
149     // the mask used here returned by comparision of vec256
150     // assuming this we can use the same mask directly with vec_sel
151     // warning intel style mask will not work properly
152     return {
153         vec_sel(a._vec0, b._vec0, mask._vecb0),
154         vec_sel(a._vec1, b._vec1, mask._vecb1)};
155   }
156 
157   template <typename step_t>
158   static Vectorized<int32_t> arange(int32_t base = 0.f, step_t step = static_cast<step_t>(1)) {
159     return Vectorized<int32_t>(
160         base,
161         base + step,
162         base + 2 * step,
163         base + 3 * step,
164         base + 4 * step,
165         base + 5 * step,
166         base + 6 * step,
167         base + 7 * step);
168   }
169   static Vectorized<int32_t> set(
170       const Vectorized<int32_t>& a,
171       const Vectorized<int32_t>& b,
172       size_t count = size()) {
173     switch (count) {
174       case 0:
175         return a;
176       case 1:
177         return blend<1>(a, b);
178       case 2:
179         return blend<3>(a, b);
180       case 3:
181         return blend<7>(a, b);
182       case 4:
183         return blend<15>(a, b);
184       case 5:
185         return blend<31>(a, b);
186       case 6:
187         return blend<63>(a, b);
188       case 7:
189         return blend<127>(a, b);
190     }
191 
192     return b;
193   }
194   static Vectorized<value_type> C10_ALWAYS_INLINE
195   loadu(const void* ptr, int count = size()) {
196     if (count == size()) {
197       return {
198           vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
199           vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
200     }
201 
202     __at_align__ value_type tmp_values[size()] = {};
203     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
204 
205     return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
206   }
207   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
208     if (count == size()) {
209       vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
210       vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
211     } else if (count > 0) {
212       __at_align__ value_type tmp_values[size()];
213       vec_vsx_st(_vec0, offset0, tmp_values);
214       vec_vsx_st(_vec1, offset16, tmp_values);
215       std::memcpy(
216           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
217     }
218   }
219   const int32_t& operator[](int idx) const = delete;
220   int32_t& operator[](int idx) = delete;
221 
angle()222   Vectorized<int32_t> angle() const {
223     return blendv(
224       Vectorized<int32_t>(0), Vectorized<int32_t>(c10::pi<int32_t>), *this < Vectorized<int32_t>(0));
225   }
real()226   Vectorized<int32_t> real() const {
227     return *this;
228   }
imag()229   Vectorized<int32_t> imag() const {
230     return Vectorized<int32_t>{0};
231   }
conj()232   Vectorized<int32_t> conj() const {
233     return *this;
234   }
235 
abs()236   Vectorized<int32_t> C10_ALWAYS_INLINE abs() const {
237     return {vec_abs(_vec0), vec_abs(_vec1)};
238   }
239 
neg()240   Vectorized<int32_t> C10_ALWAYS_INLINE neg() const {
241     return {vec_neg(_vec0), vec_neg(_vec1)};
242   }
243 
244   DEFINE_MEMBER_UNARY_OP(operator~, int32_t, vec_not)
245   DEFINE_MEMBER_OP(operator==, int32_t, vec_cmpeq)
246   DEFINE_MEMBER_OP(operator!=, int32_t, vec_cmpne)
247   DEFINE_MEMBER_OP(operator<, int32_t, vec_cmplt)
248   DEFINE_MEMBER_OP(operator<=, int32_t, vec_cmple)
249   DEFINE_MEMBER_OP(operator>, int32_t, vec_cmpgt)
250   DEFINE_MEMBER_OP(operator>=, int32_t, vec_cmpge)
251   DEFINE_MEMBER_OP_AND_ONE(eq, int32_t, vec_cmpeq)
252   DEFINE_MEMBER_OP_AND_ONE(ne, int32_t, vec_cmpne)
253   DEFINE_MEMBER_OP_AND_ONE(lt, int32_t, vec_cmplt)
254   DEFINE_MEMBER_OP_AND_ONE(le, int32_t, vec_cmple)
255   DEFINE_MEMBER_OP_AND_ONE(gt, int32_t, vec_cmpgt)
256   DEFINE_MEMBER_OP_AND_ONE(ge, int32_t, vec_cmpge)
257   DEFINE_MEMBER_OP(operator+, int32_t, vec_add)
258   DEFINE_MEMBER_OP(operator-, int32_t, vec_sub)
259   DEFINE_MEMBER_OP(operator*, int32_t, vec_mul)
260   DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int32_t, /)
261   DEFINE_MEMBER_OP(maximum, int32_t, vec_max)
262   DEFINE_MEMBER_OP(minimum, int32_t, vec_min)
263   DEFINE_MEMBER_OP(operator&, int32_t, vec_and)
264   DEFINE_MEMBER_OP(operator|, int32_t, vec_or)
265   DEFINE_MEMBER_OP(operator^, int32_t, vec_xor)
266 };
267 
268 template <>
269 Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
270                 vuint32 shift_vec0 = reinterpret_cast<vuint32>(b.vec0());
271                 vuint32 shift_vec1 = reinterpret_cast<vuint32>(b.vec1()) ;
272           return Vectorized<int32_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
273 }
274 
275 template <>
276 Vectorized<int32_t> inline operator>>(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
277                 vuint32 shift_vec0 = reinterpret_cast<vuint32>(b.vec0());
278                 vuint32 shift_vec1 = reinterpret_cast<vuint32>(b.vec1()) ;
279           return Vectorized<int32_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
280 }
281 
282 template <>
maximum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)283 Vectorized<int32_t> inline maximum(
284     const Vectorized<int32_t>& a,
285     const Vectorized<int32_t>& b) {
286   return a.maximum(b);
287 }
288 
289 template <>
minimum(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b)290 Vectorized<int32_t> inline minimum(
291     const Vectorized<int32_t>& a,
292     const Vectorized<int32_t>& b) {
293   return a.minimum(b);
294 }
295 
296 template <>
297 Vectorized<int32_t> C10_ALWAYS_INLINE operator+(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
298   return Vectorized<int32_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
299 }
300 
301 template <>
302 Vectorized<int32_t> C10_ALWAYS_INLINE operator-(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
303   return Vectorized<int32_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
304 }
305 
306 template <>
307 Vectorized<int32_t> C10_ALWAYS_INLINE operator*(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
308   return Vectorized<int32_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
309 }
310 
311 template <>
312 Vectorized<int32_t> C10_ALWAYS_INLINE operator/(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
313   return Vectorized<int32_t>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
314 }
315 
316 template <>
317 Vectorized<int32_t> C10_ALWAYS_INLINE operator&(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
318   return Vectorized<int32_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
319 }
320 
321 template <>
322 Vectorized<int32_t> C10_ALWAYS_INLINE operator|(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
323   return Vectorized<int32_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
324 }
325 
326 template <>
327 Vectorized<int32_t> C10_ALWAYS_INLINE operator^(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
328   return Vectorized<int32_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
329 }
330 
331 } // namespace
332 } // namespace vec
333 } // namespace at
334