xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 namespace at {
7 namespace vec {
8 // See Note [CPU_CAPABILITY namespace]
9 inline namespace CPU_CAPABILITY {
10 
11 template <>
12 class Vectorized<int16_t> {
13  private:
14   union {
15     struct {
16       vint16 _vec0;
17       vint16 _vec1;
18     };
19     struct {
20       vbool16 _vecb0;
21       vbool16 _vecb1;
22     };
23 
24   } __attribute__((__may_alias__));
25 
26  public:
27   using value_type = int16_t;
28   using vec_internal_type = vint16;
29   using vec_internal_mask_type = vbool16;
30   using size_type = int;
size()31   static constexpr size_type size() {
32     return 16;
33   }
Vectorized()34   Vectorized() {}
Vectorized(vint16 v)35   C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool16 vmask)36   C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vint16 v1,vint16 v2)37   C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool16 v1,vbool16 v2)38   C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(int16_t scalar)39   C10_ALWAYS_INLINE Vectorized(int16_t scalar)
40       : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
41 
Vectorized(int16_t scalar1,int16_t scalar2,int16_t scalar3,int16_t scalar4,int16_t scalar5,int16_t scalar6,int16_t scalar7,int16_t scalar8,int16_t scalar9,int16_t scalar10,int16_t scalar11,int16_t scalar12,int16_t scalar13,int16_t scalar14,int16_t scalar15,int16_t scalar16)42   C10_ALWAYS_INLINE Vectorized(
43       int16_t scalar1,
44       int16_t scalar2,
45       int16_t scalar3,
46       int16_t scalar4,
47       int16_t scalar5,
48       int16_t scalar6,
49       int16_t scalar7,
50       int16_t scalar8,
51       int16_t scalar9,
52       int16_t scalar10,
53       int16_t scalar11,
54       int16_t scalar12,
55       int16_t scalar13,
56       int16_t scalar14,
57       int16_t scalar15,
58       int16_t scalar16)
59       : _vec0{vint16{
60             scalar1,
61             scalar2,
62             scalar3,
63             scalar4,
64             scalar5,
65             scalar6,
66             scalar7,
67             scalar8}},
68         _vec1{vint16{
69             scalar9,
70             scalar10,
71             scalar11,
72             scalar12,
73             scalar13,
74             scalar14,
75             scalar15,
76             scalar16}} {}
vec0()77   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
78     return _vec0;
79   }
vec1()80   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
81     return _vec1;
82   }
83 
84   template <uint64_t mask>
85   static std::enable_if_t<mask == 0, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)86   blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
87     return a;
88   }
89 
90   template <uint64_t mask>
91   static std::enable_if_t<(mask & 65535) == 65535, Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)92       C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
93     return b;
94   }
95 
96   template <uint64_t mask>
97   static std::enable_if_t<mask == 255, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)98   blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
99     return {b._vec0, a._vec1};
100   }
101 
102   template <uint64_t mask>
103   static std::enable_if_t<(mask > 0 && mask < 255), Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)104       C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
105     constexpr int16_t g0 = (mask & 1) * 0xffff;
106     constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
107     constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
108     constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
109     constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
110     constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
111     constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
112     constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
113     const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
114 
115     return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1};
116   }
117 
118   template <uint64_t mask>
119   static std::enable_if_t<
120       (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)),
121       Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)122       C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
123     constexpr int16_t g0_2 = (mask & 1) * 0xffff;
124     constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
125     constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
126     constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
127     constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
128     constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
129     constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
130     constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
131 
132     const vint16 mask_2nd =
133         vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
134     // generated masks
135     return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
136   }
137 
138   template <uint64_t mask>
139   static std::enable_if_t<
140       (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)),
141       Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)142       C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
143     constexpr int16_t mask2 = (mask & 65535) >> 16;
144     constexpr int16_t g0_2 = (mask & 1) * 0xffff;
145     constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
146     constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
147     constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
148     constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
149     constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
150     constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
151     constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
152 
153     const vint16 mask_2nd =
154         vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
155     // generated masks
156     return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
157   }
158 
159   template <uint64_t mask>
160   static std::enable_if_t<
161       (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) &&
162        ((mask & 255) != 255)),
163       Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)164       C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
165     constexpr int16_t g0 = (mask & 1) * 0xffff;
166     constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
167     constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
168     constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
169     constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
170     constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
171     constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
172     constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
173     constexpr int16_t mask2 = (mask & 65535) >> 16;
174     constexpr int16_t g0_2 = (mask & 1) * 0xffff;
175     constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
176     constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
177     constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
178     constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
179     constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
180     constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
181     constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
182 
183     const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
184     const vint16 mask_2nd =
185         vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
186     // generated masks
187     return {
188         (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st),
189         (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
190   }
191 
blendv(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & mask)192   static Vectorized<int16_t> C10_ALWAYS_INLINE blendv(
193       const Vectorized<int16_t>& a,
194       const Vectorized<int16_t>& b,
195       const Vectorized<int16_t>& mask) {
196     // the mask used here returned by comparision of vec256
197     // assuming this we can use the same mask directly with vec_sel
198     // warning intel style mask will not work properly
199     return {
200         vec_sel(a._vec0, b._vec0, mask._vecb0),
201         vec_sel(a._vec1, b._vec1, mask._vecb1)};
202   }
203 
204   template <typename step_t>
205   static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
206     return Vectorized<int16_t>(
207         base,
208         base + step,
209         base + 2 * step,
210         base + 3 * step,
211         base + 4 * step,
212         base + 5 * step,
213         base + 6 * step,
214         base + 7 * step,
215         base + 8 * step,
216         base + 9 * step,
217         base + 10 * step,
218         base + 11 * step,
219         base + 12 * step,
220         base + 13 * step,
221         base + 14 * step,
222         base + 15 * step);
223   }
224   static Vectorized<int16_t> set(
225       const Vectorized<int16_t>& a,
226       const Vectorized<int16_t>& b,
227       size_t count = size()) {
228     switch (count) {
229       case 0:
230         return a;
231       case 1:
232         return blend<1>(a, b);
233       case 2:
234         return blend<3>(a, b);
235       case 3:
236         return blend<7>(a, b);
237       case 4:
238         return blend<15>(a, b);
239       case 5:
240         return blend<31>(a, b);
241       case 6:
242         return blend<63>(a, b);
243       case 7:
244         return blend<127>(a, b);
245       case 8:
246         return blend<255>(a, b);
247       case 9:
248         return blend<511>(a, b);
249       case 10:
250         return blend<1023>(a, b);
251       case 11:
252         return blend<2047>(a, b);
253       case 12:
254         return blend<4095>(a, b);
255       case 13:
256         return blend<8191>(a, b);
257       case 14:
258         return blend<16383>(a, b);
259       case 15:
260         return blend<32767>(a, b);
261     }
262     return b;
263   }
264   static Vectorized<value_type> C10_ALWAYS_INLINE
265   loadu(const void* ptr, int count = size()) {
266     if (count == size()) {
267       return {
268           vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
269           vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
270     }
271 
272     __at_align__ value_type tmp_values[size()] = {};
273     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
274 
275     return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
276   }
277   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
278     if (count == size()) {
279       vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
280       vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
281     } else if (count > 0) {
282       __at_align__ value_type tmp_values[size()];
283       vec_vsx_st(_vec0, offset0, tmp_values);
284       vec_vsx_st(_vec1, offset16, tmp_values);
285       std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
286     }
287   }
288   const int16_t& operator[](int idx) const = delete;
289   int16_t& operator[](int idx) = delete;
290 
angle()291   Vectorized<int16_t> angle() const {
292     return blendv(
293       Vectorized<int16_t>(0), Vectorized<int16_t>(c10::pi<int16_t>), *this < Vectorized<int16_t>(0));
294   }
real()295   Vectorized<int16_t> real() const {
296     return *this;
297   }
imag()298   Vectorized<int16_t> imag() const {
299     return Vectorized<int16_t>{0};
300   }
conj()301   Vectorized<int16_t> conj() const {
302     return *this;
303   }
304 
abs()305   Vectorized<int16_t> C10_ALWAYS_INLINE abs() const {
306     return {vec_abs(_vec0), vec_abs(_vec1)};
307   }
308 
neg()309   Vectorized<int16_t> C10_ALWAYS_INLINE neg() const {
310     return {vec_neg(_vec0), vec_neg(_vec1)};
311   }
312 
313   DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not)
314   DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq)
315   DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne)
316   DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt)
317   DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple)
318   DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt)
319   DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge)
320   DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq)
321   DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne)
322   DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt)
323   DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple)
324   DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt)
325   DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge)
326   DEFINE_MEMBER_OP(operator+, int16_t, vec_add)
327   DEFINE_MEMBER_OP(operator-, int16_t, vec_sub)
328   DEFINE_MEMBER_OP(operator*, int16_t, vec_mul)
329   DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /)
330   DEFINE_MEMBER_OP(maximum, int16_t, vec_max)
331   DEFINE_MEMBER_OP(minimum, int16_t, vec_min)
332   DEFINE_MEMBER_OP(operator&, int16_t, vec_and)
333   DEFINE_MEMBER_OP(operator|, int16_t, vec_or)
334   DEFINE_MEMBER_OP(operator^, int16_t, vec_xor)
335 };
336 
337 template <>
338 Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
339                vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
340                vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1());
341          return Vectorized<int16_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
342 }
343 
344 template <>
345 Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
346                vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
347                vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1()) ;
348          return Vectorized<int16_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
349 }
350 
351 template <>
maximum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)352 Vectorized<int16_t> inline maximum(
353     const Vectorized<int16_t>& a,
354     const Vectorized<int16_t>& b) {
355   return a.maximum(b);
356 }
357 
358 template <>
minimum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)359 Vectorized<int16_t> inline minimum(
360     const Vectorized<int16_t>& a,
361     const Vectorized<int16_t>& b) {
362   return a.minimum(b);
363 }
364 
365 template <>
366 Vectorized<int16_t> C10_ALWAYS_INLINE operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
367   return Vectorized<int16_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
368 }
369 
370 template <>
371 Vectorized<int16_t> C10_ALWAYS_INLINE operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
372   return Vectorized<int16_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
373 }
374 
375 template <>
376 Vectorized<int16_t> C10_ALWAYS_INLINE operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
377   return Vectorized<int16_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
378 }
379 
380 template <>
381 Vectorized<int16_t> C10_ALWAYS_INLINE operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
382   return Vectorized<int16_t>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
383 }
384 
385 template <>
386 Vectorized<int16_t> C10_ALWAYS_INLINE operator&(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
387   return Vectorized<int16_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
388 }
389 
390 template <>
391 Vectorized<int16_t> C10_ALWAYS_INLINE operator|(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
392   return Vectorized<int16_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
393 }
394 
395 template <>
396 Vectorized<int16_t> C10_ALWAYS_INLINE operator^(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
397   return Vectorized<int16_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
398 }
399 
400 } // namespace
401 } // namespace vec
402 } // namespace at
403