xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/cpu/vec/intrinsics.h>
3 #include <ATen/cpu/vec/vec_base.h>
4 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
5 #include <c10/util/complex.h>
6 #include <c10/util/irange.h>
7 
8 namespace at {
9 namespace vec {
10 // See Note [CPU_CAPABILITY namespace]
11 inline namespace CPU_CAPABILITY {
12 using ComplexDbl = c10::complex<double>;
13 
14 template <>
15 class Vectorized<ComplexDbl> {
16   union {
17     struct {
18       vfloat64 _vec0;
19       vfloat64 _vec1;
20     };
21     struct {
22       vbool64 _vecb0;
23       vbool64 _vecb1;
24     };
25 
26   } __attribute__((__may_alias__));
27 
28  public:
29   using value_type = ComplexDbl;
30   using vec_internal_type = vfloat64;
31   using vec_internal_mask_type = vbool64;
32   using size_type = int;
size()33   static constexpr size_type size() {
34     return 2;
35   }
Vectorized()36   Vectorized() {}
Vectorized(vfloat64 v)37   C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {}
Vectorized(vbool64 vmask)38   C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
Vectorized(vfloat64 v1,vfloat64 v2)39   C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) : _vec0{v1}, _vec1{v2} {}
Vectorized(vbool64 v1,vbool64 v2)40   C10_ALWAYS_INLINE Vectorized(vbool64 v1, vbool64 v2) : _vecb0{v1}, _vecb1{v2} {}
41 
Vectorized(ComplexDbl val)42   Vectorized(ComplexDbl val) {
43     double real_value = val.real();
44     double imag_value = val.imag();
45     _vec0 = vfloat64{real_value, imag_value};
46     _vec1 = vfloat64{real_value, imag_value};
47   }
Vectorized(ComplexDbl val1,ComplexDbl val2)48   Vectorized(ComplexDbl val1, ComplexDbl val2) {
49     _vec0 = vfloat64{val1.real(), val1.imag()};
50     _vec1 = vfloat64{val2.real(), val2.imag()};
51   }
52 
vec0()53   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
54     return _vec0;
55   }
vec1()56   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
57     return _vec1;
58   }
59 
60   template <int64_t mask>
61   static std::enable_if_t<blendChoiceComplexDbl(mask) == 0, Vectorized<ComplexDbl>>
62       C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)63       blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
64     return a;
65   }
66 
67   template <int64_t mask>
68   static std::enable_if_t<blendChoiceComplexDbl(mask) == 1, Vectorized<ComplexDbl>>
69       C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)70       blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
71     return b;
72   }
73 
74   template <int64_t mask>
75   static std::enable_if_t<blendChoiceComplexDbl(mask) == 2, Vectorized<ComplexDbl>>
76       C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)77       blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
78     return {b._vec0, a._vec1};
79   }
80 
81   template <int64_t mask>
82   static std::enable_if_t<blendChoiceComplexDbl(mask) == 3, Vectorized<ComplexDbl>>
83       C10_ALWAYS_INLINE
blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)84       blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
85     return {a._vec0, b._vec1};
86   }
87 
88   template <int64_t mask>
89   static Vectorized<ComplexDbl> C10_ALWAYS_INLINE
el_blend(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)90   el_blend(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
91     const vbool64 mask_1st = VsxDblMask1(mask);
92     const vbool64 mask_2nd = VsxDblMask2(mask);
93     return {
94         (vfloat64)vec_sel(a._vec0, b._vec0, mask_1st),
95         (vfloat64)vec_sel(a._vec1, b._vec1, mask_2nd)};
96   }
97 
blendv(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b,const Vectorized<ComplexDbl> & mask)98   static Vectorized<ComplexDbl> blendv(
99       const Vectorized<ComplexDbl>& a,
100       const Vectorized<ComplexDbl>& b,
101       const Vectorized<ComplexDbl>& mask) {
102     // convert std::complex<V> index mask to V index mask: xy -> xxyy
103     auto mask_complex =
104         Vectorized<ComplexDbl>(vec_splat(mask._vec0, 0), vec_splat(mask._vec1, 0));
105     return {
106         vec_sel(a._vec0, b._vec0, mask_complex._vecb0),
107         vec_sel(a._vec1, b._vec1, mask_complex._vecb1)};
108   }
109 
elwise_blendv(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b,const Vectorized<ComplexDbl> & mask)110   static Vectorized<ComplexDbl> C10_ALWAYS_INLINE elwise_blendv(
111       const Vectorized<ComplexDbl>& a,
112       const Vectorized<ComplexDbl>& b,
113       const Vectorized<ComplexDbl>& mask) {
114     return {
115         vec_sel(a._vec0, b._vec0, mask._vecb0),
116         vec_sel(a._vec1, b._vec1, mask._vecb1)};
117   }
118   template <typename step_t>
119   static Vectorized<ComplexDbl> arange(
120       ComplexDbl base = 0.,
121       step_t step = static_cast<step_t>(1)) {
122     return Vectorized<ComplexDbl>(base, base + step);
123   }
124   static Vectorized<ComplexDbl> set(
125       const Vectorized<ComplexDbl>& a,
126       const Vectorized<ComplexDbl>& b,
127       int64_t count = size()) {
128     switch (count) {
129       case 0:
130         return a;
131       case 1:
132         return blend<1>(a, b);
133     }
134     return b;
135   }
136 
137   static Vectorized<value_type> C10_ALWAYS_INLINE
138   loadu(const void* ptr, int count = size()) {
139     if (count == size()) {
140       return {
141           vec_vsx_ld(offset0, reinterpret_cast<const double*>(ptr)),
142           vec_vsx_ld(offset16, reinterpret_cast<const double*>(ptr))};
143     }
144 
145     __at_align__ value_type tmp_values[size()] = {};
146     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
147 
148     return {
149         vec_vsx_ld(offset0, reinterpret_cast<const double*>(tmp_values)),
150         vec_vsx_ld(offset16, reinterpret_cast<const double*>(tmp_values))};
151   }
152   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
153     if (count == size()) {
154       vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(ptr));
155       vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(ptr));
156     } else if (count > 0) {
157       __at_align__ value_type tmp_values[size()];
158       vec_vsx_st(_vec0, offset0, reinterpret_cast<double*>(tmp_values));
159       vec_vsx_st(_vec1, offset16, reinterpret_cast<double*>(tmp_values));
160       std::memcpy(
161           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
162     }
163   }
164 
165   const ComplexDbl& operator[](int idx) const = delete;
166   ComplexDbl& operator[](int idx) = delete;
167 
map(ComplexDbl (* const f)(ComplexDbl))168   Vectorized<ComplexDbl> map(ComplexDbl (*const f)(ComplexDbl)) const {
169     __at_align__ ComplexDbl tmp[size()];
170     store(tmp);
171     for (const auto i : c10::irange(size())) {
172       tmp[i] = f(tmp[i]);
173     }
174     return loadu(tmp);
175   }
176 
map(ComplexDbl (* const f)(const ComplexDbl &))177   Vectorized<ComplexDbl> map(ComplexDbl (*const f)(const ComplexDbl&)) const {
178     __at_align__ ComplexDbl tmp[size()];
179     store(tmp);
180     for (const auto i : c10::irange(size())) {
181       tmp[i] = f(tmp[i]);
182     }
183     return loadu(tmp);
184   }
185 
el_swapped()186   Vectorized<ComplexDbl> el_swapped() const {
187     vfloat64 v0 = vec_xxpermdi(_vec0, _vec0, 2);
188     vfloat64 v1 = vec_xxpermdi(_vec1, _vec1, 2);
189     return {v0, v1};
190   }
191 
el_madd(const Vectorized<ComplexDbl> & multiplier,const Vectorized<ComplexDbl> & val)192   Vectorized<ComplexDbl> el_madd(
193       const Vectorized<ComplexDbl>& multiplier,
194       const Vectorized<ComplexDbl>& val) const {
195     return {
196         vec_madd(_vec0, multiplier._vec0, val._vec0),
197         vec_madd(_vec1, multiplier._vec1, val._vec1)};
198   }
199 
el_mergeo()200   Vectorized<ComplexDbl> el_mergeo() const {
201     vfloat64 v0 = vec_splat(_vec0, 1);
202     vfloat64 v1 = vec_splat(_vec1, 1);
203     return {v0, v1};
204   }
205 
el_mergee()206   Vectorized<ComplexDbl> el_mergee() const {
207     vfloat64 v0 = vec_splat(_vec0, 0);
208     vfloat64 v1 = vec_splat(_vec1, 0);
209     return {v0, v1};
210   }
211 
el_mergee(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)212   static Vectorized<ComplexDbl> el_mergee(
213       Vectorized<ComplexDbl>& first,
214       Vectorized<ComplexDbl>& second) {
215     return {
216         vec_mergeh(first._vec0, second._vec0),
217         vec_mergeh(first._vec1, second._vec1)};
218   }
219 
el_mergeo(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)220   static Vectorized<ComplexDbl> el_mergeo(
221       Vectorized<ComplexDbl>& first,
222       Vectorized<ComplexDbl>& second) {
223     return {
224         vec_mergel(first._vec0, second._vec0),
225         vec_mergel(first._vec1, second._vec1)};
226   }
227 
abs_2_()228   Vectorized<ComplexDbl> abs_2_() const {
229     auto a = (*this).elwise_mult(*this);
230     auto permuted = a.el_swapped();
231     a = a + permuted;
232     return a;
233   }
234 
abs_()235   Vectorized<ComplexDbl> abs_() const {
236     auto vi = el_mergeo();
237     auto vr = el_mergee();
238     return {Sleef_hypotd2_u05vsx(vr._vec0, vi._vec0), Sleef_hypotd2_u05vsx(vr._vec1, vi._vec1)};
239   }
240 
abs()241   Vectorized<ComplexDbl> abs() const {
242     return abs_() & vd_real_mask;
243   }
244 
angle_()245   Vectorized<ComplexDbl> angle_() const {
246     // angle = atan2(b/a)
247     // auto b_a = _mm256_permute_pd(values, 0x05);     // b        a
248     // return Sleef_atan2d4_u10(values, b_a);          // 90-angle angle
249     Vectorized<ComplexDbl> ret;
250     ret._vec0[0] = std::atan2(_vec0[1], _vec0[0]);
251     ret._vec1[0] = std::atan2(_vec1[1], _vec1[0]);
252     return ret;
253   }
254 
angle()255   Vectorized<ComplexDbl> angle() const {
256     return angle_() & vd_real_mask;
257   }
258 
real_()259   Vectorized<ComplexDbl> real_() const {
260     return *this & vd_real_mask;
261   }
real()262   Vectorized<ComplexDbl> real() const {
263     return *this & vd_real_mask;
264   }
imag_()265   Vectorized<ComplexDbl> imag_() const {
266     return *this & vd_imag_mask;
267   }
imag()268   Vectorized<ComplexDbl> imag() const {
269     return imag_().el_swapped();
270   }
271 
conj_()272   Vectorized<ComplexDbl> conj_() const {
273     return *this ^ vd_isign_mask;
274   }
conj()275   Vectorized<ComplexDbl> conj() const {
276     return *this ^ vd_isign_mask;
277   }
278 
log()279   Vectorized<ComplexDbl> log() const {
280     // Most trigonomic ops use the log() op to improve complex number
281     // performance.
282     return map(std::log);
283   }
284 
log2()285   Vectorized<ComplexDbl> log2() const {
286     // log2eB_inv
287     auto ret = log();
288     return ret.elwise_mult(vd_log2e_inv);
289   }
log10()290   Vectorized<ComplexDbl> log10() const {
291     auto ret = log();
292     return ret.elwise_mult(vd_log10e_inv);
293   }
294 
log1p()295   Vectorized<ComplexDbl> log1p() const {
296     return map(std::log1p);
297   }
298 
asin()299   Vectorized<ComplexDbl> asin() const {
300     // asin(x)
301     // = -i*ln(iz + sqrt(1 -z^2))
302     // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi)))
303     // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi))
304     auto conj = conj_();
305     auto b_a = conj.el_swapped();
306     auto ab = conj.elwise_mult(b_a);
307     auto im = ab + ab;
308     auto val_2 = (*this).elwise_mult(*this);
309     auto val_2_swapped = val_2.el_swapped();
310     auto re = horizontal_sub(val_2, val_2_swapped);
311     re = Vectorized<ComplexDbl>(vd_one) - re;
312     auto root = el_blend<0x0A>(re, im).sqrt();
313     auto ln = (b_a + root).log();
314     return ln.el_swapped().conj();
315   }
316 
acos()317   Vectorized<ComplexDbl> acos() const {
318     // acos(x) = pi/2 - asin(x)
319     return Vectorized(vd_pi_2) - asin();
320   }
321 
atan()322   Vectorized<ComplexDbl> atan() const {
323     // atan(x) = i/2 * ln((i + z)/(i - z))
324     auto ione = Vectorized(vd_imag_one);
325     auto sum = ione + *this;
326     auto sub = ione - *this;
327     auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
328     return ln * vd_imag_half; // i/2*ln()
329   }
atanh()330   Vectorized<ComplexDbl> atanh() const {
331     return map(std::atanh);
332   }
333 
sin()334   Vectorized<ComplexDbl> sin() const {
335     return map(std::sin);
336   }
sinh()337   Vectorized<ComplexDbl> sinh() const {
338     return map(std::sinh);
339   }
cos()340   Vectorized<ComplexDbl> cos() const {
341     return map(std::cos);
342   }
cosh()343   Vectorized<ComplexDbl> cosh() const {
344     return map(std::cosh);
345   }
346 
tan()347   Vectorized<ComplexDbl> tan() const {
348     return map(std::tan);
349   }
tanh()350   Vectorized<ComplexDbl> tanh() const {
351     return map(std::tanh);
352   }
ceil()353   Vectorized<ComplexDbl> ceil() const {
354     return {vec_ceil(_vec0), vec_ceil(_vec1)};
355   }
floor()356   Vectorized<ComplexDbl> floor() const {
357     return {vec_floor(_vec0), vec_floor(_vec1)};
358   }
neg()359   Vectorized<ComplexDbl> neg() const {
360     auto z = Vectorized<ComplexDbl>(vd_zero);
361     return z - *this;
362   }
round()363   Vectorized<ComplexDbl> round() const {
364     return {vec_rint(_vec0), vec_rint(_vec1)};
365   }
366 
trunc()367   Vectorized<ComplexDbl> trunc() const {
368     return {vec_trunc(_vec0), vec_trunc(_vec1)};
369   }
370 
elwise_sqrt()371   Vectorized<ComplexDbl> elwise_sqrt() const {
372     return {vec_sqrt(_vec0), vec_sqrt(_vec1)};
373   }
374 
sqrt()375   Vectorized<ComplexDbl> sqrt() const {
376     return map(std::sqrt);
377   }
378 
reciprocal()379   Vectorized<ComplexDbl> reciprocal() const {
380     // re + im*i = (a + bi)  / (c + di)
381     // re = (ac + bd)/abs_2() = c/abs_2()
382     // im = (bc - ad)/abs_2() = d/abs_2()
383     auto c_d = *this ^ vd_isign_mask; // c       -d
384     auto abs = abs_2_();
385     return c_d.elwise_div(abs);
386   }
387 
rsqrt()388   Vectorized<ComplexDbl> rsqrt() const {
389     return sqrt().reciprocal();
390   }
391 
horizontal_add(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)392   static Vectorized<ComplexDbl> horizontal_add(
393       Vectorized<ComplexDbl>& first,
394       Vectorized<ComplexDbl>& second) {
395     // Operates on individual floats, see _mm_hadd_ps
396     // {f0+f1, s0+s1, f2+f3, s2+s3, ...}
397     // i.e. it sums the re and im of each value and interleaves first and second:
398     // {f_re0 + f_im0, s_re0 + s_im0, f_re1 + f_im1, s_re1 + s_im1, ...}
399     return el_mergee(first, second) + el_mergeo(first, second);
400   }
401 
horizontal_sub(Vectorized<ComplexDbl> & first,Vectorized<ComplexDbl> & second)402   static Vectorized<ComplexDbl> horizontal_sub(
403       Vectorized<ComplexDbl>& first,
404       Vectorized<ComplexDbl>& second) {
405     // we will simulate it differently with 6 instructions total
406     // lets permute second so that we can add it getting horizontal sums
407     auto first_perm = first.el_swapped(); // 2perm
408     auto second_perm = second.el_swapped(); // 2perm
409     // summ
410     auto first_ret = first - first_perm; // 2sub
411     auto second_ret = second - second_perm; // 2 sub
412     // now lets choose evens
413     return el_mergee(first_ret, second_ret); // 2 mergee's
414   }
415 
416   Vectorized<ComplexDbl> inline operator*(const Vectorized<ComplexDbl>& b) const {
417     //(a + bi)  * (c + di) = (ac - bd) + (ad + bc)i
418 #if 1
419     // this is more vsx friendly than simulating horizontal from x86
420     auto vi = b.el_mergeo();
421     auto vr = b.el_mergee();
422     vi = vi ^ vd_rsign_mask;
423     auto ret = elwise_mult(vr);
424     auto vx_swapped = el_swapped();
425     ret = vx_swapped.el_madd(vi, ret);
426 #else
427     auto ac_bd = elwise_mult(b);
428     auto d_c = b.el_swapped();
429     d_c = d_c ^ vd_isign_mask;
430     auto ad_bc = elwise_mult(d_c);
431     auto ret = horizontal_sub(ac_bd, ad_bc);
432 #endif
433     return ret;
434   }
435 
436   Vectorized<ComplexDbl> inline operator/(const Vectorized<ComplexDbl>& b) const {
437     // re + im*i = (a + bi)  / (c + di)
438     // re = (ac + bd)/abs_2()
439     // im = (bc - ad)/abs_2()
440     auto fabs_cd =  Vectorized{
441       vec_andc(b._vec0, vd_sign_mask),
442       vec_andc(b._vec1, vd_sign_mask)};       // |c|            |d|
443     auto fabs_dc =  fabs_cd.el_swapped();     // |d|            |c|
444     auto scale = fabs_cd.elwise_max(fabs_dc); // sc = max(|c|, |d|)
445     auto a2 = elwise_div(scale);              // a/sc           b/sc
446     auto b2 = b.elwise_div(scale);            // c/sc           d/sc
447     auto acbd2 = a2.elwise_mult(b2);          // ac/sc^2        bd/sc^2
448     auto dc2 = b2.el_swapped();               // d/sc           c/sc
449     dc2 = dc2 ^ vd_rsign_mask;                // -d/sc          c/sc
450     auto adbc2 = a2.elwise_mult(dc2);         // -ad/sc^2       bc/sc^2
451     auto ret = horizontal_add(acbd2, adbc2);  // (ac+bd)/sc^2   (bc-ad)/sc^2
452     auto denom2 = b2.abs_2_();                // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
453     ret = ret.elwise_div(denom2);
454     return ret;
455   }
456 
exp()457   Vectorized<ComplexDbl> exp() const {
458     return map(std::exp);
459   }
exp2()460   Vectorized<ComplexDbl> exp2() const {
461     return map(exp2_impl);
462   }
expm1()463   Vectorized<ComplexDbl> expm1() const {
464     return map(std::expm1);
465   }
466 
pow(const Vectorized<ComplexDbl> & exp)467   Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
468     __at_align__ ComplexDbl x_tmp[size()];
469     __at_align__ ComplexDbl y_tmp[size()];
470     store(x_tmp);
471     exp.store(y_tmp);
472     for (const auto i : c10::irange(size())) {
473       x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]);
474     }
475     return loadu(x_tmp);
476   }
477 
sgn()478   Vectorized<ComplexDbl> sgn() const {
479     return map(at::native::sgn_impl);
480   }
481 
482   Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
483     TORCH_CHECK(false, "not supported for complex numbers");
484   }
485   Vectorized<ComplexDbl> operator<=(const Vectorized<ComplexDbl>& other) const {
486     TORCH_CHECK(false, "not supported for complex numbers");
487   }
488   Vectorized<ComplexDbl> operator>(const Vectorized<ComplexDbl>& other) const {
489     TORCH_CHECK(false, "not supported for complex numbers");
490   }
491   Vectorized<ComplexDbl> operator>=(const Vectorized<ComplexDbl>& other) const {
492     TORCH_CHECK(false, "not supported for complex numbers");
493   }
494 
eq(const Vectorized<ComplexDbl> & other)495   Vectorized<ComplexDbl> eq(const Vectorized<ComplexDbl>& other) const {
496     auto eq = (*this == other);  // compares real and imag individually
497     // If both real numbers and imag numbers are equal, then the complex numbers are equal
498     return (eq.real() & eq.imag()) & vd_one;
499   }
ne(const Vectorized<ComplexDbl> & other)500   Vectorized<ComplexDbl> ne(const Vectorized<ComplexDbl>& other) const {
501     auto ne = (*this != other);  // compares real and imag individually
502     // If either real numbers or imag numbers are not equal, then the complex numbers are not equal
503     return (ne.real() | ne.imag()) & vd_one;
504   }
505 
506   DEFINE_MEMBER_OP(operator==, ComplexDbl, vec_cmpeq)
507   DEFINE_MEMBER_OP(operator!=, ComplexDbl, vec_cmpne)
508 
509   DEFINE_MEMBER_OP(operator+, ComplexDbl, vec_add)
510   DEFINE_MEMBER_OP(operator-, ComplexDbl, vec_sub)
511   DEFINE_MEMBER_OP(operator&, ComplexDbl, vec_and)
512   DEFINE_MEMBER_OP(operator|, ComplexDbl, vec_or)
513   DEFINE_MEMBER_OP(operator^, ComplexDbl, vec_xor)
514   // elementwise helpers
515   DEFINE_MEMBER_OP(elwise_mult, ComplexDbl, vec_mul)
516   DEFINE_MEMBER_OP(elwise_div, ComplexDbl, vec_div)
517   DEFINE_MEMBER_OP(elwise_gt, ComplexDbl, vec_cmpgt)
518   DEFINE_MEMBER_OP(elwise_ge, ComplexDbl, vec_cmpge)
519   DEFINE_MEMBER_OP(elwise_lt, ComplexDbl, vec_cmplt)
520   DEFINE_MEMBER_OP(elwise_le, ComplexDbl, vec_cmple)
521   DEFINE_MEMBER_OP(elwise_max, ComplexDbl, vec_max)
522 };
523 
524 template <>
maximum(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)525 Vectorized<ComplexDbl> inline maximum(
526     const Vectorized<ComplexDbl>& a,
527     const Vectorized<ComplexDbl>& b) {
528   auto abs_a = a.abs_2_();
529   auto abs_b = b.abs_2_();
530   // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ);
531   // auto max = _mm256_blendv_ps(a, b, mask);
532   auto mask = abs_a.elwise_lt(abs_b);
533   auto max = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
534 
535   return max;
536   // Exploit the fact that all-ones is a NaN.
537   // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
538   // return _mm256_or_ps(max, isnan);
539 }
540 
541 template <>
minimum(const Vectorized<ComplexDbl> & a,const Vectorized<ComplexDbl> & b)542 Vectorized<ComplexDbl> inline minimum(
543     const Vectorized<ComplexDbl>& a,
544     const Vectorized<ComplexDbl>& b) {
545   auto abs_a = a.abs_2_();
546   auto abs_b = b.abs_2_();
547   // auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ);
548   // auto min = _mm256_blendv_ps(a, b, mask);
549   auto mask = abs_a.elwise_gt(abs_b);
550   auto min = Vectorized<ComplexDbl>::elwise_blendv(a, b, mask);
551   return min;
552   // Exploit the fact that all-ones is a NaN.
553   // auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q);
554   // return _mm256_or_ps(min, isnan);
555 }
556 
557 template <>
558 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator+(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
559   return Vectorized<ComplexDbl>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
560 }
561 
562 template <>
563 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator-(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
564   return Vectorized<ComplexDbl>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
565 }
566 
567 template <>
568 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator&(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
569   return Vectorized<ComplexDbl>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
570 }
571 
572 template <>
573 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator|(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
574   return Vectorized<ComplexDbl>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
575 }
576 
577 template <>
578 Vectorized<ComplexDbl> C10_ALWAYS_INLINE operator^(const Vectorized<ComplexDbl>& a, const Vectorized<ComplexDbl>& b) {
579   return Vectorized<ComplexDbl>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
580 }
581 
582 } // namespace
583 } // namespace vec
584 } // namespace at
585