1 #pragma once
2
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 namespace at {
7 namespace vec {
8 // See Note [CPU_CAPABILITY namespace]
9 inline namespace CPU_CAPABILITY {
10
11 template <>
12 class Vectorized<int16_t> {
13 private:
14 union {
15 struct {
16 vint16 _vec0;
17 vint16 _vec1;
18 };
19 struct {
20 vbool16 _vecb0;
21 vbool16 _vecb1;
22 };
23
24 } __attribute__((__may_alias__));
25
26 public:
27 using value_type = int16_t;
28 using vec_internal_type = vint16;
29 using vec_internal_mask_type = vbool16;
30 using size_type = int;
size()31 static constexpr size_type size() {
32 return 16;
33 }
Vectorized()34 Vectorized() {}
Vectorized(vint16 v)35 C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool16 vmask)36 C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vint16 v1,vint16 v2)37 C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool16 v1,vbool16 v2)38 C10_ALWAYS_INLINE Vectorized(vbool16 v1, vbool16 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(int16_t scalar)39 C10_ALWAYS_INLINE Vectorized(int16_t scalar)
40 : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
41
Vectorized(int16_t scalar1,int16_t scalar2,int16_t scalar3,int16_t scalar4,int16_t scalar5,int16_t scalar6,int16_t scalar7,int16_t scalar8,int16_t scalar9,int16_t scalar10,int16_t scalar11,int16_t scalar12,int16_t scalar13,int16_t scalar14,int16_t scalar15,int16_t scalar16)42 C10_ALWAYS_INLINE Vectorized(
43 int16_t scalar1,
44 int16_t scalar2,
45 int16_t scalar3,
46 int16_t scalar4,
47 int16_t scalar5,
48 int16_t scalar6,
49 int16_t scalar7,
50 int16_t scalar8,
51 int16_t scalar9,
52 int16_t scalar10,
53 int16_t scalar11,
54 int16_t scalar12,
55 int16_t scalar13,
56 int16_t scalar14,
57 int16_t scalar15,
58 int16_t scalar16)
59 : _vec0{vint16{
60 scalar1,
61 scalar2,
62 scalar3,
63 scalar4,
64 scalar5,
65 scalar6,
66 scalar7,
67 scalar8}},
68 _vec1{vint16{
69 scalar9,
70 scalar10,
71 scalar11,
72 scalar12,
73 scalar13,
74 scalar14,
75 scalar15,
76 scalar16}} {}
vec0()77 C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
78 return _vec0;
79 }
vec1()80 C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
81 return _vec1;
82 }
83
84 template <uint64_t mask>
85 static std::enable_if_t<mask == 0, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)86 blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
87 return a;
88 }
89
90 template <uint64_t mask>
91 static std::enable_if_t<(mask & 65535) == 65535, Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)92 C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
93 return b;
94 }
95
96 template <uint64_t mask>
97 static std::enable_if_t<mask == 255, Vectorized<int16_t>> C10_ALWAYS_INLINE
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)98 blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
99 return {b._vec0, a._vec1};
100 }
101
102 template <uint64_t mask>
103 static std::enable_if_t<(mask > 0 && mask < 255), Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)104 C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
105 constexpr int16_t g0 = (mask & 1) * 0xffff;
106 constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
107 constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
108 constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
109 constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
110 constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
111 constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
112 constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
113 const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
114
115 return {(vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st), a._vec1};
116 }
117
118 template <uint64_t mask>
119 static std::enable_if_t<
120 (mask > 255 && (mask & 65535) != 65535 && ((mask & 255) == 255)),
121 Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)122 C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
123 constexpr int16_t g0_2 = (mask & 1) * 0xffff;
124 constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
125 constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
126 constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
127 constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
128 constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
129 constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
130 constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
131
132 const vint16 mask_2nd =
133 vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
134 // generated masks
135 return {b._vec0, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
136 }
137
138 template <uint64_t mask>
139 static std::enable_if_t<
140 (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) == 0)),
141 Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)142 C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
143 constexpr int16_t mask2 = (mask & 65535) >> 16;
144 constexpr int16_t g0_2 = (mask & 1) * 0xffff;
145 constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
146 constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
147 constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
148 constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
149 constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
150 constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
151 constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
152
153 const vint16 mask_2nd =
154 vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
155 // generated masks
156 return {a, (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
157 }
158
159 template <uint64_t mask>
160 static std::enable_if_t<
161 (mask > 255 && ((mask & 65535) != 65535) && ((mask & 255) != 0) &&
162 ((mask & 255) != 255)),
163 Vectorized<int16_t>>
blend(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)164 C10_ALWAYS_INLINE blend(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
165 constexpr int16_t g0 = (mask & 1) * 0xffff;
166 constexpr int16_t g1 = ((mask & 2) >> 1) * 0xffff;
167 constexpr int16_t g2 = ((mask & 4) >> 2) * 0xffff;
168 constexpr int16_t g3 = ((mask & 8) >> 3) * 0xffff;
169 constexpr int16_t g4 = ((mask & 16) >> 4) * 0xffff;
170 constexpr int16_t g5 = ((mask & 32) >> 5) * 0xffff;
171 constexpr int16_t g6 = ((mask & 64) >> 6) * 0xffff;
172 constexpr int16_t g7 = ((mask & 128) >> 7) * 0xffff;
173 constexpr int16_t mask2 = (mask & 65535) >> 16;
174 constexpr int16_t g0_2 = (mask & 1) * 0xffff;
175 constexpr int16_t g1_2 = ((mask & 2) >> 1) * 0xffff;
176 constexpr int16_t g2_2 = ((mask & 4) >> 2) * 0xffff;
177 constexpr int16_t g3_2 = ((mask & 8) >> 3) * 0xffff;
178 constexpr int16_t g4_2 = ((mask & 16) >> 4) * 0xffff;
179 constexpr int16_t g5_2 = ((mask & 32) >> 5) * 0xffff;
180 constexpr int16_t g6_2 = ((mask & 64) >> 6) * 0xffff;
181 constexpr int16_t g7_2 = ((mask & 128) >> 7) * 0xffff;
182
183 const vint16 mask_1st = vint16{g0, g1, g2, g3, g4, g5, g6, g7};
184 const vint16 mask_2nd =
185 vint16{g0_2, g1_2, g2_2, g3_2, g4_2, g5_2, g6_2, g7_2};
186 // generated masks
187 return {
188 (vint16)vec_sel(a._vec0, b._vec0, (vbool16)mask_1st),
189 (vint16)vec_sel(a._vec1, b._vec1, (vbool16)mask_2nd)};
190 }
191
blendv(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & mask)192 static Vectorized<int16_t> C10_ALWAYS_INLINE blendv(
193 const Vectorized<int16_t>& a,
194 const Vectorized<int16_t>& b,
195 const Vectorized<int16_t>& mask) {
196 // the mask used here returned by comparision of vec256
197 // assuming this we can use the same mask directly with vec_sel
198 // warning intel style mask will not work properly
199 return {
200 vec_sel(a._vec0, b._vec0, mask._vecb0),
201 vec_sel(a._vec1, b._vec1, mask._vecb1)};
202 }
203
204 template <typename step_t>
205 static Vectorized<int16_t> arange(int16_t base = 0, step_t step = static_cast<step_t>(1)) {
206 return Vectorized<int16_t>(
207 base,
208 base + step,
209 base + 2 * step,
210 base + 3 * step,
211 base + 4 * step,
212 base + 5 * step,
213 base + 6 * step,
214 base + 7 * step,
215 base + 8 * step,
216 base + 9 * step,
217 base + 10 * step,
218 base + 11 * step,
219 base + 12 * step,
220 base + 13 * step,
221 base + 14 * step,
222 base + 15 * step);
223 }
224 static Vectorized<int16_t> set(
225 const Vectorized<int16_t>& a,
226 const Vectorized<int16_t>& b,
227 size_t count = size()) {
228 switch (count) {
229 case 0:
230 return a;
231 case 1:
232 return blend<1>(a, b);
233 case 2:
234 return blend<3>(a, b);
235 case 3:
236 return blend<7>(a, b);
237 case 4:
238 return blend<15>(a, b);
239 case 5:
240 return blend<31>(a, b);
241 case 6:
242 return blend<63>(a, b);
243 case 7:
244 return blend<127>(a, b);
245 case 8:
246 return blend<255>(a, b);
247 case 9:
248 return blend<511>(a, b);
249 case 10:
250 return blend<1023>(a, b);
251 case 11:
252 return blend<2047>(a, b);
253 case 12:
254 return blend<4095>(a, b);
255 case 13:
256 return blend<8191>(a, b);
257 case 14:
258 return blend<16383>(a, b);
259 case 15:
260 return blend<32767>(a, b);
261 }
262 return b;
263 }
264 static Vectorized<value_type> C10_ALWAYS_INLINE
265 loadu(const void* ptr, int count = size()) {
266 if (count == size()) {
267 return {
268 vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
269 vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
270 }
271
272 __at_align__ value_type tmp_values[size()] = {};
273 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
274
275 return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
276 }
277 void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
278 if (count == size()) {
279 vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
280 vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
281 } else if (count > 0) {
282 __at_align__ value_type tmp_values[size()];
283 vec_vsx_st(_vec0, offset0, tmp_values);
284 vec_vsx_st(_vec1, offset16, tmp_values);
285 std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
286 }
287 }
288 const int16_t& operator[](int idx) const = delete;
289 int16_t& operator[](int idx) = delete;
290
angle()291 Vectorized<int16_t> angle() const {
292 return blendv(
293 Vectorized<int16_t>(0), Vectorized<int16_t>(c10::pi<int16_t>), *this < Vectorized<int16_t>(0));
294 }
real()295 Vectorized<int16_t> real() const {
296 return *this;
297 }
imag()298 Vectorized<int16_t> imag() const {
299 return Vectorized<int16_t>{0};
300 }
conj()301 Vectorized<int16_t> conj() const {
302 return *this;
303 }
304
abs()305 Vectorized<int16_t> C10_ALWAYS_INLINE abs() const {
306 return {vec_abs(_vec0), vec_abs(_vec1)};
307 }
308
neg()309 Vectorized<int16_t> C10_ALWAYS_INLINE neg() const {
310 return {vec_neg(_vec0), vec_neg(_vec1)};
311 }
312
313 DEFINE_MEMBER_UNARY_OP(operator~, int16_t, vec_not)
314 DEFINE_MEMBER_OP(operator==, int16_t, vec_cmpeq)
315 DEFINE_MEMBER_OP(operator!=, int16_t, vec_cmpne)
316 DEFINE_MEMBER_OP(operator<, int16_t, vec_cmplt)
317 DEFINE_MEMBER_OP(operator<=, int16_t, vec_cmple)
318 DEFINE_MEMBER_OP(operator>, int16_t, vec_cmpgt)
319 DEFINE_MEMBER_OP(operator>=, int16_t, vec_cmpge)
320 DEFINE_MEMBER_OP_AND_ONE(eq, int16_t, vec_cmpeq)
321 DEFINE_MEMBER_OP_AND_ONE(ne, int16_t, vec_cmpne)
322 DEFINE_MEMBER_OP_AND_ONE(lt, int16_t, vec_cmplt)
323 DEFINE_MEMBER_OP_AND_ONE(le, int16_t, vec_cmple)
324 DEFINE_MEMBER_OP_AND_ONE(gt, int16_t, vec_cmpgt)
325 DEFINE_MEMBER_OP_AND_ONE(ge, int16_t, vec_cmpge)
326 DEFINE_MEMBER_OP(operator+, int16_t, vec_add)
327 DEFINE_MEMBER_OP(operator-, int16_t, vec_sub)
328 DEFINE_MEMBER_OP(operator*, int16_t, vec_mul)
329 DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, int16_t, /)
330 DEFINE_MEMBER_OP(maximum, int16_t, vec_max)
331 DEFINE_MEMBER_OP(minimum, int16_t, vec_min)
332 DEFINE_MEMBER_OP(operator&, int16_t, vec_and)
333 DEFINE_MEMBER_OP(operator|, int16_t, vec_or)
334 DEFINE_MEMBER_OP(operator^, int16_t, vec_xor)
335 };
336
337 template <>
338 Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
339 vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
340 vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1());
341 return Vectorized<int16_t>{vec_sl(a.vec0(), shift_vec0), vec_sl(a.vec1(), shift_vec1)};
342 }
343
344 template <>
345 Vectorized<int16_t> inline operator>>(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
346 vuint16 shift_vec0 = reinterpret_cast<vuint16>(b.vec0());
347 vuint16 shift_vec1 = reinterpret_cast<vuint16>(b.vec1()) ;
348 return Vectorized<int16_t>{vec_sr(a.vec0(), shift_vec0), vec_sr(a.vec1(), shift_vec1)};
349 }
350
351 template <>
maximum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)352 Vectorized<int16_t> inline maximum(
353 const Vectorized<int16_t>& a,
354 const Vectorized<int16_t>& b) {
355 return a.maximum(b);
356 }
357
358 template <>
minimum(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b)359 Vectorized<int16_t> inline minimum(
360 const Vectorized<int16_t>& a,
361 const Vectorized<int16_t>& b) {
362 return a.minimum(b);
363 }
364
365 template <>
366 Vectorized<int16_t> C10_ALWAYS_INLINE operator+(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
367 return Vectorized<int16_t>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
368 }
369
370 template <>
371 Vectorized<int16_t> C10_ALWAYS_INLINE operator-(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
372 return Vectorized<int16_t>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
373 }
374
375 template <>
376 Vectorized<int16_t> C10_ALWAYS_INLINE operator*(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
377 return Vectorized<int16_t>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
378 }
379
380 template <>
381 Vectorized<int16_t> C10_ALWAYS_INLINE operator/(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
382 return Vectorized<int16_t>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
383 }
384
385 template <>
386 Vectorized<int16_t> C10_ALWAYS_INLINE operator&(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
387 return Vectorized<int16_t>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
388 }
389
390 template <>
391 Vectorized<int16_t> C10_ALWAYS_INLINE operator|(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
392 return Vectorized<int16_t>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
393 }
394
395 template <>
396 Vectorized<int16_t> C10_ALWAYS_INLINE operator^(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
397 return Vectorized<int16_t>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
398 }
399
400 } // namespace
401 } // namespace vec
402 } // namespace at
403