1 #pragma once
2
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 #include <sleef.h>
7 namespace at {
8 namespace vec {
9 // See Note [CPU_CAPABILITY namespace]
10
11 inline namespace CPU_CAPABILITY {
12
13 template <>
14 class Vectorized<float> {
15 private:
16 union {
17 struct {
18 vfloat32 _vec0;
19 vfloat32 _vec1;
20 };
21 struct {
22 vbool32 _vecb0;
23 vbool32 _vecb1;
24 };
25
26 } __attribute__((__may_alias__));
27
28 public:
29 using value_type = float;
30 using vec_internal_type = vfloat32;
31 using vec_internal_mask_type = vbool32;
32 using size_type = int;
33
size()34 static constexpr size_type size() {
35 return 8;
36 }
Vectorized()37 Vectorized() {}
38
Vectorized(vfloat32 v)39 C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool32 vmask)40 C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vfloat32 v1,vfloat32 v2)41 C10_ALWAYS_INLINE Vectorized(vfloat32 v1, vfloat32 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool32 v1,vbool32 v2)42 C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
Vectorized(float scalar)43 C10_ALWAYS_INLINE Vectorized(float scalar)
44 : _vec0{vec_splats(scalar)}, _vec1{vec_splats(scalar)} {}
Vectorized(float scalar1,float scalar2,float scalar3,float scalar4,float scalar5,float scalar6,float scalar7,float scalar8)45 C10_ALWAYS_INLINE Vectorized(
46 float scalar1,
47 float scalar2,
48 float scalar3,
49 float scalar4,
50 float scalar5,
51 float scalar6,
52 float scalar7,
53 float scalar8)
54 : _vec0{vfloat32{scalar1, scalar2, scalar3, scalar4}},
55 _vec1{vfloat32{scalar5, scalar6, scalar7, scalar8}} {}
vec0()56 C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
57 return _vec0;
58 }
vec1()59 C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
60 return _vec1;
61 }
62
63 template <int64_t mask>
64 static std::enable_if_t<blendChoice(mask) == 0, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)65 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
66 return a;
67 }
68
69 template <int64_t mask>
70 static std::enable_if_t<blendChoice(mask) == 1, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)71 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
72 return b;
73 }
74
75 template <int64_t mask>
76 static std::enable_if_t<blendChoice(mask) == 2, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)77 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
78 return {b._vec0, a._vec1};
79 }
80
81 template <int64_t mask>
82 static std::enable_if_t<blendChoice(mask) == 3, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)83 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
84 return {a._vec0, b._vec1};
85 }
86
87 template <int64_t mask>
88 static std::enable_if_t<blendChoice(mask) == 4, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)89 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
90 const vbool32 mask_1st = VsxMask1(mask);
91 return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), a._vec1};
92 }
93
94 template <int64_t mask>
95 static std::enable_if_t<blendChoice(mask) == 5, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)96 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
97 const vbool32 mask_1st = VsxMask1(mask);
98 return {(vfloat32)vec_sel(a._vec0, b._vec0, mask_1st), b._vec1};
99 }
100
101 template <int64_t mask>
102 static std::enable_if_t<blendChoice(mask) == 6, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)103 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
104 const vbool32 mask_2nd = VsxMask2(mask);
105 // generated masks
106 return {a._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
107 }
108
109 template <int64_t mask>
110 static std::enable_if_t<blendChoice(mask) == 7, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)111 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
112 const vbool32 mask_2nd = VsxMask2(mask);
113 // generated masks
114 return {b._vec0, (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
115 }
116
117 template <int64_t mask>
118 static std::enable_if_t<blendChoice(mask) == 8, Vectorized<float>> C10_ALWAYS_INLINE
blend(const Vectorized<float> & a,const Vectorized<float> & b)119 blend(const Vectorized<float>& a, const Vectorized<float>& b) {
120 const vbool32 mask_1st = VsxMask1(mask);
121 const vbool32 mask_2nd = VsxMask2(mask);
122 return {
123 (vfloat32)vec_sel(a._vec0, b._vec0, mask_1st),
124 (vfloat32)vec_sel(a._vec1, b._vec1, mask_2nd)};
125 }
126
blendv(const Vectorized<float> & a,const Vectorized<float> & b,const Vectorized<float> & mask)127 static Vectorized<float> C10_ALWAYS_INLINE blendv(
128 const Vectorized<float>& a,
129 const Vectorized<float>& b,
130 const Vectorized<float>& mask) {
131 // the mask used here returned by comparision of vec256
132 // assuming this we can use the same mask directly with vec_sel
133 return {
134 vec_sel(a._vec0, b._vec0, mask._vecb0),
135 vec_sel(a._vec1, b._vec1, mask._vecb1)};
136 }
137
138 template <typename step_t>
139 static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
140 return Vectorized<float>(
141 base,
142 base + step,
143 base + 2 * step,
144 base + 3 * step,
145 base + 4 * step,
146 base + 5 * step,
147 base + 6 * step,
148 base + 7 * step);
149 }
150 static Vectorized<float> set(
151 const Vectorized<float>& a,
152 const Vectorized<float>& b,
153 size_t count = size()) {
154 switch (count) {
155 case 0:
156 return a;
157 case 1:
158 return blend<1>(a, b);
159 case 2:
160 return blend<3>(a, b);
161 case 3:
162 return blend<7>(a, b);
163 case 4:
164 return blend<15>(a, b);
165 case 5:
166 return blend<31>(a, b);
167 case 6:
168 return blend<63>(a, b);
169 case 7:
170 return blend<127>(a, b);
171 }
172
173 return b;
174 }
175 static Vectorized<value_type> C10_ALWAYS_INLINE
176 loadu(const void* ptr, int count = size()) {
177 if (count == size()) {
178 return {
179 vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
180 vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
181 }
182
183 __at_align__ value_type tmp_values[size()] = {};
184 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
185
186 return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
187 }
188 void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
189 if (count == size()) {
190 vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
191 vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
192 } else if (count > 0) {
193 __at_align__ value_type tmp_values[size()];
194 vec_vsx_st(_vec0, offset0, tmp_values);
195 vec_vsx_st(_vec1, offset16, tmp_values);
196 std::memcpy(
197 ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
198 }
199 }
200
201 const float& operator[](int idx) const = delete;
202 float& operator[](int idx) = delete;
203
map(float (* const f)(float))204 Vectorized<float> map(float (*const f)(float)) const {
205 Vectorized<float> ret;
206 for (int i = 0; i < size() / 2; i++) {
207 ret._vec0[i] = f(_vec0[i]);
208 }
209 for (int i = 0; i < size() / 2; i++) {
210 ret._vec1[i] = f(_vec1[i]);
211 }
212 return ret;
213 }
214
mapbi(float (* const f)(float,float),const Vectorized<float> & other)215 Vectorized<float> mapbi(float (*const f)(float, float), const Vectorized<float>& other)
216 const {
217 Vectorized<float> ret;
218 for (int i = 0; i < size() / 2; i++) {
219 ret._vec0[i] = f(_vec0[i], other._vec0[i]);
220 }
221 for (int i = 0; i < size() / 2; i++) {
222 ret._vec1[i] = f(_vec1[i], other._vec1[i]);
223 }
224 return ret;
225 }
226
_nor()227 Vectorized<float> _nor() const {
228 return {vec_nor(_vec0, _vec0), vec_nor(_vec1, _vec1)};
229 }
230
isnan()231 Vectorized<float> isnan() const {
232 auto x = *this;
233 auto ret = (x == x);
234 return ret._nor();
235 }
236
has_inf_nan()237 bool has_inf_nan() const {
238 for (const auto i : c10::irange(size()/2)) {
239 if(_isnan(_vec0[i]) || _isinf(_vec0[i])) {
240 return true;
241 }
242 }
243 for (const auto i : c10::irange(size()/2)) {
244 if(_isnan(_vec1[i]) || _isinf(_vec1[i])) {
245 return true;
246 }
247 }
248 return false;
249 }
250
zero_mask()251 int zero_mask() const {
252 // returns an integer mask where all zero elements are translated to 1-bit
253 // and others are translated to 0-bit
254 //__m256 cmp = _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_EQ_OQ);
255 auto cmp = (*this == zero);
256 // return _mm256_movemask_ps(cmp);
257 // possible simulation //mask= lvsl ( 0 ) vbpermq( vec, mask <<5)
258 vuint64 result0 = vec_vbpermq((vuint8)cmp._vecb0, mask_zero_bits);
259 vuint64 result1 = vec_vbpermq((vuint8)cmp._vecb1, mask_zero_bits);
260 return (result0[1] >> 12 | (result1[1] >> 8));
261 }
262
abs()263 Vectorized<float> C10_ALWAYS_INLINE abs() const {
264 return {vec_abs(_vec0), vec_abs(_vec1)};
265 }
266
acos()267 Vectorized<float> C10_ALWAYS_INLINE acos() const {
268 return {Sleef_acosf4_u10(_vec0), Sleef_acosf4_u10(_vec1)};
269 }
acosh()270 Vectorized<float> C10_ALWAYS_INLINE acosh() const {
271 return {Sleef_acoshf4_u10(_vec0), Sleef_acoshf4_u10(_vec1)};
272 }
asin()273 Vectorized<float> C10_ALWAYS_INLINE asin() const {
274 return {Sleef_asinf4_u10(_vec0), Sleef_asinf4_u10(_vec1)};
275 }
atan()276 Vectorized<float> atan() const {
277 return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
278 }
atanh()279 Vectorized<float> atanh() const {
280 return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)};
281 }
atan2(const Vectorized<float> & b)282 Vectorized<float> atan2(const Vectorized<float>& b) const {
283 return {Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)};
284 }
copysign(const Vectorized<float> & sign)285 Vectorized<float> copysign(const Vectorized<float> &sign) const {
286 return {Sleef_copysignf4(_vec0, sign._vec0), Sleef_copysignf4(_vec1, sign._vec1)};
287 }
lgamma()288 Vectorized<float> lgamma() const {
289 return {Sleef_lgammaf4_u10(_vec0), Sleef_lgammaf4_u10(_vec1)};
290 }
erf()291 Vectorized<float> erf() const {
292 return {Sleef_erff4_u10(_vec0), Sleef_erff4_u10(_vec1)};
293 }
294
erfc()295 Vectorized<float> erfc() const {
296 return {Sleef_erfcf4_u15(_vec0), Sleef_erfcf4_u15(_vec1)};
297 }
298
erfinv()299 Vectorized<float> erfinv() const {
300 return map(calc_erfinv);
301 }
302
angle()303 Vectorized<float> angle() const {
304 auto tmp = blendv(
305 Vectorized<float>(0), Vectorized<float>(c10::pi<float>), *this < Vectorized<float>(0));
306 return blendv(tmp, *this, isnan());
307 }
real()308 Vectorized<float> real() const {
309 return *this;
310 }
imag()311 Vectorized<float> imag() const {
312 return Vectorized<float>{0};
313 }
conj()314 Vectorized<float> conj() const {
315 return *this;
316 }
317
exp()318 Vectorized<float> C10_ALWAYS_INLINE exp() const {
319 return {Sleef_expf4_u10(_vec0), Sleef_expf4_u10(_vec1)};
320 }
exp2()321 Vectorized<float> C10_ALWAYS_INLINE exp2() const {
322 return {Sleef_exp2f4_u10(_vec0), Sleef_exp2f4_u10(_vec1)};
323 }
expm1()324 Vectorized<float> expm1() const {
325 return {Sleef_expm1f4_u10(_vec0), Sleef_expm1f4_u10(_vec1)};
326 }
exp_u20()327 Vectorized<float> C10_ALWAYS_INLINE exp_u20() const {
328 return exp();
329 }
330
log()331 Vectorized<float> C10_ALWAYS_INLINE log() const {
332 return {Sleef_logf4_u10(_vec0), Sleef_logf4_u10(_vec1)};
333 }
log10()334 Vectorized<float> C10_ALWAYS_INLINE log10() const {
335 return {Sleef_log10f4_u10(_vec0), Sleef_log10f4_u10(_vec1)};
336 }
log1p()337 Vectorized<float> C10_ALWAYS_INLINE log1p() const {
338 return {Sleef_log1pf4_u10(_vec0), Sleef_log1pf4_u10(_vec1)};
339 }
log2()340 Vectorized<float> C10_ALWAYS_INLINE log2() const {
341 return {Sleef_log2f4_u10(_vec0), Sleef_log2f4_u10(_vec1)};
342 }
ceil()343 Vectorized<float> C10_ALWAYS_INLINE ceil() const {
344 return {vec_ceil(_vec0), vec_ceil(_vec1)};
345 }
cos()346 Vectorized<float> C10_ALWAYS_INLINE cos() const {
347 return {Sleef_cosf4_u10(_vec0), Sleef_cosf4_u10(_vec1)};
348 }
cosh()349 Vectorized<float> C10_ALWAYS_INLINE cosh() const {
350 return {Sleef_coshf4_u10(_vec0), Sleef_coshf4_u10(_vec1)};
351 }
floor()352 Vectorized<float> C10_ALWAYS_INLINE floor() const {
353 return {vec_floor(_vec0), vec_floor(_vec1)};
354 }
neg()355 Vectorized<float> C10_ALWAYS_INLINE neg() const {
356 return {vec_neg(_vec0), vec_neg(_vec1)};
357 }
358
round()359 Vectorized<float> C10_ALWAYS_INLINE round() const {
360 return {vec_round(_vec0), vec_round(_vec1)};
361 }
sin()362 Vectorized<float> C10_ALWAYS_INLINE sin() const {
363 return {Sleef_sinf4_u10(_vec0), Sleef_sinf4_u10(_vec1)};
364 }
sinh()365 Vectorized<float> C10_ALWAYS_INLINE sinh() const {
366 return {Sleef_sinhf4_u10(_vec0), Sleef_sinhf4_u10(_vec1)};
367 }
tan()368 Vectorized<float> C10_ALWAYS_INLINE tan() const {
369 return {Sleef_tanf4_u10(_vec0), Sleef_tanf4_u10(_vec1)};
370 }
tanh()371 Vectorized<float> C10_ALWAYS_INLINE tanh() const {
372 return {Sleef_tanhf4_u10(_vec0), Sleef_tanhf4_u10(_vec1)};
373 }
trunc()374 Vectorized<float> C10_ALWAYS_INLINE trunc() const {
375 return {vec_trunc(_vec0), vec_trunc(_vec1)};
376 }
377
frac()378 Vectorized<float> C10_ALWAYS_INLINE frac() const {
379 return *this - trunc();
380 }
381
sqrt()382 Vectorized<float> C10_ALWAYS_INLINE sqrt() const {
383 return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
384 }
reciprocal()385 Vectorized<float> C10_ALWAYS_INLINE reciprocal() const {
386 return Vectorized<float>(one) / (*this);
387 }
rsqrt()388 Vectorized<float> C10_ALWAYS_INLINE rsqrt() const {
389 return sqrt().reciprocal();
390 }
391
pow(const Vectorized<float> & exp)392 Vectorized<float> C10_ALWAYS_INLINE pow(const Vectorized<float>& exp) const {
393 return {Sleef_powf4_u10(_vec0, exp._vec0), Sleef_powf4_u10(_vec1, exp._vec1)};
394 }
395
fmod(const Vectorized<float> & b)396 Vectorized<float> fmod(const Vectorized<float>& b) const {
397 return {Sleef_fmodf4(_vec0, b._vec0),Sleef_fmodf4(_vec1, b._vec1)};
398 }
399
hypot(const Vectorized<float> & b)400 Vectorized<float> hypot(const Vectorized<float>& b) const {
401 return {Sleef_hypotf4_u05(_vec0, b._vec0), Sleef_hypotf4_u05(_vec1, b._vec1)};
402 }
403
nextafter(const Vectorized<float> & b)404 Vectorized<float> nextafter(const Vectorized<float>& b) const {
405 return {Sleef_nextafterf4(_vec0, b._vec0), Sleef_nextafterf4(_vec1, b._vec1)};
406 }
407
igamma(const Vectorized<float> & x)408 Vectorized<float> igamma(const Vectorized<float>& x) const {
409 return mapbi(calc_igamma, x);
410 }
411
igammac(const Vectorized<float> & x)412 Vectorized<float> igammac(const Vectorized<float>& x) const {
413 return mapbi(calc_igammac, x);
414 }
415
i0()416 Vectorized<float> i0() const {
417 return map(calc_i0);
418 }
419
i0e()420 Vectorized<float> i0e() const {
421 return map(calc_i0e);
422 }
423
digamma()424 Vectorized<float> digamma() const {
425 return map(calc_digamma);
426 }
427
428 DEFINE_MEMBER_OP(operator==, float, vec_cmpeq)
429 DEFINE_MEMBER_OP(operator!=, float, vec_cmpne)
430 DEFINE_MEMBER_OP(operator<, float, vec_cmplt)
431 DEFINE_MEMBER_OP(operator<=, float, vec_cmple)
432 DEFINE_MEMBER_OP(operator>, float, vec_cmpgt)
433 DEFINE_MEMBER_OP(operator>=, float, vec_cmpge)
434 DEFINE_MEMBER_OP_AND_ONE(eq, float, vec_cmpeq)
435 DEFINE_MEMBER_OP_AND_ONE(ne, float, vec_cmpne)
436 DEFINE_MEMBER_OP_AND_ONE(lt, float, vec_cmplt)
437 DEFINE_MEMBER_OP_AND_ONE(le, float, vec_cmple)
438 DEFINE_MEMBER_OP_AND_ONE(gt, float, vec_cmpgt)
439 DEFINE_MEMBER_OP_AND_ONE(ge, float, vec_cmpge)
440 DEFINE_MEMBER_OP(operator+, float, vec_add)
441 DEFINE_MEMBER_OP(operator-, float, vec_sub)
442 DEFINE_MEMBER_OP(operator*, float, vec_mul)
443 DEFINE_MEMBER_OP(operator/, float, vec_div)
444 DEFINE_MEMBER_OP(maximum, float, vec_max_nan2)
445 DEFINE_MEMBER_OP(minimum, float, vec_min_nan2)
446 DEFINE_MEMBER_OP(operator&, float, vec_and)
447 DEFINE_MEMBER_OP(operator|, float, vec_or)
448 DEFINE_MEMBER_OP(operator^, float, vec_xor)
449 DEFINE_MEMBER_TERNARY_OP(madd, float, vec_madd)
450 };
451
452 template <>
maximum(const Vectorized<float> & a,const Vectorized<float> & b)453 Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
454 return a.maximum(b);
455 }
456
457 template <>
minimum(const Vectorized<float> & a,const Vectorized<float> & b)458 Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
459 return a.minimum(b);
460 }
461
462 template <>
463 Vectorized<float> C10_ALWAYS_INLINE operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
464 return Vectorized<float>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
465 }
466
467 template <>
468 Vectorized<float> C10_ALWAYS_INLINE operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
469 return Vectorized<float>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
470 }
471
472 template <>
473 Vectorized<float> C10_ALWAYS_INLINE operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
474 return Vectorized<float>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
475 }
476
477 template <>
478 Vectorized<float> C10_ALWAYS_INLINE operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
479 return Vectorized<float>{vec_div(a.vec0(), b.vec0()), vec_div(a.vec1(), b.vec1())};
480 }
481
482 template <>
483 Vectorized<float> C10_ALWAYS_INLINE operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
484 return Vectorized<float>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
485 }
486
487 template <>
488 Vectorized<float> C10_ALWAYS_INLINE operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
489 return Vectorized<float>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
490 }
491
492 template <>
493 Vectorized<float> C10_ALWAYS_INLINE operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
494 return Vectorized<float>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
495 }
496
497 } // namespace
498 } // namespace vec
499 } // namespace at
500