1 #pragma once
2 #include <ATen/cpu/vec/intrinsics.h>
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
5 #include <c10/util/complex.h>
6 #include <c10/util/irange.h>
7
8 namespace at {
9 namespace vec {
10 // See Note [CPU_CAPABILITY namespace]
11 inline namespace CPU_CAPABILITY {
12 using ComplexDbl = c10::complex<double>;
13
14 template <>
15 class Vectorized<ComplexDbl> {
16 union {
17 struct {
18 vfloat64 _vec0;
19 vfloat64 _vec1;
20 };
21 struct {
22 vbool64 _vecb0;
23 vbool64 _vecb1;
24 };
25
26 } __attribute__((__may_alias__));
27
28 public:
29 using value_type = ComplexDbl;
30 using vec_internal_type = vfloat64;
31 using vec_internal_mask_type = vbool64;
32 using size_type = int;
size()33 static constexpr size_type size() {
34 return 2;
35 }
Vectorized()36 Vectorized() {}
Vectorized(vfloat64 v)37 C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool64 vmask)38 C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vfloat64 v1,vfloat64 v2)39 C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool64 v1,vbool64 v2)40 C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
41
Vectorized(ComplexDbl val)42 Vectorized(ComplexDbl val) {
43 double real_value = val.real();
44 double imag_value = val.imag();
45 _vec0 = vfloat64{real_value, imag_value};
46 _vec1 = vfloat64{real_value, imag_value};
47 }
Vectorized(ComplexDbl val1,ComplexDbl val2)48 Vectorized(ComplexDbl val1, ComplexDbl val2) {
49 _vec0 = vfloat64{val1.real(), val1.imag()};
50 _vec1 = vfloat64{val2.real(), val2.imag()};
51 }
52
vec0()53 C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
54 return _vec0;
55 }
vec1()56 C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
57 return _vec1;
58 }
59
60 template <int64_t mask>
61 static std::enable_if_t<blendChoiceComplexDbl(mask) == 0, Vectorized<ComplexDbl>>
62 C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)63 blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
64 return a;
65 }
66
67 template <int64_t mask>
68 static std::enable_if_t<blendChoiceComplexDbl(mask) == 1, Vectorized<ComplexDbl>>
69 C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)70 blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
71 return b;
72 }
73
74 template <int64_t mask>
75 static std::enable_if_t<blendChoiceComplexDbl(mask) == 2, Vectorized<ComplexDbl>>
76 C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)77 blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
78 return {b._vec0, a._vec1};
79 }
80
81 template <int64_t mask>
82 static std::enable_if_t<blendChoiceComplexDbl(mask) == 3, Vectorized<ComplexDbl>>
83 C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)84 blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
85 return {a._vec0, b._vec1};
86 }
87
88 template <int64_t mask>
89 static Vectorized<ComplexDbl> C10_ALWAYS_INLINE
el_blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)90 el_blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
91 const vbool64 mask_1st = VsxDblMask1(mask);
92 const vbool64 mask_2nd = VsxDblMask2(mask);
93 return {
94 (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
95 (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)};
96 }
97
blendv(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b,const Vectorized<ComplexDbl> & mask)98 static Vectorized<ComplexDbl> blendv(
99 const Vectorized<ComplexDbl>& a,
100 const Vectorized<ComplexDbl>& b,
101 const Vectorized<ComplexDbl>& mask) {
102 // convert std::complex<V> index mask to V index mask: xy -> xxyy
103 auto mask_complex =
104 Vectorized<ComplexDbl>(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0));
105 return {
106 vec_sel(a._vec0, b._vec0, mask_complex._vecb0),
107 vec_sel(a._vec1, b._vec1, mask_complex._vecb1)};
108 }
109
elwise_blendv(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b,const Vectorized<ComplexDbl> & mask)110 static Vectorized<ComplexDbl> C10_ALWAYS_INLINE elwise_blendv(
111 const Vectorized<ComplexDbl>& a,
112 const Vectorized<ComplexDbl>& b,
113 const Vectorized<ComplexDbl>& mask) {
114 return {
115 vec_sel(a._vec0, b._vec0, mask._vecb0),
116 vec_sel(a._vec1, b._vec1, mask._vecb1)};
117 }
118 template <typename step_t>
119 static Vectorized<ComplexDbl> arange(
120 ComplexDbl base = 0.,
121 step_t step = static_cast<step_t>(1)) {
122 return Vectorized<ComplexDbl>(base, base + step);
123 }
124 static Vectorized<ComplexDbl> set(
125 const Vectorized<ComplexDbl>& a,
126 const Vectorized<ComplexDbl>& b,
127 int64_t count = size()) {
128 switch (count) {
129 case 0:
130 return a;
131 case 1:
132 return blend<1>(a, b);
133 }
134 return b;
135 }
136
137 static Vectorized<value_type> C10_ALWAYS_INLINE
138 loadu(const void* ptr, int count = size()) {
139 if (count == size()) {
140 return {
141 vec_vsx_ld(offset0, reinterpret_cast<const double*>(ptr)),
142 vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
143 }
144
145 __at_align__ value_type tmp_values[size()] = {};
146 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
147
148 return {
149 vec_vsx_ld(offset0, reinterpret_cast<const double*>(tmp_values)),
150 vec_vsx_ld(offset16, reinterpret_cast<const double*>(tmp_values))};
151 }
152 void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
153 if (count == size()) {
154 vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
155 vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
156 } else if (count > 0) {
157 __at_align__ value_type tmp_values[size()];
158 vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
159 vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
160 std::memcpy(
161 ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
162 }
163 }
164
165 const ComplexDbl& operator[](int idx) const = delete;
166 ComplexDbl& operator[](int idx) = delete;
167
map(ComplexDbl (* const f)(ComplexDbl))168 Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
169 __at_align__ ComplexDbl tmp[size()];
170 store(tmp);
171 for (const auto i : c10::irange(size())) {
172 tmp[i] = f(tmp[i]);
173 }
174 return loadu(tmp);
175 }
176
map(ComplexDbl (* const f)(const ComplexDbl &))177 Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
178 __at_align__ ComplexDbl tmp[size()];
179 store(tmp);
180 for (const auto i : c10::irange(size())) {
181 tmp[i] = f(tmp[i]);
182 }
183 return loadu(tmp);
184 }
185
el_swapped()186 Vectorized<ComplexDbl> el_swapped() const {
187 vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2);
188 vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2);
189 return {v0, v1};
190 }
191
el_madd(const Vectorized<ComplexDbl> & multiplier,const Vectorized<ComplexDbl> & val)192 Vectorized<ComplexDbl> el_madd(
193 const Vectorized<ComplexDbl>& multiplier,
194 const Vectorized<ComplexDbl>& val) const {
195 return {
196 vec_madd(_vec0, multiplier._vec0, val._vec0),
197 vec_madd(_vec1, multiplier._vec1, val._vec1)};
198 }
199
el_mergeo()200 Vectorized<ComplexDbl> el_mergeo() const {
201 vfloat64 v0 = vec_splat(_vec0, 1);
202 vfloat64 v1 = vec_splat(_vec1, 1);
203 return {v0, v1};
204 }
205
el_mergee()206 Vectorized<ComplexDbl> el_mergee() const {
207 vfloat64 v0 = vec_splat(_vec0, 0);
208 vfloat64 v1 = vec_splat(_vec1, 0);
209 return {v0, v1};
210 }
211
el_mergee(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)212 static Vectorized<ComplexDbl> el_mergee(
213 Vectorized<ComplexDbl>& first,
214 Vectorized<ComplexDbl>& second) {
215 return {
216 vec_mergeh(first._vec0, second._vec0),
217 vec_mergeh(first._vec1, second._vec1)};
218 }
219
el_mergeo(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)220 static Vectorized<ComplexDbl> el_mergeo(
221 Vectorized<ComplexDbl>& first,
222 Vectorized<ComplexDbl>& second) {
223 return {
224 vec_mergel(first._vec0, second._vec0),
225 vec_mergel(first._vec1, second._vec1)};
226 }
227
abs_2_()228 Vectorized<ComplexDbl> abs_2_() const {
229 auto a = (*this).elwise_mult(*this);
230 auto permuted = a.el_swapped();
231 a = a + permuted;
232 return a;
233 }
234
abs_()235 Vectorized<ComplexDbl> abs_() const {
236 auto vi = el_mergeo();
237 auto vr = el_mergee();
238 return {Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)};
239 }
240
abs()241 Vectorized<ComplexDbl> abs() const {
242 return abs_() & vd_real_mask;
243 }
244
angle_()245 Vectorized<ComplexDbl> angle_() const {
246 // angle = atan2(b/a)
247 // auto b_a = _mm256_permute_pd(values, 0x05); // b a
248 // return Sleef_atan2d4_u10(values, b_a); // 90-angle angle
249 Vectorized<ComplexDbl> ret;
250 ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]);
251 ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]);
252 return ret;
253 }
254
angle()255 Vectorized<ComplexDbl> angle() const {
256 return angle_() & vd_real_mask;
257 }
258
real_()259 Vectorized<ComplexDbl> real_() const {
260 return *this & vd_real_mask;
261 }
real()262 Vectorized<ComplexDbl> real() const {
263 return *this & vd_real_mask;
264 }
imag_()265 Vectorized<ComplexDbl> imag_() const {
266 return *this & vd_imag_mask;
267 }
imag()268 Vectorized<ComplexDbl> imag() const {
269 return imag_().el_swapped();
270 }
271
conj_()272 Vectorized<ComplexDbl> conj_() const {
273 return *this ^ vd_isign_mask;
274 }
conj()275 Vectorized<ComplexDbl> conj() const {
276 return *this ^ vd_isign_mask;
277 }
278
log()279 Vectorized<ComplexDbl> log() const {
280 // Most trigonomic ops use the log() op to improve complex number
281 // performance.
282 return map(std::log);
283 }
284
log2()285 Vectorized<ComplexDbl> log2() const {
286 // log2eB_inv
287 auto ret = log();
288 return ret.elwise_mult(vd_log2e_inv);
289 }
log10()290 Vectorized<ComplexDbl> log10() const {
291 auto ret = log();
292 return ret.elwise_mult(vd_log10e_inv);
293 }
294
log1p()295 Vectorized<ComplexDbl> log1p() const {
296 return map(std::log1p);
297 }
298
asin()299 Vectorized<ComplexDbl> asin() const {
300 // asin(x)
301 // = -i*ln(iz + sqrt(1 -z^2))
302 // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
303 // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
304 auto conj = conj_();
305 auto b_a = conj.el_swapped();
306 auto ab = conj.elwise_mult(b_a);
307 auto im = ab + ab;
308 auto val_2 = (*this).elwise_mult(*this);
309 auto val_2_swapped = val_2.el_swapped();
310 auto re = horizontal_sub(val_2, val_2_swapped);
311 re = Vectorized<ComplexDbl>(vd_one) - re;
312 auto root = el_blend<0x0A>(re, im).sqrt();
313 auto ln = (b_a + root).log();
314 return ln.el_swapped().conj();
315 }
316
acos()317 Vectorized<ComplexDbl> acos() const {
318 // acos(x) = pi/2 - asin(x)
319 return Vectorized(vd_pi_2) - asin();
320 }
321
atan()322 Vectorized<ComplexDbl> atan() const {
323 // atan(x) = i/2 * ln((i + z)/(i - z))
324 auto ione = Vectorized(vd_imag_one);
325 auto sum = ione + *this;
326 auto sub = ione - *this;
327 auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
328 return ln * vd_imag_half; // i/2*ln()
329 }
atanh()330 Vectorized<ComplexDbl> atanh() const {
331 return map(std::atanh);
332 }
333
sin()334 Vectorized<ComplexDbl> sin() const {
335 return map(std::sin);
336 }
sinh()337 Vectorized<ComplexDbl> sinh() const {
338 return map(std::sinh);
339 }
cos()340 Vectorized<ComplexDbl> cos() const {
341 return map(std::cos);
342 }
cosh()343 Vectorized<ComplexDbl> cosh() const {
344 return map(std::cosh);
345 }
346
tan()347 Vectorized<ComplexDbl> tan() const {
348 return map(std::tan);
349 }
tanh()350 Vectorized<ComplexDbl> tanh() const {
351 return map(std::tanh);
352 }
ceil()353 Vectorized<ComplexDbl> ceil() const {
354 return {vec_ceil(_vec0), vec_ceil(_vec1)};
355 }
floor()356 Vectorized<ComplexDbl> floor() const {
357 return {vec_floor(_vec0), vec_floor(_vec1)};
358 }
neg()359 Vectorized<ComplexDbl> neg() const {
360 auto z = Vectorized<ComplexDbl>(vd_zero);
361 return z - *this;
362 }
round()363 Vectorized<ComplexDbl> round() const {
364 return {vec_rint(_vec0), vec_rint(_vec1)};
365 }
366
trunc()367 Vectorized<ComplexDbl> trunc() const {
368 return {vec_trunc(_vec0), vec_trunc(_vec1)};
369 }
370
elwise_sqrt()371 Vectorized<ComplexDbl> elwise_sqrt() const {
372 return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
373 }
374
sqrt()375 Vectorized<ComplexDbl> sqrt() const {
376 return map(std::sqrt);
377 }
378
reciprocal()379 Vectorized<ComplexDbl> reciprocal() const {
380 // re + im*i = (a + bi) / (c + di)
381 // re = (ac + bd)/abs_2() = c/abs_2()
382 // im = (bc - ad)/abs_2() = d/abs_2()
383 auto c_d = *this ^ vd_isign_mask; // c -d
384 auto abs = abs_2_();
385 return c_d.elwise_div(abs);
386 }
387
rsqrt()388 Vectorized<ComplexDbl> rsqrt() const {
389 return sqrt().reciprocal();
390 }
391
horizontal_add(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)392 static Vectorized<ComplexDbl> horizontal_add(
393 Vectorized<ComplexDbl>& first,
394 Vectorized<ComplexDbl>& second) {
395 // Operates on individual floats, see _mm_hadd_ps
396 // {f0+f1, s0+s1, f2+f3, s2+s3, ...}
397 // i.e. it sums the re and im of each value and interleaves first and second:
398 // {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
399 return el_mergee(first, second) + el_mergeo(first, second);
400 }
401
horizontal_sub(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)402 static Vectorized<ComplexDbl> horizontal_sub(
403 Vectorized<ComplexDbl>& first,
404 Vectorized<ComplexDbl>& second) {
405 // we will simulate it differently with 6 instructions total
406 // lets permute second so that we can add it getting horizontal sums
407 auto first_perm = first.el_swapped(); // 2perm
408 auto second_perm = second.el_swapped(); // 2perm
409 // summ
410 auto first_ret = first - first_perm; // 2sub
411 auto second_ret = second - second_perm; // 2 sub
412 // now lets choose evens
413 return el_mergee(first_ret, second_ret); // 2 mergee's
414 }
415
416 Vectorized<ComplexDbl> inline operator*(const Vectorized<ComplexDbl>& b) const {
417 //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
418 #if 1
419 // this is more vsx friendly than simulating horizontal from x86
420 auto vi = b.el_mergeo();
421 auto vr = b.el_mergee();
422 vi = vi ^ vd_rsign_mask;
423 auto ret = elwise_mult(vr);
424 auto vx_swapped = el_swapped();
425 ret = vx_swapped.el_madd(vi, ret);
426 #else
427 auto ac_bd = elwise_mult(b);
428 auto d_c = b.el_swapped();
429 d_c = d_c ^ vd_isign_mask;
430 auto ad_bc = elwise_mult(d_c);
431 auto ret = horizontal_sub(ac_bd, ad_bc);
432 #endif
433 return ret;
434 }
435
436 Vectorized<ComplexDbl> inline operator/(const Vectorized<ComplexDbl>& b) const {
437 // re + im*i = (a + bi) / (c + di)
438 // re = (ac + bd)/abs_2()
439 // im = (bc - ad)/abs_2()
440 auto fabs_cd = Vectorized{
441 vec_andc(b._vec0, vd_sign_mask),
442 vec_andc(b._vec1, vd_sign_mask)}; // |c| |d|
443 auto fabs_dc = fabs_cd.el_swapped(); // |d| |c|
444 auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
445 auto a2 = elwise_div(scale); // a/sc b/sc
446 auto b2 = b.elwise_div(scale); // c/sc d/sc
447 auto acbd2 = a2.elwise_mult(b2); // ac/sc^2 bd/sc^2
448 auto dc2 = b2.el_swapped(); // d/sc c/sc
449 dc2 = dc2 ^ vd_rsign_mask; // -d/sc c/sc
450 auto adbc2 = a2.elwise_mult(dc2); // -ad/sc^2 bc/sc^2
451 auto ret = horizontal_add(acbd2, adbc2); // (ac+bd)/sc^2 (bc-ad)/sc^2
452 auto denom2 = b2.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
453 ret = ret.elwise_div(denom2);
454 return ret;
455 }
456
exp()457 Vectorized<ComplexDbl> exp() const {
458 return map(std::exp);
459 }
exp2()460 Vectorized<ComplexDbl> exp2() const {
461 return map(exp2_impl);
462 }
expm1()463 Vectorized<ComplexDbl> expm1() const {
464 return map(std::expm1);
465 }
466
pow(const Vectorized<ComplexDbl> & exp)467 Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
468 __at_align__ ComplexDbl x_tmp[size()];
469 __at_align__ ComplexDbl y_tmp[size()];
470 store(x_tmp);
471 exp.store(y_tmp);
472 for (const auto i : c10::irange(size())) {
473 x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
474 }
475 return loadu(x_tmp);
476 }
477
sgn()478 Vectorized<ComplexDbl> sgn() const {
479 return map(at::native::sgn_impl);
480 }
481
482 Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
483 TORCH_CHECK(false, "not supported for complex numbers");
484 }
485 Vectorized<ComplexDbl> operator<=(const Vectorized<ComplexDbl>& other) const {
486 TORCH_CHECK(false, "not supported for complex numbers");
487 }
488 Vectorized<ComplexDbl> operator>(const Vectorized<ComplexDbl>& other) const {
489 TORCH_CHECK(false, "not supported for complex numbers");
490 }
491 Vectorized<ComplexDbl> operator>=(const Vectorized<ComplexDbl>& other) const {
492 TORCH_CHECK(false, "not supported for complex numbers");
493 }
494
eq(const Vectorized<ComplexDbl> & other)495 Vectorized<ComplexDbl> eq(const Vectorized<ComplexDbl>& other) const {
496 auto eq = (*this == other); // compares real and imag individually
497 // If both real numbers and imag numbers are equal, then the complex numbers are equal
498 return (eq.real() & eq.imag()) & vd_one;
499 }
ne(const Vectorized<ComplexDbl> & other)500 Vectorized<ComplexDbl> ne(const Vectorized<ComplexDbl>& other) const {
501 auto ne = (*this != other); // compares real and imag individually
502 // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
503 return (ne.real() | ne.imag()) & vd_one;
504 }
505
506 DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq)
507 DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne)
508
509 DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add)
510 DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub)
511 DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and)
512 DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or)
513 DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor)
514 // elementwise helpers
515 DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul)
516 DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div)
517 DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt)
518 DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge)
519 DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt)
520 DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple)
521 DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max)
522 };
523
524 template <>
maximum(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)525 Vectorized<ComplexDbl> inline maximum(
526 const Vectorized<ComplexDbl>& a,
527 const Vectorized<ComplexDbl>& b) {
528 auto abs_a = a.abs_2_();
529 auto abs_b = b.abs_2_();
530 // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
531 // auto max = _mm256_blendv_ps(a, b, mask);
532 auto mask = abs_a.elwise_lt(abs_b);
533 auto max = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
534
535 return max;
536 // Exploit the fact that all-ones is a NaN.
537 // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
538 // return _mm256_or_ps(max, isnan);
539 }
540
541 template <>
minimum(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)542 Vectorized<ComplexDbl> inline minimum(
543 const Vectorized<ComplexDbl>& a,
544 const Vectorized<ComplexDbl>& b) {
545 auto abs_a = a.abs_2_();
546 auto abs_b = b.abs_2_();
547 // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
548 // auto min = _mm256_blendv_ps(a, b, mask);
549 auto mask = abs_a.elwise_gt(abs_b);
550 auto min = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
551 return min;
552 // Exploit the fact that all-ones is a NaN.
553 // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
554 // return _mm256_or_ps(min, isnan);
555 }
556
557 template <>
558 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator+(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
559 return Vectorized<ComplexDbl>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
560 }
561
562 template <>
563 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator-(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
564 return Vectorized<ComplexDbl>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
565 }
566
567 template <>
568 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator&(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
569 return Vectorized<ComplexDbl>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
570 }
571
572 template <>
573 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator|(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
574 return Vectorized<ComplexDbl>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
575 }
576
577 template <>
578 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator^(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
579 return Vectorized<ComplexDbl>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
580 }
581
582 } // namespace
583 } // namespace vec
584 } // namespace at
585