xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/vec256/vec256_float_neon.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13 
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 #include <executorch/kernels/optimized/vec/vec_base.h>
16 
17 
18 #if defined(__aarch64__) && defined(ET_BUILD_ARM_VEC256_WITH_SLEEF)
19 #include <sleef.h>
20 #endif
21 
22 // Sleef offers vectorized versions of some transcedentals
23 // such as sin, cos, tan etc..
24 // However for now opting for STL, since we are not building
25 // with Sleef for mobile yet.
26 
27 namespace executorch {
28 namespace vec {
29 // See Note [CPU_CAPABILITY namespace]
30 inline namespace CPU_CAPABILITY {
31 
32 // Right now contains only aarch64 implementation.
33 // Due to follow two reasons aarch32 is not currently supported.
34 // 1. Due to difference in ISA been aarch32 and aarch64, intrinsics
35 //    that work for aarch64 dont work for aarch32.
36 // 2. Android NDK r21 has problems with compiling aarch32.
37 //    Clang seg faults.
38 //    https://github.com/android/ndk/issues/1248
39 //    https://bugs.llvm.org/show_bug.cgi?id=45824
40 // Most likely we will do aarch32 support with inline asm.
41 #if defined(__aarch64__)
42 
43 #ifdef __BIG_ENDIAN__
44 #error "Big endian is not supported."
45 #endif
46 
47 #if defined(ET_BUILD_ARM_VEC256_WITH_SLEEF)
48 #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
49 #else
50 #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
51 #endif
52 
53 template<int index, bool mask_val>
54 struct BlendRegs {
55   static float32x4_t impl(
56     const float32x4_t& a, const float32x4_t& b, float32x4_t& res);
57 };
58 
59 template<int index>
60 struct BlendRegs<index, true>{
61   static float32x4_t impl(
62       const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
63     return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index);
64   }
65 };
66 
67 template<int index>
68 struct BlendRegs<index, false>{
69   static float32x4_t impl(
70       const float32x4_t& a, const float32x4_t& b, float32x4_t& res) {
71     return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index);
72   }
73 };
74 
75 template <> class Vectorized<float> {
76 private:
77   float32x4x2_t values;
78 public:
79   using value_type = float;
80   using size_type = int;
81   static constexpr size_type size() {
82     return 8;
83   }
84   Vectorized() {}
85   Vectorized(float32x4x2_t v) : values(v) {}
86   Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {}
87   Vectorized(float val0, float val1, float val2, float val3,
88          float val4, float val5, float val6, float val7) :
89          values{val0, val1, val2, val3, val4, val5, val6, val7} {}
90   Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {}
91   operator float32x4x2_t() const {
92     return values;
93   }
94   template <int64_t mask>
95   static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) {
96     Vectorized<float> vec;
97     // 0.
98     vec.values.val[0] =
99       BlendRegs<0, (mask & 0x01)!=0>::impl(
100           a.values.val[0], b.values.val[0], vec.values.val[0]);
101     vec.values.val[0] =
102       BlendRegs<1, (mask & 0x02)!=0>::impl(
103           a.values.val[0], b.values.val[0], vec.values.val[0]);
104     vec.values.val[0] =
105       BlendRegs<2, (mask & 0x04)!=0>::impl(
106           a.values.val[0], b.values.val[0], vec.values.val[0]);
107     vec.values.val[0] =
108       BlendRegs<3, (mask & 0x08)!=0>::impl(
109           a.values.val[0], b.values.val[0], vec.values.val[0]);
110     // 1.
111     vec.values.val[1] =
112       BlendRegs<0, (mask & 0x10)!=0>::impl(
113           a.values.val[1], b.values.val[1], vec.values.val[1]);
114     vec.values.val[1] =
115       BlendRegs<1, (mask & 0x20)!=0>::impl(
116           a.values.val[1], b.values.val[1], vec.values.val[1]);
117     vec.values.val[1] =
118       BlendRegs<2, (mask & 0x40)!=0>::impl(
119           a.values.val[1], b.values.val[1], vec.values.val[1]);
120     vec.values.val[1] =
121       BlendRegs<3, (mask & 0x80)!=0>::impl(
122           a.values.val[1], b.values.val[1], vec.values.val[1]);
123     return vec;
124   }
125   static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b,
126                               const Vectorized<float>& mask) {
127     // TODO
128     // NB: This requires that each value, i.e., each uint value,
129     // of the mask either all be zeros or all be 1s.
130     // We perhaps need some kind of an assert?
131     // But that will affect performance.
132     Vectorized<float> vec(mask.values);
133     vec.values.val[0] = vbslq_f32(
134         vreinterpretq_u32_f32(vec.values.val[0]),
135         b.values.val[0],
136         a.values.val[0]);
137     vec.values.val[1] = vbslq_f32(
138         vreinterpretq_u32_f32(vec.values.val[1]),
139         b.values.val[1],
140         a.values.val[1]);
141     return vec;
142   }
143   template<typename step_t>
144   static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) {
145     const Vectorized<float> base_vec(base);
146     const Vectorized<float> step_vec(step);
147     const Vectorized<float> step_sizes(0, 1, 2, 3, 4, 5, 6, 7);
148     return fmadd(step_sizes, step_vec, base_vec);
149   }
150   static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b,
151                            int64_t count = size()) {
152     switch (count) {
153       case 0:
154         return a;
155       case 1:
156         {
157           Vectorized<float> vec;
158           static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0};
159           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
160           vec.values.val[1] = a.values.val[1];
161           vec.values.val[0] = vbslq_f32(
162               vreinterpretq_u32_f32(vec.values.val[0]),
163               b.values.val[0],
164               a.values.val[0]);
165           return vec;
166         }
167       case 2:
168         {
169           Vectorized<float> vec;
170           static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
171           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
172           vec.values.val[1] = a.values.val[1];
173           vec.values.val[0] = vbslq_f32(
174               vreinterpretq_u32_f32(vec.values.val[0]),
175               b.values.val[0],
176               a.values.val[0]);
177           return vec;
178         }
179       case 3:
180         {
181           Vectorized<float> vec;
182           static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
183           vec.values.val[0] = vreinterpretq_f32_u32(mask_low);
184           vec.values.val[1] = a.values.val[1];
185           vec.values.val[0] = vbslq_f32(
186               vreinterpretq_u32_f32(vec.values.val[0]),
187               b.values.val[0],
188               a.values.val[0]);
189           return vec;
190         }
191       case 4:
192         return Vectorized<float>(b.values.val[0], a.values.val[1]);
193       case 5:
194         {
195           Vectorized<float> vec;
196           static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0};
197           vec.values.val[0] = b.values.val[0];
198           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
199           vec.values.val[1] = vbslq_f32(
200               vreinterpretq_u32_f32(vec.values.val[1]),
201               b.values.val[1],
202               a.values.val[1]);
203           return vec;
204         }
205       case 6:
206         {
207           Vectorized<float> vec;
208           static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0};
209           vec.values.val[0] = b.values.val[0];
210           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
211           vec.values.val[1] = vbslq_f32(
212               vreinterpretq_u32_f32(vec.values.val[1]),
213               b.values.val[1],
214               a.values.val[1]);
215           return vec;
216         }
217       case 7:
218         {
219           Vectorized<float> vec;
220           static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0};
221           vec.values.val[0] = b.values.val[0];
222           vec.values.val[1] = vreinterpretq_f32_u32(mask_high);
223           vec.values.val[1] = vbslq_f32(
224               vreinterpretq_u32_f32(vec.values.val[1]),
225               b.values.val[1],
226               a.values.val[1]);
227           return vec;
228         }
229     }
230     return b;
231   }
232   static Vectorized<float> loadu(const void* ptr, int64_t count = size()) {
233     if (count == size()) {
234       return vld1q_f32_x2(reinterpret_cast<const float*>(ptr));
235     }
236     else if (count == (size() >> 1)) {
237       Vectorized<float> res;
238       res.values.val[0] = vld1q_f32(reinterpret_cast<const float*>(ptr));
239       res.values.val[1] = vdupq_n_f32(0.f);
240       return res;
241     }
242     else {
243       __at_align__ float tmp_values[size()];
244       for (size_t i = 0; i < size(); ++i) {
245         tmp_values[i] = 0.0;
246       }
247       std::memcpy(
248           tmp_values,
249           reinterpret_cast<const float*>(ptr),
250           count * sizeof(float));
251       return vld1q_f32_x2(reinterpret_cast<const float*>(tmp_values));
252     }
253   }
254   void store(void* ptr, int64_t count = size()) const {
255     if (count == size()) {
256       vst1q_f32_x2(reinterpret_cast<float*>(ptr), values);
257     }
258     else if (count == (size() >> 1)) {
259       vst1q_f32(reinterpret_cast<float*>(ptr), values.val[0]);
260     }
261     else {
262       float tmp_values[size()];
263       vst1q_f32_x2(reinterpret_cast<float*>(tmp_values), values);
264       std::memcpy(ptr, tmp_values, count * sizeof(float));
265     }
266   }
267   inline const float32x4_t& get_low() const {
268     return values.val[0];
269   }
270   inline float32x4_t& get_low() {
271     return values.val[0];
272   }
273   inline const float32x4_t& get_high() const {
274     return values.val[1];
275   }
276   inline float32x4_t& get_high() {
277     return values.val[1];
278   }
279   // Very slow implementation of indexing.
280   // Only required because vec256_qint refers to this.
281   // Once we specialize that implementation for ARM
282   // this should be removed. TODO (kimishpatel)
283   float operator[](int idx) const {
284     __at_align__ float tmp[size()];
285     store(tmp);
286     return tmp[idx];
287   }
288   float operator[](int idx) {
289     __at_align__ float tmp[size()];
290     store(tmp);
291     return tmp[idx];
292   }
293   // For boolean version where we want to if any 1/all zero
294   // etc. can be done faster in a different way.
295   int zero_mask() const {
296     __at_align__ float tmp[size()];
297     store(tmp);
298     int mask = 0;
299     for (size_t i = 0; i < size(); ++ i) {
300       if (tmp[i] == 0.f) {
301         mask |= (1 << i);
302       }
303     }
304     return mask;
305   }
306   Vectorized<float> isnan() const {
307     __at_align__ float tmp[size()];
308     __at_align__ float res[size()];
309     store(tmp);
310     for (size_t i = 0; i < size(); ++i) {
311       if (std::isnan(tmp[i])) {
312         std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(float));
313       } else {
314         std::memset(static_cast<void*>(&res[i]), 0, sizeof(float));
315       }
316     }
317     return loadu(res);
318   };
319   Vectorized<float> map(float (*const f)(float)) const {
320     __at_align__ float tmp[size()];
321     store(tmp);
322     for (size_t i = 0; i < size(); ++i) {
323       tmp[i] = f(tmp[i]);
324     }
325     return loadu(tmp);
326   }
327   Vectorized<float> abs() const {
328     return Vectorized<float>(vabsq_f32(values.val[0]), vabsq_f32(values.val[1]));
329   }
330   Vectorized<float> acos() const {
331     return USE_SLEEF(
332       Vectorized<float>(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])),
333       map(std::acos)
334     );
335   }
336   Vectorized<float> asin() const {
337     return USE_SLEEF(
338       Vectorized<float>(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])),
339       map(std::asin)
340     );
341   }
342   Vectorized<float> atan() const {
343     return USE_SLEEF(
344       Vectorized<float>(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])),
345       map(std::atan)
346     );
347   }
348   Vectorized<float> atan2(const Vectorized<float> &exp) const {
349     USE_SLEEF(
350       {
351         return Vectorized<float>(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]),
352                                  Sleef_atan2f4_u10(values.val[1], exp.values.val[1]));
353       },
354       {
355         __at_align__ float tmp[size()];
356         __at_align__ float tmp_exp[size()];
357         store(tmp);
358         exp.store(tmp_exp);
359         for (size_t i = 0; i < size(); ++i) {
360           tmp[i] = std::atan2(tmp[i], tmp_exp[i]);
361         }
362         return loadu(tmp);
363       }
364     )
365   }
366   Vectorized<float> copysign(const Vectorized<float> &sign) const {
367     USE_SLEEF(
368       {
369         return Vectorized<float>(Sleef_copysignf4(values.val[0], sign.values.val[0]),
370                                  Sleef_copysignf4(values.val[1], sign.values.val[1]));
371       },
372       {
373         __at_align__ float tmp[size()];
374         __at_align__ float tmp_sign[size()];
375         store(tmp);
376         sign.store(tmp_sign);
377         for (size_t i = 0; i < size(); i++) {
378           tmp[i] = std::copysign(tmp[i], tmp_sign[i]);
379         }
380         return loadu(tmp);
381       }
382     )
383   }
384   Vectorized<float> erf() const {
385     return USE_SLEEF(
386       Vectorized<float>(Sleef_erff4_u10(values.val[0]), Sleef_erff4_u10(values.val[1])),
387       map(std::erf);
388     );
389   }
390   Vectorized<float> erfc() const {
391     return USE_SLEEF(
392       Vectorized<float>(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])),
393       map(std::erfc)
394     );
395   }
396   Vectorized<float> exp() const {
397     return USE_SLEEF(
398       Vectorized<float>(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])),
399       map(std::exp)
400     );
401   }
402   Vectorized<float> exp2() const {
403     return USE_SLEEF(
404         Vectorized<float>(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])),
405         map(std::exp2)
406       );
407   }
408   Vectorized<float> expm1() const {
409     return USE_SLEEF(
410       Vectorized<float>(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])),
411       map(std::expm1)
412     );
413   }
414   Vectorized<float> fmod(const Vectorized<float>& q) const {
415     USE_SLEEF(
416       {
417         return Vectorized<float>(Sleef_fmodf4(values.val[0], q.values.val[0]),
418                                  Sleef_fmodf4(values.val[1], q.values.val[1]));
419       },
420       {
421         __at_align__ float tmp[size()];
422         __at_align__ float tmp_q[size()];
423         store(tmp);
424         q.store(tmp_q);
425         for (size_t i = 0; i < size(); ++i) {
426           tmp[i] = std::fmod(tmp[i], tmp_q[i]);
427         }
428         return loadu(tmp);
429       }
430     )
431   }
432   Vectorized<float> hypot(const Vectorized<float> &b) const {
433     USE_SLEEF(
434       {
435         return Vectorized<float>(Sleef_hypotf4_u05(values.val[0], b.values.val[0]),
436                                  Sleef_hypotf4_u05(values.val[1], b.values.val[1]));
437       },
438       {
439         __at_align__ float tmp[size()];
440         __at_align__ float tmp_b[size()];
441         store(tmp);
442         b.store(tmp_b);
443         for (size_t i = 0; i < size(); ++i) {
444           tmp[i] = std::hypot(tmp[i], tmp_b[i]);
445         }
446         return loadu(tmp);
447       }
448     )
449   }
450   Vectorized<float> log() const {
451     return USE_SLEEF(
452       Vectorized<float>(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])),
453       map(std::log)
454     );
455   }
456   Vectorized<float> log10() const {
457     return USE_SLEEF(
458       Vectorized<float>(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])),
459       map(std::log10)
460     );
461   }
462   Vectorized<float> log1p() const {
463     return USE_SLEEF(
464       Vectorized<float>(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])),
465       map(std::log1p)
466     );
467   }
468   Vectorized<float> log2() const {
469     return USE_SLEEF(
470       Vectorized<float>(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])),
471       map(std::log2)
472     );
473   }
474   Vectorized<float> nextafter(const Vectorized<float> &b) const {
475     USE_SLEEF(
476       {
477         return Vectorized<float>(Sleef_nextafterf4(values.val[0], b.values.val[0]),
478                                  Sleef_nextafterf4(values.val[1], b.values.val[1]));
479       },
480       {
481         __at_align__ float tmp[size()];
482         __at_align__ float tmp_b[size()];
483         store(tmp);
484         b.store(tmp_b);
485         for (size_t i = 0; i < size(); ++i) {
486           tmp[i] = std::nextafter(tmp[i], tmp_b[i]);
487         }
488         return loadu(tmp);
489       }
490     )
491   }
492   Vectorized<float> frac() const;
493   Vectorized<float> sin() const {
494     return USE_SLEEF(
495       Vectorized<float>(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])),
496       map(std::sin)
497     );
498   }
499   Vectorized<float> sinh() const {
500     return USE_SLEEF(
501       Vectorized<float>(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])),
502       map(std::sinh)
503     );
504   }
505   Vectorized<float> cos() const {
506     return USE_SLEEF(
507       Vectorized<float>(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])),
508       map(std::cos)
509     );
510   }
511   Vectorized<float> cosh() const {
512     return USE_SLEEF(
513       Vectorized<float>(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])),
514       map(std::cosh)
515     );
516   }
517   Vectorized<float> ceil() const {
518     return map(std::ceil);
519   }
520   Vectorized<float> floor() const {
521     return map(std::floor);
522   }
523   Vectorized<float> neg() const {
524     return Vectorized<float>(
525         vnegq_f32(values.val[0]),
526         vnegq_f32(values.val[1]));
527   }
528   Vectorized<float> round() const {
529     return map(std::round);
530   }
531   Vectorized<float> tan() const {
532     return USE_SLEEF(
533       Vectorized<float>(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])),
534       map(std::tan)
535     );
536   }
537   Vectorized<float> tanh() const {
538     return USE_SLEEF(
539       Vectorized<float>(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])),
540       map(std::tanh)
541     );
542   }
543   Vectorized<float> trunc() const {
544     float32x4_t r0 = vrndq_f32(values.val[0]);
545     float32x4_t r1 = vrndq_f32(values.val[1]);
546     return Vectorized<float>(r0, r1);
547   }
548   Vectorized<float> lgamma() const {
549     return USE_SLEEF(
550       Vectorized<float>(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])),
551       map(std::lgamma)
552     );
553   }
554   Vectorized<float> sqrt() const {
555     return Vectorized<float>(
556         vsqrtq_f32(values.val[0]),
557         vsqrtq_f32(values.val[1]));
558   }
559   Vectorized<float> reciprocal() const {
560     auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]);
561     auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]);
562     return Vectorized<float>(r0, r1);
563   }
564   Vectorized<float> rsqrt() const {
565     return this->sqrt().reciprocal();
566   }
567   Vectorized<float> pow(const Vectorized<float> &exp) const {
568     USE_SLEEF(
569       {
570         return Vectorized<float>(Sleef_powf4_u10(values.val[0], exp.values.val[0]),
571                                  Sleef_powf4_u10(values.val[1], exp.values.val[1]));
572       },
573       {
574         __at_align__ float tmp[size()];
575         __at_align__ float tmp_exp[size()];
576         store(tmp);
577         exp.store(tmp_exp);
578         for (size_t i = 0; i < size(); ++i) {
579           tmp[i] = std::pow(tmp[i], tmp_exp[i]);
580         }
581         return loadu(tmp);
582       }
583     )
584   }
585   Vectorized<float> operator==(const Vectorized<float>& other) const {
586     float32x4_t r0 =
587       vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0]));
588     float32x4_t r1 =
589       vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1]));
590     return Vectorized<float>(r0, r1);
591   }
592 
593   Vectorized<float> operator!=(const Vectorized<float>& other) const {
594     float32x4_t r0 = vreinterpretq_f32_u32(
595         vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0])));
596     float32x4_t r1 = vreinterpretq_f32_u32(
597         vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1])));
598     return Vectorized<float>(r0, r1);
599   }
600 
601   Vectorized<float> operator<(const Vectorized<float>& other) const {
602     float32x4_t r0 =
603       vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0]));
604     float32x4_t r1 =
605       vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1]));
606     return Vectorized<float>(r0, r1);
607   }
608 
609   Vectorized<float> operator<=(const Vectorized<float>& other) const {
610     float32x4_t r0 =
611       vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0]));
612     float32x4_t r1 =
613       vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1]));
614     return Vectorized<float>(r0, r1);
615   }
616 
617   Vectorized<float> operator>(const Vectorized<float>& other) const {
618     float32x4_t r0 =
619       vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0]));
620     float32x4_t r1 =
621       vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1]));
622     return Vectorized<float>(r0, r1);
623   }
624 
625   Vectorized<float> operator>=(const Vectorized<float>& other) const {
626     float32x4_t r0 =
627       vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0]));
628     float32x4_t r1 =
629       vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1]));
630     return Vectorized<float>(r0, r1);
631   }
632 
633   Vectorized<float> eq(const Vectorized<float>& other) const;
634   Vectorized<float> ne(const Vectorized<float>& other) const;
635   Vectorized<float> gt(const Vectorized<float>& other) const;
636   Vectorized<float> ge(const Vectorized<float>& other) const;
637   Vectorized<float> lt(const Vectorized<float>& other) const;
638   Vectorized<float> le(const Vectorized<float>& other) const;
639 };
640 
641 template <>
642 Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) {
643   float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low());
644   float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high());
645   return Vectorized<float>(r0, r1);
646 }
647 
648 template <>
649 Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) {
650   float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low());
651   float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high());
652   return Vectorized<float>(r0, r1);
653 }
654 
655 template <>
656 Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) {
657   float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low());
658   float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high());
659   return Vectorized<float>(r0, r1);
660 }
661 
662 template <>
663 Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) {
664   float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low());
665   float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high());
666   return Vectorized<float>(r0, r1);
667 }
668 
669 // frac. Implement this here so we can use subtraction
670 inline Vectorized<float> Vectorized<float>::frac() const {
671   return *this - this->trunc();
672 }
673 
674 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
675 // either input is a NaN.
676 template <>
677 Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) {
678   float32x4_t r0 = vmaxq_f32(a.get_low(), b.get_low());
679   float32x4_t r1 = vmaxq_f32(a.get_high(), b.get_high());
680   return Vectorized<float>(r0, r1);
681 }
682 
683 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
684 // either input is a NaN.
685 template <>
686 Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) {
687   float32x4_t r0 = vminq_f32(a.get_low(), b.get_low());
688   float32x4_t r1 = vminq_f32(a.get_high(), b.get_high());
689   return Vectorized<float>(r0, r1);
690 }
691 
692 template <>
693 Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) {
694   return minimum(max, maximum(min, a));
695 }
696 
697 template <>
698 Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) {
699   return minimum(max, a);
700 }
701 
702 template <>
703 Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) {
704   return maximum(min, a);
705 }
706 
707 template <>
708 Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) {
709   float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32(
710       vreinterpretq_u32_f32(a.get_low()),
711       vreinterpretq_u32_f32(b.get_low())));
712   float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32(
713       vreinterpretq_u32_f32(a.get_high()),
714       vreinterpretq_u32_f32(b.get_high())));
715   return Vectorized<float>(r0, r1);
716 }
717 
718 template <>
719 Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) {
720   float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32(
721       vreinterpretq_u32_f32(a.get_low()),
722       vreinterpretq_u32_f32(b.get_low())));
723   float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32(
724       vreinterpretq_u32_f32(a.get_high()),
725       vreinterpretq_u32_f32(b.get_high())));
726   return Vectorized<float>(r0, r1);
727 }
728 
729 template <>
730 Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) {
731   float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32(
732       vreinterpretq_u32_f32(a.get_low()),
733       vreinterpretq_u32_f32(b.get_low())));
734   float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32(
735       vreinterpretq_u32_f32(a.get_high()),
736       vreinterpretq_u32_f32(b.get_high())));
737   return Vectorized<float>(r0, r1);
738 }
739 
740 inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const {
741   return (*this == other) & Vectorized<float>(1.0f);
742 }
743 
744 inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const {
745   return (*this != other) & Vectorized<float>(1.0f);
746 }
747 
748 inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const {
749   return (*this > other) & Vectorized<float>(1.0f);
750 }
751 
752 inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const {
753   return (*this >= other) & Vectorized<float>(1.0f);
754 }
755 
756 inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const {
757   return (*this < other) & Vectorized<float>(1.0f);
758 }
759 
760 inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const {
761   return (*this <= other) & Vectorized<float>(1.0f);
762 }
763 
764 template <>
765 inline void convert(const float* src, int32_t* dst, int64_t n) {
766   int64_t i;
767 #pragma unroll
768   for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
769     vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i)));
770     vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4)));
771   }
772 #pragma unroll
773   for (; i < n; i++) {
774     dst[i] = static_cast<int32_t>(src[i]);
775   }
776 }
777 
778 template <>
779 inline void convert(const int32_t* src, float* dst, int64_t n) {
780   int64_t i;
781 #pragma unroll
782   for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
783     vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i)));
784     vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4)));
785   }
786 #pragma unroll
787   for (; i < n; i++) {
788     dst[i] = static_cast<float>(src[i]);
789   }
790 }
791 
792 template <>
793 Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
794   float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low());
795   float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high());
796   return Vectorized<float>(r0, r1);
797 }
798 
799 template <>
800 Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) {
801   float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low());
802   float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high());
803   return Vectorized<float>(r0, r1);
804 }
805 
806 #endif /* defined(aarch64) */
807 
808 }}}
809