xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 namespace at {
7 namespace vec {
8 // See Note [CPU_CAPABILITY namespace]
9 inline namespace CPU_CAPABILITY {
10 
11 template <>
12 class Vectorized<int64_t> {
13  private:
14   union {
15     struct {
16       vint64 _vec0;
17       vint64 _vec1;
18     };
19     struct {
20       vbool64 _vecb0;
21       vbool64 _vecb1;
22     };
23 
24   } __attribute__((__may_alias__));
25 
26  public:
27   using value_type = int64_t;
28   using vec_internal_type = vint64;
29   using vec_internal_mask_type = vbool64;
30   using size_type = int;
31   using ElementType = signed long long;
size()32   static constexpr size_type size() {
33     return 4;
34   }
Vectorized()35   Vectorized() {}
Vectorized(vint64 v)36   C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool64 vmask)37   C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vint64 v1,vint64 v2)38   C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool64 v1,vbool64 v2)39   C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(int64_t scalar)40   C10_ALWAYS_INLINE Vectorized(int64_t scalar)
41       : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
Vectorized(int64_t scalar1,int64_t scalar2,int64_t scalar3,int64_t scalar4)42   C10_ALWAYS_INLINE Vectorized(
43       int64_t scalar1,
44       int64_t scalar2,
45       int64_t scalar3,
46       int64_t scalar4)
47       : _vec0{vint64{scalar1, scalar2}}, _vec1{vint64{scalar3, scalar4}} {}
48 
vec0()49   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
50     return _vec0;
51   }
vec1()52   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
53     return _vec1;
54   }
55 
56   template <uint64_t mask>
57   static std::enable_if_t<mask == 0, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)58   blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
59     return a;
60   }
61 
62   template <uint64_t mask>
63   static std::enable_if_t<mask == 3, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)64   blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
65     return {b._vec0, a._vec1};
66   }
67 
68   template <uint64_t mask>
69   static std::enable_if_t<(mask & 15) == 15, Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)70   blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
71     return b;
72   }
73 
74   template <uint64_t mask>
75   static std::enable_if_t<(mask > 0 && mask < 3), Vectorized<int64_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)76   blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
77     constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
78     constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
79     const vbool64 mask_1st = (vbool64){g0, g1};
80     return {(vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st), a._vec1};
81   }
82 
83   template <uint64_t mask>
84   static std::enable_if_t<(mask > 3) && (mask & 3) == 0, Vectorized<int64_t>>
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)85       C10_ALWAYS_INLINE blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
86     constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
87     constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
88 
89     const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
90     return {a._vec0, (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
91   }
92 
93   template <uint64_t mask>
94   static std::enable_if_t<
95       (mask > 3) && (mask & 3) != 0 && (mask & 15) != 15,
96       Vectorized<int64_t>>
blend(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)97       C10_ALWAYS_INLINE blend(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
98     constexpr uint64_t g0 = (mask & 1) * 0xffffffffffffffff;
99     constexpr uint64_t g1 = ((mask & 2) >> 1) * 0xffffffffffffffff;
100     constexpr uint64_t g0_2 = ((mask & 4) >> 2) * 0xffffffffffffffff;
101     constexpr uint64_t g1_2 = ((mask & 8) >> 3) * 0xffffffffffffffff;
102 
103     const vbool64 mask_1st = (vbool64){g0, g1};
104     const vbool64 mask_2nd = (vbool64){g0_2, g1_2};
105     return {
106         (vint64)vec_sel(a._vec0, b._vec0, (vbool64)mask_1st),
107         (vint64)vec_sel(a._vec1, b._vec1, (vbool64)mask_2nd)};
108   }
109 
blendv(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & mask)110   static Vectorized<int64_t> C10_ALWAYS_INLINE blendv(
111       const Vectorized<int64_t>& a,
112       const Vectorized<int64_t>& b,
113       const Vectorized<int64_t>& mask) {
114     // the mask used here returned by comparision of vec256
115 
116     return {
117         vec_sel(a._vec0, b._vec0, mask._vecb0),
118         vec_sel(a._vec1, b._vec1, mask._vecb1)};
119   }
120   template <typename step_t>
121   static Vectorized<int64_t> arange(int64_t base = 0., step_t step = static_cast<step_t>(1)) {
122     return Vectorized<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
123   }
124 
125   static Vectorized<int64_t> C10_ALWAYS_INLINE
126   set(const Vectorized<int64_t>& a,
127       const Vectorized<int64_t>& b,
128       size_t count = size()) {
129     switch (count) {
130       case 0:
131         return a;
132       case 1:
133         return blend<1>(a, b);
134       case 2:
135         return blend<3>(a, b);
136       case 3:
137         return blend<7>(a, b);
138     }
139 
140     return b;
141   }
142   static Vectorized<value_type> C10_ALWAYS_INLINE
143   loadu(const void* ptr, int count = size()) {
144     if (count == size()) {
145       static_assert(sizeof(double) == sizeof(value_type));
146       const double* dptr = reinterpret_cast<const double*>(ptr);
147       return {// treat it as double load
148               (vint64)vec_vsx_ld(offset0, dptr),
149               (vint64)vec_vsx_ld(offset16, dptr)};
150     }
151 
152     __at_align__ double tmp_values[size()] = {};
153     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
154 
155     return {
156         (vint64)vec_vsx_ld(offset0, tmp_values),
157         (vint64)vec_vsx_ld(offset16, tmp_values)};
158   }
159   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
160     if (count == size()) {
161       double* dptr = reinterpret_cast<double*>(ptr);
162       vec_vsx_st((vfloat64)_vec0, offset0, dptr);
163       vec_vsx_st((vfloat64)_vec1, offset16, dptr);
164     } else if (count > 0) {
165       __at_align__ double tmp_values[size()];
166       vec_vsx_st((vfloat64)_vec0, offset0, tmp_values);
167       vec_vsx_st((vfloat64)_vec1, offset16, tmp_values);
168       std::memcpy(
169           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
170     }
171   }
172   const int64_t& operator[](int idx) const = delete;
173   int64_t& operator[](int idx) = delete;
174 
angle()175   Vectorized<int64_t> angle() const {
176     return blendv(
177       Vectorized<int64_t>(0), Vectorized<int64_t>(c10::pi<int64_t>), *this < Vectorized<int64_t>(0));
178   }
real()179   Vectorized<int64_t> real() const {
180     return *this;
181   }
imag()182   Vectorized<int64_t> imag() const {
183     return Vectorized<int64_t>{0};
184   }
conj()185   Vectorized<int64_t> conj() const {
186     return *this;
187   }
188 
abs()189   Vectorized<int64_t> C10_ALWAYS_INLINE abs() const {
190     return {vec_abs(_vec0), vec_abs(_vec1)};
191   }
192 
neg()193   Vectorized<int64_t> C10_ALWAYS_INLINE neg() const {
194     return {vec_neg(_vec0), vec_neg(_vec1)};
195   }
196 
197   DEFINE_MEMBER_UNARY_OP(operator~, int64_t, vec_not)
198   DEFINE_MEMBER_OP(operator==, int64_t, vec_cmpeq)
199   DEFINE_MEMBER_OP(operator!=, int64_t, vec_cmpne)
200   DEFINE_MEMBER_OP(operator<, int64_t, vec_cmplt)
201   DEFINE_MEMBER_OP(operator<=, int64_t, vec_cmple)
202   DEFINE_MEMBER_OP(operator>, int64_t, vec_cmpgt)
203   DEFINE_MEMBER_OP(operator>=, int64_t, vec_cmpge)
204   DEFINE_MEMBER_OP_AND_ONE(eq, int64_t, vec_cmpeq)
205   DEFINE_MEMBER_OP_AND_ONE(ne, int64_t, vec_cmpne)
206   DEFINE_MEMBER_OP_AND_ONE(lt, int64_t, vec_cmplt)
207   DEFINE_MEMBER_OP_AND_ONE(le, int64_t, vec_cmple)
208   DEFINE_MEMBER_OP_AND_ONE(gt, int64_t, vec_cmpgt)
209   DEFINE_MEMBER_OP_AND_ONE(ge, int64_t, vec_cmpge)
210   DEFINE_MEMBER_OP(operator+, int64_t, vec_add)
211   DEFINE_MEMBER_OP(operator-, int64_t, vec_sub)
212   DEFINE_MEMBER_OP(operator*, int64_t, vec_mul)
213   DEFINE_MEMBER_OP(operator/, int64_t, vec_div)
214   DEFINE_MEMBER_OP(maximum, int64_t, vec_max)
215   DEFINE_MEMBER_OP(minimum, int64_t, vec_min)
216   DEFINE_MEMBER_OP(operator&, int64_t, vec_and)
217   DEFINE_MEMBER_OP(operator|, int64_t, vec_or)
218   DEFINE_MEMBER_OP(operator^, int64_t, vec_xor)
219 };
220 
221 template <>
222 Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
223                 vuint64 shift_vec0 = reinterpret_cast<vuint64>(b.vec0());
224                 vuint64 shift_vec1 = reinterpret_cast<vuint64>(b.vec1()) ;
225           return Vectorized<int64_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
226 }
227 
228 template <>
229 Vectorized<int64_t> inline operator>>(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
230                 vuint64 shift_vec0 = reinterpret_cast<vuint64>(b.vec0());
231                 vuint64 shift_vec1 = reinterpret_cast<vuint64>(b.vec1()) ;
232           return Vectorized<int64_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
233 }
234 
235 template <>
maximum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)236 Vectorized<int64_t> inline maximum(
237     const Vectorized<int64_t>& a,
238     const Vectorized<int64_t>& b) {
239   return a.maximum(b);
240 }
241 
242 template <>
minimum(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b)243 Vectorized<int64_t> inline minimum(
244     const Vectorized<int64_t>& a,
245     const Vectorized<int64_t>& b) {
246   return a.minimum(b);
247 }
248 
249 template <>
250 Vectorized<int64_t> C10_ALWAYS_INLINE operator+(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
251   return Vectorized<int64_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
252 }
253 
254 template <>
255 Vectorized<int64_t> C10_ALWAYS_INLINE operator-(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
256   return Vectorized<int64_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
257 }
258 
259 template <>
260 Vectorized<int64_t> C10_ALWAYS_INLINE operator*(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
261   return Vectorized<int64_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
262 }
263 
264 template <>
265 Vectorized<int64_t> C10_ALWAYS_INLINE operator/(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
266   return Vectorized<int64_t>{vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())};
267 }
268 
269 template <>
270 Vectorized<int64_t> C10_ALWAYS_INLINE operator&(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
271   return Vectorized<int64_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
272 }
273 
274 template <>
275 Vectorized<int64_t> C10_ALWAYS_INLINE operator|(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
276   return Vectorized<int64_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
277 }
278 
279 template <>
280 Vectorized<int64_t> C10_ALWAYS_INLINE operator^(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
281   return Vectorized<int64_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
282 }
283 
284 } // namespace
285 } // namespace vec
286 } // namespace at
287