1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // DO NOT DEFINE STATIC DATA IN THIS HEADER! 12 // See Note [Do not compile initializers with AVX] 13 14 #include <executorch/kernels/optimized/vec/intrinsics.h> 15 #include <executorch/kernels/optimized/vec/vec_base.h> 16 17 18 #if defined(__aarch64__) && defined(ET_BUILD_ARM_VEC256_WITH_SLEEF) 19 #include <sleef.h> 20 #endif 21 22 // Sleef offers vectorized versions of some transcedentals 23 // such as sin, cos, tan etc.. 24 // However for now opting for STL, since we are not building 25 // with Sleef for mobile yet. 26 27 namespace executorch { 28 namespace vec { 29 // See Note [CPU_CAPABILITY namespace] 30 inline namespace CPU_CAPABILITY { 31 32 // Right now contains only aarch64 implementation. 33 // Due to follow two reasons aarch32 is not currently supported. 34 // 1. Due to difference in ISA been aarch32 and aarch64, intrinsics 35 // that work for aarch64 dont work for aarch32. 36 // 2. Android NDK r21 has problems with compiling aarch32. 37 // Clang seg faults. 38 // https://github.com/android/ndk/issues/1248 39 // https://bugs.llvm.org/show_bug.cgi?id=45824 40 // Most likely we will do aarch32 support with inline asm. 41 #if defined(__aarch64__) 42 43 #ifdef __BIG_ENDIAN__ 44 #error "Big endian is not supported." 45 #endif 46 47 #if defined(ET_BUILD_ARM_VEC256_WITH_SLEEF) 48 #define USE_SLEEF(sleef_code, non_sleef_code) sleef_code 49 #else 50 #define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code 51 #endif 52 53 template<int index, bool mask_val> 54 struct BlendRegs { 55 static float32x4_t impl( 56 const float32x4_t& a, const float32x4_t& b, float32x4_t& res); 57 }; 58 59 template<int index> 60 struct BlendRegs<index, true>{ 61 static float32x4_t impl( 62 const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { 63 return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); 64 } 65 }; 66 67 template<int index> 68 struct BlendRegs<index, false>{ 69 static float32x4_t impl( 70 const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { 71 return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); 72 } 73 }; 74 75 template <> class Vectorized<float> { 76 private: 77 float32x4x2_t values; 78 public: 79 using value_type = float; 80 using size_type = int; 81 static constexpr size_type size() { 82 return 8; 83 } 84 Vectorized() {} 85 Vectorized(float32x4x2_t v) : values(v) {} 86 Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {} 87 Vectorized(float val0, float val1, float val2, float val3, 88 float val4, float val5, float val6, float val7) : 89 values{val0, val1, val2, val3, val4, val5, val6, val7} {} 90 Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {} 91 operator float32x4x2_t() const { 92 return values; 93 } 94 template <int64_t mask> 95 static Vectorized<float> blend(const Vectorized<float>& a, const Vectorized<float>& b) { 96 Vectorized<float> vec; 97 // 0. 98 vec.values.val[0] = 99 BlendRegs<0, (mask & 0x01)!=0>::impl( 100 a.values.val[0], b.values.val[0], vec.values.val[0]); 101 vec.values.val[0] = 102 BlendRegs<1, (mask & 0x02)!=0>::impl( 103 a.values.val[0], b.values.val[0], vec.values.val[0]); 104 vec.values.val[0] = 105 BlendRegs<2, (mask & 0x04)!=0>::impl( 106 a.values.val[0], b.values.val[0], vec.values.val[0]); 107 vec.values.val[0] = 108 BlendRegs<3, (mask & 0x08)!=0>::impl( 109 a.values.val[0], b.values.val[0], vec.values.val[0]); 110 // 1. 111 vec.values.val[1] = 112 BlendRegs<0, (mask & 0x10)!=0>::impl( 113 a.values.val[1], b.values.val[1], vec.values.val[1]); 114 vec.values.val[1] = 115 BlendRegs<1, (mask & 0x20)!=0>::impl( 116 a.values.val[1], b.values.val[1], vec.values.val[1]); 117 vec.values.val[1] = 118 BlendRegs<2, (mask & 0x40)!=0>::impl( 119 a.values.val[1], b.values.val[1], vec.values.val[1]); 120 vec.values.val[1] = 121 BlendRegs<3, (mask & 0x80)!=0>::impl( 122 a.values.val[1], b.values.val[1], vec.values.val[1]); 123 return vec; 124 } 125 static Vectorized<float> blendv(const Vectorized<float>& a, const Vectorized<float>& b, 126 const Vectorized<float>& mask) { 127 // TODO 128 // NB: This requires that each value, i.e., each uint value, 129 // of the mask either all be zeros or all be 1s. 130 // We perhaps need some kind of an assert? 131 // But that will affect performance. 132 Vectorized<float> vec(mask.values); 133 vec.values.val[0] = vbslq_f32( 134 vreinterpretq_u32_f32(vec.values.val[0]), 135 b.values.val[0], 136 a.values.val[0]); 137 vec.values.val[1] = vbslq_f32( 138 vreinterpretq_u32_f32(vec.values.val[1]), 139 b.values.val[1], 140 a.values.val[1]); 141 return vec; 142 } 143 template<typename step_t> 144 static Vectorized<float> arange(float base = 0.f, step_t step = static_cast<step_t>(1)) { 145 const Vectorized<float> base_vec(base); 146 const Vectorized<float> step_vec(step); 147 const Vectorized<float> step_sizes(0, 1, 2, 3, 4, 5, 6, 7); 148 return fmadd(step_sizes, step_vec, base_vec); 149 } 150 static Vectorized<float> set(const Vectorized<float>& a, const Vectorized<float>& b, 151 int64_t count = size()) { 152 switch (count) { 153 case 0: 154 return a; 155 case 1: 156 { 157 Vectorized<float> vec; 158 static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; 159 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 160 vec.values.val[1] = a.values.val[1]; 161 vec.values.val[0] = vbslq_f32( 162 vreinterpretq_u32_f32(vec.values.val[0]), 163 b.values.val[0], 164 a.values.val[0]); 165 return vec; 166 } 167 case 2: 168 { 169 Vectorized<float> vec; 170 static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; 171 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 172 vec.values.val[1] = a.values.val[1]; 173 vec.values.val[0] = vbslq_f32( 174 vreinterpretq_u32_f32(vec.values.val[0]), 175 b.values.val[0], 176 a.values.val[0]); 177 return vec; 178 } 179 case 3: 180 { 181 Vectorized<float> vec; 182 static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; 183 vec.values.val[0] = vreinterpretq_f32_u32(mask_low); 184 vec.values.val[1] = a.values.val[1]; 185 vec.values.val[0] = vbslq_f32( 186 vreinterpretq_u32_f32(vec.values.val[0]), 187 b.values.val[0], 188 a.values.val[0]); 189 return vec; 190 } 191 case 4: 192 return Vectorized<float>(b.values.val[0], a.values.val[1]); 193 case 5: 194 { 195 Vectorized<float> vec; 196 static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0}; 197 vec.values.val[0] = b.values.val[0]; 198 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 199 vec.values.val[1] = vbslq_f32( 200 vreinterpretq_u32_f32(vec.values.val[1]), 201 b.values.val[1], 202 a.values.val[1]); 203 return vec; 204 } 205 case 6: 206 { 207 Vectorized<float> vec; 208 static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; 209 vec.values.val[0] = b.values.val[0]; 210 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 211 vec.values.val[1] = vbslq_f32( 212 vreinterpretq_u32_f32(vec.values.val[1]), 213 b.values.val[1], 214 a.values.val[1]); 215 return vec; 216 } 217 case 7: 218 { 219 Vectorized<float> vec; 220 static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; 221 vec.values.val[0] = b.values.val[0]; 222 vec.values.val[1] = vreinterpretq_f32_u32(mask_high); 223 vec.values.val[1] = vbslq_f32( 224 vreinterpretq_u32_f32(vec.values.val[1]), 225 b.values.val[1], 226 a.values.val[1]); 227 return vec; 228 } 229 } 230 return b; 231 } 232 static Vectorized<float> loadu(const void* ptr, int64_t count = size()) { 233 if (count == size()) { 234 return vld1q_f32_x2(reinterpret_cast<const float*>(ptr)); 235 } 236 else if (count == (size() >> 1)) { 237 Vectorized<float> res; 238 res.values.val[0] = vld1q_f32(reinterpret_cast<const float*>(ptr)); 239 res.values.val[1] = vdupq_n_f32(0.f); 240 return res; 241 } 242 else { 243 __at_align__ float tmp_values[size()]; 244 for (size_t i = 0; i < size(); ++i) { 245 tmp_values[i] = 0.0; 246 } 247 std::memcpy( 248 tmp_values, 249 reinterpret_cast<const float*>(ptr), 250 count * sizeof(float)); 251 return vld1q_f32_x2(reinterpret_cast<const float*>(tmp_values)); 252 } 253 } 254 void store(void* ptr, int64_t count = size()) const { 255 if (count == size()) { 256 vst1q_f32_x2(reinterpret_cast<float*>(ptr), values); 257 } 258 else if (count == (size() >> 1)) { 259 vst1q_f32(reinterpret_cast<float*>(ptr), values.val[0]); 260 } 261 else { 262 float tmp_values[size()]; 263 vst1q_f32_x2(reinterpret_cast<float*>(tmp_values), values); 264 std::memcpy(ptr, tmp_values, count * sizeof(float)); 265 } 266 } 267 inline const float32x4_t& get_low() const { 268 return values.val[0]; 269 } 270 inline float32x4_t& get_low() { 271 return values.val[0]; 272 } 273 inline const float32x4_t& get_high() const { 274 return values.val[1]; 275 } 276 inline float32x4_t& get_high() { 277 return values.val[1]; 278 } 279 // Very slow implementation of indexing. 280 // Only required because vec256_qint refers to this. 281 // Once we specialize that implementation for ARM 282 // this should be removed. TODO (kimishpatel) 283 float operator[](int idx) const { 284 __at_align__ float tmp[size()]; 285 store(tmp); 286 return tmp[idx]; 287 } 288 float operator[](int idx) { 289 __at_align__ float tmp[size()]; 290 store(tmp); 291 return tmp[idx]; 292 } 293 // For boolean version where we want to if any 1/all zero 294 // etc. can be done faster in a different way. 295 int zero_mask() const { 296 __at_align__ float tmp[size()]; 297 store(tmp); 298 int mask = 0; 299 for (size_t i = 0; i < size(); ++ i) { 300 if (tmp[i] == 0.f) { 301 mask |= (1 << i); 302 } 303 } 304 return mask; 305 } 306 Vectorized<float> isnan() const { 307 __at_align__ float tmp[size()]; 308 __at_align__ float res[size()]; 309 store(tmp); 310 for (size_t i = 0; i < size(); ++i) { 311 if (std::isnan(tmp[i])) { 312 std::memset(static_cast<void*>(&res[i]), 0xFF, sizeof(float)); 313 } else { 314 std::memset(static_cast<void*>(&res[i]), 0, sizeof(float)); 315 } 316 } 317 return loadu(res); 318 }; 319 Vectorized<float> map(float (*const f)(float)) const { 320 __at_align__ float tmp[size()]; 321 store(tmp); 322 for (size_t i = 0; i < size(); ++i) { 323 tmp[i] = f(tmp[i]); 324 } 325 return loadu(tmp); 326 } 327 Vectorized<float> abs() const { 328 return Vectorized<float>(vabsq_f32(values.val[0]), vabsq_f32(values.val[1])); 329 } 330 Vectorized<float> acos() const { 331 return USE_SLEEF( 332 Vectorized<float>(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])), 333 map(std::acos) 334 ); 335 } 336 Vectorized<float> asin() const { 337 return USE_SLEEF( 338 Vectorized<float>(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])), 339 map(std::asin) 340 ); 341 } 342 Vectorized<float> atan() const { 343 return USE_SLEEF( 344 Vectorized<float>(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])), 345 map(std::atan) 346 ); 347 } 348 Vectorized<float> atan2(const Vectorized<float> &exp) const { 349 USE_SLEEF( 350 { 351 return Vectorized<float>(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]), 352 Sleef_atan2f4_u10(values.val[1], exp.values.val[1])); 353 }, 354 { 355 __at_align__ float tmp[size()]; 356 __at_align__ float tmp_exp[size()]; 357 store(tmp); 358 exp.store(tmp_exp); 359 for (size_t i = 0; i < size(); ++i) { 360 tmp[i] = std::atan2(tmp[i], tmp_exp[i]); 361 } 362 return loadu(tmp); 363 } 364 ) 365 } 366 Vectorized<float> copysign(const Vectorized<float> &sign) const { 367 USE_SLEEF( 368 { 369 return Vectorized<float>(Sleef_copysignf4(values.val[0], sign.values.val[0]), 370 Sleef_copysignf4(values.val[1], sign.values.val[1])); 371 }, 372 { 373 __at_align__ float tmp[size()]; 374 __at_align__ float tmp_sign[size()]; 375 store(tmp); 376 sign.store(tmp_sign); 377 for (size_t i = 0; i < size(); i++) { 378 tmp[i] = std::copysign(tmp[i], tmp_sign[i]); 379 } 380 return loadu(tmp); 381 } 382 ) 383 } 384 Vectorized<float> erf() const { 385 return USE_SLEEF( 386 Vectorized<float>(Sleef_erff4_u10(values.val[0]), Sleef_erff4_u10(values.val[1])), 387 map(std::erf); 388 ); 389 } 390 Vectorized<float> erfc() const { 391 return USE_SLEEF( 392 Vectorized<float>(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])), 393 map(std::erfc) 394 ); 395 } 396 Vectorized<float> exp() const { 397 return USE_SLEEF( 398 Vectorized<float>(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])), 399 map(std::exp) 400 ); 401 } 402 Vectorized<float> exp2() const { 403 return USE_SLEEF( 404 Vectorized<float>(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])), 405 map(std::exp2) 406 ); 407 } 408 Vectorized<float> expm1() const { 409 return USE_SLEEF( 410 Vectorized<float>(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])), 411 map(std::expm1) 412 ); 413 } 414 Vectorized<float> fmod(const Vectorized<float>& q) const { 415 USE_SLEEF( 416 { 417 return Vectorized<float>(Sleef_fmodf4(values.val[0], q.values.val[0]), 418 Sleef_fmodf4(values.val[1], q.values.val[1])); 419 }, 420 { 421 __at_align__ float tmp[size()]; 422 __at_align__ float tmp_q[size()]; 423 store(tmp); 424 q.store(tmp_q); 425 for (size_t i = 0; i < size(); ++i) { 426 tmp[i] = std::fmod(tmp[i], tmp_q[i]); 427 } 428 return loadu(tmp); 429 } 430 ) 431 } 432 Vectorized<float> hypot(const Vectorized<float> &b) const { 433 USE_SLEEF( 434 { 435 return Vectorized<float>(Sleef_hypotf4_u05(values.val[0], b.values.val[0]), 436 Sleef_hypotf4_u05(values.val[1], b.values.val[1])); 437 }, 438 { 439 __at_align__ float tmp[size()]; 440 __at_align__ float tmp_b[size()]; 441 store(tmp); 442 b.store(tmp_b); 443 for (size_t i = 0; i < size(); ++i) { 444 tmp[i] = std::hypot(tmp[i], tmp_b[i]); 445 } 446 return loadu(tmp); 447 } 448 ) 449 } 450 Vectorized<float> log() const { 451 return USE_SLEEF( 452 Vectorized<float>(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])), 453 map(std::log) 454 ); 455 } 456 Vectorized<float> log10() const { 457 return USE_SLEEF( 458 Vectorized<float>(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])), 459 map(std::log10) 460 ); 461 } 462 Vectorized<float> log1p() const { 463 return USE_SLEEF( 464 Vectorized<float>(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])), 465 map(std::log1p) 466 ); 467 } 468 Vectorized<float> log2() const { 469 return USE_SLEEF( 470 Vectorized<float>(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])), 471 map(std::log2) 472 ); 473 } 474 Vectorized<float> nextafter(const Vectorized<float> &b) const { 475 USE_SLEEF( 476 { 477 return Vectorized<float>(Sleef_nextafterf4(values.val[0], b.values.val[0]), 478 Sleef_nextafterf4(values.val[1], b.values.val[1])); 479 }, 480 { 481 __at_align__ float tmp[size()]; 482 __at_align__ float tmp_b[size()]; 483 store(tmp); 484 b.store(tmp_b); 485 for (size_t i = 0; i < size(); ++i) { 486 tmp[i] = std::nextafter(tmp[i], tmp_b[i]); 487 } 488 return loadu(tmp); 489 } 490 ) 491 } 492 Vectorized<float> frac() const; 493 Vectorized<float> sin() const { 494 return USE_SLEEF( 495 Vectorized<float>(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])), 496 map(std::sin) 497 ); 498 } 499 Vectorized<float> sinh() const { 500 return USE_SLEEF( 501 Vectorized<float>(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])), 502 map(std::sinh) 503 ); 504 } 505 Vectorized<float> cos() const { 506 return USE_SLEEF( 507 Vectorized<float>(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])), 508 map(std::cos) 509 ); 510 } 511 Vectorized<float> cosh() const { 512 return USE_SLEEF( 513 Vectorized<float>(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])), 514 map(std::cosh) 515 ); 516 } 517 Vectorized<float> ceil() const { 518 return map(std::ceil); 519 } 520 Vectorized<float> floor() const { 521 return map(std::floor); 522 } 523 Vectorized<float> neg() const { 524 return Vectorized<float>( 525 vnegq_f32(values.val[0]), 526 vnegq_f32(values.val[1])); 527 } 528 Vectorized<float> round() const { 529 return map(std::round); 530 } 531 Vectorized<float> tan() const { 532 return USE_SLEEF( 533 Vectorized<float>(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])), 534 map(std::tan) 535 ); 536 } 537 Vectorized<float> tanh() const { 538 return USE_SLEEF( 539 Vectorized<float>(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])), 540 map(std::tanh) 541 ); 542 } 543 Vectorized<float> trunc() const { 544 float32x4_t r0 = vrndq_f32(values.val[0]); 545 float32x4_t r1 = vrndq_f32(values.val[1]); 546 return Vectorized<float>(r0, r1); 547 } 548 Vectorized<float> lgamma() const { 549 return USE_SLEEF( 550 Vectorized<float>(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])), 551 map(std::lgamma) 552 ); 553 } 554 Vectorized<float> sqrt() const { 555 return Vectorized<float>( 556 vsqrtq_f32(values.val[0]), 557 vsqrtq_f32(values.val[1])); 558 } 559 Vectorized<float> reciprocal() const { 560 auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]); 561 auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]); 562 return Vectorized<float>(r0, r1); 563 } 564 Vectorized<float> rsqrt() const { 565 return this->sqrt().reciprocal(); 566 } 567 Vectorized<float> pow(const Vectorized<float> &exp) const { 568 USE_SLEEF( 569 { 570 return Vectorized<float>(Sleef_powf4_u10(values.val[0], exp.values.val[0]), 571 Sleef_powf4_u10(values.val[1], exp.values.val[1])); 572 }, 573 { 574 __at_align__ float tmp[size()]; 575 __at_align__ float tmp_exp[size()]; 576 store(tmp); 577 exp.store(tmp_exp); 578 for (size_t i = 0; i < size(); ++i) { 579 tmp[i] = std::pow(tmp[i], tmp_exp[i]); 580 } 581 return loadu(tmp); 582 } 583 ) 584 } 585 Vectorized<float> operator==(const Vectorized<float>& other) const { 586 float32x4_t r0 = 587 vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0])); 588 float32x4_t r1 = 589 vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1])); 590 return Vectorized<float>(r0, r1); 591 } 592 593 Vectorized<float> operator!=(const Vectorized<float>& other) const { 594 float32x4_t r0 = vreinterpretq_f32_u32( 595 vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0]))); 596 float32x4_t r1 = vreinterpretq_f32_u32( 597 vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1]))); 598 return Vectorized<float>(r0, r1); 599 } 600 601 Vectorized<float> operator<(const Vectorized<float>& other) const { 602 float32x4_t r0 = 603 vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0])); 604 float32x4_t r1 = 605 vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1])); 606 return Vectorized<float>(r0, r1); 607 } 608 609 Vectorized<float> operator<=(const Vectorized<float>& other) const { 610 float32x4_t r0 = 611 vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0])); 612 float32x4_t r1 = 613 vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1])); 614 return Vectorized<float>(r0, r1); 615 } 616 617 Vectorized<float> operator>(const Vectorized<float>& other) const { 618 float32x4_t r0 = 619 vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0])); 620 float32x4_t r1 = 621 vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1])); 622 return Vectorized<float>(r0, r1); 623 } 624 625 Vectorized<float> operator>=(const Vectorized<float>& other) const { 626 float32x4_t r0 = 627 vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0])); 628 float32x4_t r1 = 629 vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1])); 630 return Vectorized<float>(r0, r1); 631 } 632 633 Vectorized<float> eq(const Vectorized<float>& other) const; 634 Vectorized<float> ne(const Vectorized<float>& other) const; 635 Vectorized<float> gt(const Vectorized<float>& other) const; 636 Vectorized<float> ge(const Vectorized<float>& other) const; 637 Vectorized<float> lt(const Vectorized<float>& other) const; 638 Vectorized<float> le(const Vectorized<float>& other) const; 639 }; 640 641 template <> 642 Vectorized<float> inline operator+(const Vectorized<float>& a, const Vectorized<float>& b) { 643 float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low()); 644 float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high()); 645 return Vectorized<float>(r0, r1); 646 } 647 648 template <> 649 Vectorized<float> inline operator-(const Vectorized<float>& a, const Vectorized<float>& b) { 650 float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low()); 651 float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high()); 652 return Vectorized<float>(r0, r1); 653 } 654 655 template <> 656 Vectorized<float> inline operator*(const Vectorized<float>& a, const Vectorized<float>& b) { 657 float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low()); 658 float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high()); 659 return Vectorized<float>(r0, r1); 660 } 661 662 template <> 663 Vectorized<float> inline operator/(const Vectorized<float>& a, const Vectorized<float>& b) { 664 float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low()); 665 float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high()); 666 return Vectorized<float>(r0, r1); 667 } 668 669 // frac. Implement this here so we can use subtraction 670 inline Vectorized<float> Vectorized<float>::frac() const { 671 return *this - this->trunc(); 672 } 673 674 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if 675 // either input is a NaN. 676 template <> 677 Vectorized<float> inline maximum(const Vectorized<float>& a, const Vectorized<float>& b) { 678 float32x4_t r0 = vmaxq_f32(a.get_low(), b.get_low()); 679 float32x4_t r1 = vmaxq_f32(a.get_high(), b.get_high()); 680 return Vectorized<float>(r0, r1); 681 } 682 683 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if 684 // either input is a NaN. 685 template <> 686 Vectorized<float> inline minimum(const Vectorized<float>& a, const Vectorized<float>& b) { 687 float32x4_t r0 = vminq_f32(a.get_low(), b.get_low()); 688 float32x4_t r1 = vminq_f32(a.get_high(), b.get_high()); 689 return Vectorized<float>(r0, r1); 690 } 691 692 template <> 693 Vectorized<float> inline clamp(const Vectorized<float>& a, const Vectorized<float>& min, const Vectorized<float>& max) { 694 return minimum(max, maximum(min, a)); 695 } 696 697 template <> 698 Vectorized<float> inline clamp_max(const Vectorized<float>& a, const Vectorized<float>& max) { 699 return minimum(max, a); 700 } 701 702 template <> 703 Vectorized<float> inline clamp_min(const Vectorized<float>& a, const Vectorized<float>& min) { 704 return maximum(min, a); 705 } 706 707 template <> 708 Vectorized<float> inline operator&(const Vectorized<float>& a, const Vectorized<float>& b) { 709 float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32( 710 vreinterpretq_u32_f32(a.get_low()), 711 vreinterpretq_u32_f32(b.get_low()))); 712 float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32( 713 vreinterpretq_u32_f32(a.get_high()), 714 vreinterpretq_u32_f32(b.get_high()))); 715 return Vectorized<float>(r0, r1); 716 } 717 718 template <> 719 Vectorized<float> inline operator|(const Vectorized<float>& a, const Vectorized<float>& b) { 720 float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32( 721 vreinterpretq_u32_f32(a.get_low()), 722 vreinterpretq_u32_f32(b.get_low()))); 723 float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32( 724 vreinterpretq_u32_f32(a.get_high()), 725 vreinterpretq_u32_f32(b.get_high()))); 726 return Vectorized<float>(r0, r1); 727 } 728 729 template <> 730 Vectorized<float> inline operator^(const Vectorized<float>& a, const Vectorized<float>& b) { 731 float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32( 732 vreinterpretq_u32_f32(a.get_low()), 733 vreinterpretq_u32_f32(b.get_low()))); 734 float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32( 735 vreinterpretq_u32_f32(a.get_high()), 736 vreinterpretq_u32_f32(b.get_high()))); 737 return Vectorized<float>(r0, r1); 738 } 739 740 inline Vectorized<float> Vectorized<float>::eq(const Vectorized<float>& other) const { 741 return (*this == other) & Vectorized<float>(1.0f); 742 } 743 744 inline Vectorized<float> Vectorized<float>::ne(const Vectorized<float>& other) const { 745 return (*this != other) & Vectorized<float>(1.0f); 746 } 747 748 inline Vectorized<float> Vectorized<float>::gt(const Vectorized<float>& other) const { 749 return (*this > other) & Vectorized<float>(1.0f); 750 } 751 752 inline Vectorized<float> Vectorized<float>::ge(const Vectorized<float>& other) const { 753 return (*this >= other) & Vectorized<float>(1.0f); 754 } 755 756 inline Vectorized<float> Vectorized<float>::lt(const Vectorized<float>& other) const { 757 return (*this < other) & Vectorized<float>(1.0f); 758 } 759 760 inline Vectorized<float> Vectorized<float>::le(const Vectorized<float>& other) const { 761 return (*this <= other) & Vectorized<float>(1.0f); 762 } 763 764 template <> 765 inline void convert(const float* src, int32_t* dst, int64_t n) { 766 int64_t i; 767 #pragma unroll 768 for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { 769 vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); 770 vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4))); 771 } 772 #pragma unroll 773 for (; i < n; i++) { 774 dst[i] = static_cast<int32_t>(src[i]); 775 } 776 } 777 778 template <> 779 inline void convert(const int32_t* src, float* dst, int64_t n) { 780 int64_t i; 781 #pragma unroll 782 for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) { 783 vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); 784 vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4))); 785 } 786 #pragma unroll 787 for (; i < n; i++) { 788 dst[i] = static_cast<float>(src[i]); 789 } 790 } 791 792 template <> 793 Vectorized<float> inline fmadd(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { 794 float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low()); 795 float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high()); 796 return Vectorized<float>(r0, r1); 797 } 798 799 template <> 800 Vectorized<float> inline fmsub(const Vectorized<float>& a, const Vectorized<float>& b, const Vectorized<float>& c) { 801 float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low()); 802 float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high()); 803 return Vectorized<float>(r0, r1); 804 } 805 806 #endif /* defined(aarch64) */ 807 808 }}} 809