1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner ([email protected]) 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_PACKET_MATH_AVX_H 11 #define EIGEN_PACKET_MATH_AVX_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 18 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 19 #endif 20 21 #if !defined(EIGEN_VECTORIZE_AVX512) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS) 22 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16 23 #endif 24 25 #ifdef EIGEN_VECTORIZE_FMA 26 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD 27 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD 28 #endif 29 #endif 30 31 typedef __m256 Packet8f; 32 typedef __m256i Packet8i; 33 typedef __m256d Packet4d; 34 typedef eigen_packet_wrapper<__m128i, 2> Packet8h; 35 typedef eigen_packet_wrapper<__m128i, 3> Packet8bf; 36 37 template<> struct is_arithmetic<__m256> { enum { value = true }; }; 38 template<> struct is_arithmetic<__m256i> { enum { value = true }; }; 39 template<> struct is_arithmetic<__m256d> { enum { value = true }; }; 40 template<> struct is_arithmetic<Packet8h> { enum { value = true }; }; 41 template<> struct is_arithmetic<Packet8bf> { enum { value = true }; }; 42 43 #define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \ 44 const Packet8f p8f_##NAME = pset1<Packet8f>(X) 45 46 #define _EIGEN_DECLARE_CONST_Packet4d(NAME,X) \ 47 const Packet4d p4d_##NAME = pset1<Packet4d>(X) 48 49 #define _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(NAME,X) \ 50 const Packet8f p8f_##NAME = _mm256_castsi256_ps(pset1<Packet8i>(X)) 51 52 #define _EIGEN_DECLARE_CONST_Packet8i(NAME,X) \ 53 const Packet8i p8i_##NAME = pset1<Packet8i>(X) 54 55 // Use the packet_traits defined in AVX512/PacketMath.h instead if we're going 56 // to leverage AVX512 instructions. 57 #ifndef EIGEN_VECTORIZE_AVX512 58 template<> struct packet_traits<float> : default_packet_traits 59 { 60 typedef Packet8f type; 61 typedef Packet4f half; 62 enum { 63 Vectorizable = 1, 64 AlignedOnScalar = 1, 65 size = 8, 66 HasHalfPacket = 1, 67 68 HasCmp = 1, 69 HasDiv = 1, 70 HasSin = EIGEN_FAST_MATH, 71 HasCos = EIGEN_FAST_MATH, 72 HasLog = 1, 73 HasLog1p = 1, 74 HasExpm1 = 1, 75 HasExp = 1, 76 HasNdtri = 1, 77 HasBessel = 1, 78 HasSqrt = 1, 79 HasRsqrt = 1, 80 HasTanh = EIGEN_FAST_MATH, 81 HasErf = EIGEN_FAST_MATH, 82 HasBlend = 1, 83 HasRound = 1, 84 HasFloor = 1, 85 HasCeil = 1, 86 HasRint = 1 87 }; 88 }; 89 template<> struct packet_traits<double> : default_packet_traits 90 { 91 typedef Packet4d type; 92 typedef Packet2d half; 93 enum { 94 Vectorizable = 1, 95 AlignedOnScalar = 1, 96 size=4, 97 HasHalfPacket = 1, 98 99 HasCmp = 1, 100 HasDiv = 1, 101 HasLog = 1, 102 HasExp = 1, 103 HasSqrt = 1, 104 HasRsqrt = 1, 105 HasBlend = 1, 106 HasRound = 1, 107 HasFloor = 1, 108 HasCeil = 1, 109 HasRint = 1 110 }; 111 }; 112 113 template <> 114 struct packet_traits<Eigen::half> : default_packet_traits { 115 typedef Packet8h type; 116 // There is no half-size packet for Packet8h. 117 typedef Packet8h half; 118 enum { 119 Vectorizable = 1, 120 AlignedOnScalar = 1, 121 size = 8, 122 HasHalfPacket = 0, 123 124 HasCmp = 1, 125 HasAdd = 1, 126 HasSub = 1, 127 HasMul = 1, 128 HasDiv = 1, 129 HasSin = EIGEN_FAST_MATH, 130 HasCos = EIGEN_FAST_MATH, 131 HasNegate = 1, 132 HasAbs = 1, 133 HasAbs2 = 0, 134 HasMin = 1, 135 HasMax = 1, 136 HasConj = 1, 137 HasSetLinear = 0, 138 HasLog = 1, 139 HasLog1p = 1, 140 HasExpm1 = 1, 141 HasExp = 1, 142 HasSqrt = 1, 143 HasRsqrt = 1, 144 HasTanh = EIGEN_FAST_MATH, 145 HasErf = EIGEN_FAST_MATH, 146 HasBlend = 0, 147 HasRound = 1, 148 HasFloor = 1, 149 HasCeil = 1, 150 HasRint = 1, 151 HasBessel = 1, 152 HasNdtri = 1 153 }; 154 }; 155 156 template <> 157 struct packet_traits<bfloat16> : default_packet_traits { 158 typedef Packet8bf type; 159 // There is no half-size packet for current Packet8bf. 160 // TODO: support as SSE path. 161 typedef Packet8bf half; 162 enum { 163 Vectorizable = 1, 164 AlignedOnScalar = 1, 165 size = 8, 166 HasHalfPacket = 0, 167 168 HasCmp = 1, 169 HasAdd = 1, 170 HasSub = 1, 171 HasMul = 1, 172 HasDiv = 1, 173 HasSin = EIGEN_FAST_MATH, 174 HasCos = EIGEN_FAST_MATH, 175 HasNegate = 1, 176 HasAbs = 1, 177 HasAbs2 = 0, 178 HasMin = 1, 179 HasMax = 1, 180 HasConj = 1, 181 HasSetLinear = 0, 182 HasLog = 1, 183 HasLog1p = 1, 184 HasExpm1 = 1, 185 HasExp = 1, 186 HasSqrt = 1, 187 HasRsqrt = 1, 188 HasTanh = EIGEN_FAST_MATH, 189 HasErf = EIGEN_FAST_MATH, 190 HasBlend = 0, 191 HasRound = 1, 192 HasFloor = 1, 193 HasCeil = 1, 194 HasRint = 1, 195 HasBessel = 1, 196 HasNdtri = 1 197 }; 198 }; 199 #endif 200 201 template<> struct scalar_div_cost<float,true> { enum { value = 14 }; }; 202 template<> struct scalar_div_cost<double,true> { enum { value = 16 }; }; 203 204 /* Proper support for integers is only provided by AVX2. In the meantime, we'll 205 use SSE instructions and packets to deal with integers. 206 template<> struct packet_traits<int> : default_packet_traits 207 { 208 typedef Packet8i type; 209 enum { 210 Vectorizable = 1, 211 AlignedOnScalar = 1, 212 size=8 213 }; 214 }; 215 */ 216 217 template<> struct unpacket_traits<Packet8f> { 218 typedef float type; 219 typedef Packet4f half; 220 typedef Packet8i integer_packet; 221 typedef uint8_t mask_t; 222 enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true}; 223 }; 224 template<> struct unpacket_traits<Packet4d> { 225 typedef double type; 226 typedef Packet2d half; 227 enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; 228 }; 229 template<> struct unpacket_traits<Packet8i> { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; }; 230 template<> struct unpacket_traits<Packet8bf> { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; 231 232 // Helper function for bit packing snippet of low precision comparison. 233 // It packs the flags from 16x16 to 8x16. 234 EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) { 235 return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), 236 _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); 237 } 238 239 240 template<> EIGEN_STRONG_INLINE Packet8f pset1<Packet8f>(const float& from) { return _mm256_set1_ps(from); } 241 template<> EIGEN_STRONG_INLINE Packet4d pset1<Packet4d>(const double& from) { return _mm256_set1_pd(from); } 242 template<> EIGEN_STRONG_INLINE Packet8i pset1<Packet8i>(const int& from) { return _mm256_set1_epi32(from); } 243 244 template<> EIGEN_STRONG_INLINE Packet8f pset1frombits<Packet8f>(unsigned int from) { return _mm256_castsi256_ps(pset1<Packet8i>(from)); } 245 template<> EIGEN_STRONG_INLINE Packet4d pset1frombits<Packet4d>(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); } 246 247 template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); } 248 template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); } 249 template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); } 250 251 252 template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return _mm256_castsi256_ps(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); } 253 template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); } 254 template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return _mm256_castsi256_pd(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); } 255 256 template<> EIGEN_STRONG_INLINE Packet8f pload1<Packet8f>(const float* from) { return _mm256_broadcast_ss(from); } 257 template<> EIGEN_STRONG_INLINE Packet4d pload1<Packet4d>(const double* from) { return _mm256_broadcast_sd(from); } 258 259 template<> EIGEN_STRONG_INLINE Packet8f plset<Packet8f>(const float& a) { return _mm256_add_ps(_mm256_set1_ps(a), _mm256_set_ps(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)); } 260 template<> EIGEN_STRONG_INLINE Packet4d plset<Packet4d>(const double& a) { return _mm256_add_pd(_mm256_set1_pd(a), _mm256_set_pd(3.0,2.0,1.0,0.0)); } 261 262 template<> EIGEN_STRONG_INLINE Packet8f padd<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); } 263 template<> EIGEN_STRONG_INLINE Packet4d padd<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); } 264 template<> EIGEN_STRONG_INLINE Packet8i padd<Packet8i>(const Packet8i& a, const Packet8i& b) { 265 #ifdef EIGEN_VECTORIZE_AVX2 266 return _mm256_add_epi32(a,b); 267 #else 268 __m128i lo = _mm_add_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); 269 __m128i hi = _mm_add_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); 270 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 271 #endif 272 } 273 274 template<> EIGEN_STRONG_INLINE Packet8f psub<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); } 275 template<> EIGEN_STRONG_INLINE Packet4d psub<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); } 276 template<> EIGEN_STRONG_INLINE Packet8i psub<Packet8i>(const Packet8i& a, const Packet8i& b) { 277 #ifdef EIGEN_VECTORIZE_AVX2 278 return _mm256_sub_epi32(a,b); 279 #else 280 __m128i lo = _mm_sub_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); 281 __m128i hi = _mm_sub_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); 282 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 283 #endif 284 } 285 286 template<> EIGEN_STRONG_INLINE Packet8f pnegate(const Packet8f& a) 287 { 288 return _mm256_sub_ps(_mm256_set1_ps(0.0),a); 289 } 290 template<> EIGEN_STRONG_INLINE Packet4d pnegate(const Packet4d& a) 291 { 292 return _mm256_sub_pd(_mm256_set1_pd(0.0),a); 293 } 294 295 template<> EIGEN_STRONG_INLINE Packet8f pconj(const Packet8f& a) { return a; } 296 template<> EIGEN_STRONG_INLINE Packet4d pconj(const Packet4d& a) { return a; } 297 template<> EIGEN_STRONG_INLINE Packet8i pconj(const Packet8i& a) { return a; } 298 299 template<> EIGEN_STRONG_INLINE Packet8f pmul<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); } 300 template<> EIGEN_STRONG_INLINE Packet4d pmul<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); } 301 template<> EIGEN_STRONG_INLINE Packet8i pmul<Packet8i>(const Packet8i& a, const Packet8i& b) { 302 #ifdef EIGEN_VECTORIZE_AVX2 303 return _mm256_mullo_epi32(a,b); 304 #else 305 const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); 306 const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); 307 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 308 #endif 309 } 310 311 template<> EIGEN_STRONG_INLINE Packet8f pdiv<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); } 312 template<> EIGEN_STRONG_INLINE Packet4d pdiv<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); } 313 template<> EIGEN_STRONG_INLINE Packet8i pdiv<Packet8i>(const Packet8i& /*a*/, const Packet8i& /*b*/) 314 { eigen_assert(false && "packet integer division are not supported by AVX"); 315 return pset1<Packet8i>(0); 316 } 317 318 #ifdef EIGEN_VECTORIZE_FMA 319 template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) { 320 #if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) 321 // Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers, 322 // and even register spilling with clang>=6.0 (bug 1637). 323 // Gcc stupidly generates a vfmadd132ps instruction. 324 // So let's enforce it to generate a vfmadd231ps instruction since the most common use 325 // case is to accumulate the result of the product. 326 Packet8f res = c; 327 __asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); 328 return res; 329 #else 330 return _mm256_fmadd_ps(a,b,c); 331 #endif 332 } 333 template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) { 334 #if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) 335 // see above 336 Packet4d res = c; 337 __asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); 338 return res; 339 #else 340 return _mm256_fmadd_pd(a,b,c); 341 #endif 342 } 343 #endif 344 345 template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); } 346 template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); } 347 template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } 348 template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); } 349 350 template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); } 351 template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); } 352 template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); } 353 template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); } 354 355 356 template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) { 357 #ifdef EIGEN_VECTORIZE_AVX2 358 return _mm256_cmpeq_epi32(a,b); 359 #else 360 __m128i lo = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); 361 __m128i hi = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); 362 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 363 #endif 364 } 365 366 template<> EIGEN_STRONG_INLINE Packet8f pmin<Packet8f>(const Packet8f& a, const Packet8f& b) { 367 #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 368 // There appears to be a bug in GCC, by which the optimizer may flip 369 // the argument order in calls to _mm_min_ps/_mm_max_ps, so we have to 370 // resort to inline ASM here. This is supposed to be fixed in gcc6.3, 371 // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 372 Packet8f res; 373 asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); 374 return res; 375 #else 376 // Arguments are swapped to match NaN propagation behavior of std::min. 377 return _mm256_min_ps(b,a); 378 #endif 379 } 380 template<> EIGEN_STRONG_INLINE Packet4d pmin<Packet4d>(const Packet4d& a, const Packet4d& b) { 381 #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 382 // See pmin above 383 Packet4d res; 384 asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); 385 return res; 386 #else 387 // Arguments are swapped to match NaN propagation behavior of std::min. 388 return _mm256_min_pd(b,a); 389 #endif 390 } 391 392 template<> EIGEN_STRONG_INLINE Packet8f pmax<Packet8f>(const Packet8f& a, const Packet8f& b) { 393 #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 394 // See pmin above 395 Packet8f res; 396 asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); 397 return res; 398 #else 399 // Arguments are swapped to match NaN propagation behavior of std::max. 400 return _mm256_max_ps(b,a); 401 #endif 402 } 403 template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const Packet4d& b) { 404 #if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 405 // See pmin above 406 Packet4d res; 407 asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); 408 return res; 409 #else 410 // Arguments are swapped to match NaN propagation behavior of std::max. 411 return _mm256_max_pd(b,a); 412 #endif 413 } 414 415 // Add specializations for min/max with prescribed NaN progation. 416 template<> 417 EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) { 418 return pminmax_propagate_numbers(a, b, pmin<Packet8f>); 419 } 420 template<> 421 EIGEN_STRONG_INLINE Packet4d pmin<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) { 422 return pminmax_propagate_numbers(a, b, pmin<Packet4d>); 423 } 424 template<> 425 EIGEN_STRONG_INLINE Packet8f pmax<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) { 426 return pminmax_propagate_numbers(a, b, pmax<Packet8f>); 427 } 428 template<> 429 EIGEN_STRONG_INLINE Packet4d pmax<PropagateNumbers, Packet4d>(const Packet4d& a, const Packet4d& b) { 430 return pminmax_propagate_numbers(a, b, pmax<Packet4d>); 431 } 432 template<> 433 EIGEN_STRONG_INLINE Packet8f pmin<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) { 434 return pminmax_propagate_nan(a, b, pmin<Packet8f>); 435 } 436 template<> 437 EIGEN_STRONG_INLINE Packet4d pmin<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) { 438 return pminmax_propagate_nan(a, b, pmin<Packet4d>); 439 } 440 template<> 441 EIGEN_STRONG_INLINE Packet8f pmax<PropagateNaN, Packet8f>(const Packet8f& a, const Packet8f& b) { 442 return pminmax_propagate_nan(a, b, pmax<Packet8f>); 443 } 444 template<> 445 EIGEN_STRONG_INLINE Packet4d pmax<PropagateNaN, Packet4d>(const Packet4d& a, const Packet4d& b) { 446 return pminmax_propagate_nan(a, b, pmax<Packet4d>); 447 } 448 449 template<> EIGEN_STRONG_INLINE Packet8f print<Packet8f>(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); } 450 template<> EIGEN_STRONG_INLINE Packet4d print<Packet4d>(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); } 451 452 template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); } 453 template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); } 454 455 template<> EIGEN_STRONG_INLINE Packet8f pfloor<Packet8f>(const Packet8f& a) { return _mm256_floor_ps(a); } 456 template<> EIGEN_STRONG_INLINE Packet4d pfloor<Packet4d>(const Packet4d& a) { return _mm256_floor_pd(a); } 457 458 459 template<> EIGEN_STRONG_INLINE Packet8i ptrue<Packet8i>(const Packet8i& a) { 460 #ifdef EIGEN_VECTORIZE_AVX2 461 // vpcmpeqd has lower latency than the more general vcmpps 462 return _mm256_cmpeq_epi32(a,a); 463 #else 464 const __m256 b = _mm256_castsi256_ps(a); 465 return _mm256_castps_si256(_mm256_cmp_ps(b,b,_CMP_TRUE_UQ)); 466 #endif 467 } 468 469 template<> EIGEN_STRONG_INLINE Packet8f ptrue<Packet8f>(const Packet8f& a) { 470 #ifdef EIGEN_VECTORIZE_AVX2 471 // vpcmpeqd has lower latency than the more general vcmpps 472 const __m256i b = _mm256_castps_si256(a); 473 return _mm256_castsi256_ps(_mm256_cmpeq_epi32(b,b)); 474 #else 475 return _mm256_cmp_ps(a,a,_CMP_TRUE_UQ); 476 #endif 477 } 478 479 template<> EIGEN_STRONG_INLINE Packet4d ptrue<Packet4d>(const Packet4d& a) { 480 #ifdef EIGEN_VECTORIZE_AVX2 481 // vpcmpeqq has lower latency than the more general vcmppd 482 const __m256i b = _mm256_castpd_si256(a); 483 return _mm256_castsi256_pd(_mm256_cmpeq_epi64(b,b)); 484 #else 485 return _mm256_cmp_pd(a,a,_CMP_TRUE_UQ); 486 #endif 487 } 488 489 template<> EIGEN_STRONG_INLINE Packet8f pand<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_and_ps(a,b); } 490 template<> EIGEN_STRONG_INLINE Packet4d pand<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_and_pd(a,b); } 491 template<> EIGEN_STRONG_INLINE Packet8i pand<Packet8i>(const Packet8i& a, const Packet8i& b) { 492 #ifdef EIGEN_VECTORIZE_AVX2 493 return _mm256_and_si256(a,b); 494 #else 495 return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); 496 #endif 497 } 498 499 template<> EIGEN_STRONG_INLINE Packet8f por<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_or_ps(a,b); } 500 template<> EIGEN_STRONG_INLINE Packet4d por<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_or_pd(a,b); } 501 template<> EIGEN_STRONG_INLINE Packet8i por<Packet8i>(const Packet8i& a, const Packet8i& b) { 502 #ifdef EIGEN_VECTORIZE_AVX2 503 return _mm256_or_si256(a,b); 504 #else 505 return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); 506 #endif 507 } 508 509 template<> EIGEN_STRONG_INLINE Packet8f pxor<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_xor_ps(a,b); } 510 template<> EIGEN_STRONG_INLINE Packet4d pxor<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_xor_pd(a,b); } 511 template<> EIGEN_STRONG_INLINE Packet8i pxor<Packet8i>(const Packet8i& a, const Packet8i& b) { 512 #ifdef EIGEN_VECTORIZE_AVX2 513 return _mm256_xor_si256(a,b); 514 #else 515 return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); 516 #endif 517 } 518 519 template<> EIGEN_STRONG_INLINE Packet8f pandnot<Packet8f>(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(b,a); } 520 template<> EIGEN_STRONG_INLINE Packet4d pandnot<Packet4d>(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(b,a); } 521 template<> EIGEN_STRONG_INLINE Packet8i pandnot<Packet8i>(const Packet8i& a, const Packet8i& b) { 522 #ifdef EIGEN_VECTORIZE_AVX2 523 return _mm256_andnot_si256(b,a); 524 #else 525 return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a))); 526 #endif 527 } 528 529 template<> EIGEN_STRONG_INLINE Packet8f pround<Packet8f>(const Packet8f& a) 530 { 531 const Packet8f mask = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x80000000u)); 532 const Packet8f prev0dot5 = pset1frombits<Packet8f>(static_cast<numext::uint32_t>(0x3EFFFFFFu)); 533 return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); 534 } 535 template<> EIGEN_STRONG_INLINE Packet4d pround<Packet4d>(const Packet4d& a) 536 { 537 const Packet4d mask = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x8000000000000000ull)); 538 const Packet4d prev0dot5 = pset1frombits<Packet4d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull)); 539 return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); 540 } 541 542 template<> EIGEN_STRONG_INLINE Packet8f pselect<Packet8f>(const Packet8f& mask, const Packet8f& a, const Packet8f& b) 543 { return _mm256_blendv_ps(b,a,mask); } 544 template<> EIGEN_STRONG_INLINE Packet4d pselect<Packet4d>(const Packet4d& mask, const Packet4d& a, const Packet4d& b) 545 { return _mm256_blendv_pd(b,a,mask); } 546 547 template<int N> EIGEN_STRONG_INLINE Packet8i parithmetic_shift_right(Packet8i a) { 548 #ifdef EIGEN_VECTORIZE_AVX2 549 return _mm256_srai_epi32(a, N); 550 #else 551 __m128i lo = _mm_srai_epi32(_mm256_extractf128_si256(a, 0), N); 552 __m128i hi = _mm_srai_epi32(_mm256_extractf128_si256(a, 1), N); 553 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 554 #endif 555 } 556 557 template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_right(Packet8i a) { 558 #ifdef EIGEN_VECTORIZE_AVX2 559 return _mm256_srli_epi32(a, N); 560 #else 561 __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(a, 0), N); 562 __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(a, 1), N); 563 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 564 #endif 565 } 566 567 template<int N> EIGEN_STRONG_INLINE Packet8i plogical_shift_left(Packet8i a) { 568 #ifdef EIGEN_VECTORIZE_AVX2 569 return _mm256_slli_epi32(a, N); 570 #else 571 __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(a, 0), N); 572 __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(a, 1), N); 573 return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); 574 #endif 575 } 576 577 template<> EIGEN_STRONG_INLINE Packet8f pload<Packet8f>(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ps(from); } 578 template<> EIGEN_STRONG_INLINE Packet4d pload<Packet4d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_pd(from); } 579 template<> EIGEN_STRONG_INLINE Packet8i pload<Packet8i>(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from)); } 580 581 template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ps(from); } 582 template<> EIGEN_STRONG_INLINE Packet4d ploadu<Packet4d>(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); } 583 template<> EIGEN_STRONG_INLINE Packet8i ploadu<Packet8i>(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from)); } 584 585 template<> EIGEN_STRONG_INLINE Packet8f ploadu<Packet8f>(const float* from, uint8_t umask) { 586 Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask)); 587 const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); 588 mask = por<Packet8i>(mask, bit_mask); 589 mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff)); 590 EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask); 591 } 592 593 // Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3} 594 template<> EIGEN_STRONG_INLINE Packet8f ploaddup<Packet8f>(const float* from) 595 { 596 // TODO try to find a way to avoid the need of a temporary register 597 // Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from)); 598 // tmp = _mm256_insertf128_ps(tmp, _mm_movehl_ps(_mm256_castps256_ps128(tmp),_mm256_castps256_ps128(tmp)), 1); 599 // return _mm256_unpacklo_ps(tmp,tmp); 600 601 // _mm256_insertf128_ps is very slow on Haswell, thus: 602 Packet8f tmp = _mm256_broadcast_ps((const __m128*)(const void*)from); 603 // mimic an "inplace" permutation of the lower 128bits using a blend 604 tmp = _mm256_blend_ps(tmp,_mm256_castps128_ps256(_mm_permute_ps( _mm256_castps256_ps128(tmp), _MM_SHUFFLE(1,0,1,0))), 15); 605 // then we can perform a consistent permutation on the global register to get everything in shape: 606 return _mm256_permute_ps(tmp, _MM_SHUFFLE(3,3,2,2)); 607 } 608 // Loads 2 doubles from memory a returns the packet {a0, a0 a1, a1} 609 template<> EIGEN_STRONG_INLINE Packet4d ploaddup<Packet4d>(const double* from) 610 { 611 Packet4d tmp = _mm256_broadcast_pd((const __m128d*)(const void*)from); 612 return _mm256_permute_pd(tmp, 3<<2); 613 } 614 615 // Loads 2 floats from memory a returns the packet {a0, a0 a0, a0, a1, a1, a1, a1} 616 template<> EIGEN_STRONG_INLINE Packet8f ploadquad<Packet8f>(const float* from) 617 { 618 Packet8f tmp = _mm256_castps128_ps256(_mm_broadcast_ss(from)); 619 return _mm256_insertf128_ps(tmp, _mm_broadcast_ss(from+1), 1); 620 } 621 622 template<> EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet8f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ps(to, from); } 623 template<> EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_pd(to, from); } 624 template<> EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } 625 626 template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ps(to, from); } 627 template<> EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); } 628 template<> EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } 629 630 template<> EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet8f& from, uint8_t umask) { 631 Packet8i mask = _mm256_set1_epi8(static_cast<char>(umask)); 632 const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); 633 mask = por<Packet8i>(mask, bit_mask); 634 mask = pcmp_eq<Packet8i>(mask, _mm256_set1_epi32(0xffffffff)); 635 EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from); 636 } 637 638 // NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available 639 // NOTE: for the record the following seems to be slower: return _mm256_i32gather_ps(from, _mm256_set1_epi32(stride), 4); 640 template<> EIGEN_DEVICE_FUNC inline Packet8f pgather<float, Packet8f>(const float* from, Index stride) 641 { 642 return _mm256_set_ps(from[7*stride], from[6*stride], from[5*stride], from[4*stride], 643 from[3*stride], from[2*stride], from[1*stride], from[0*stride]); 644 } 645 template<> EIGEN_DEVICE_FUNC inline Packet4d pgather<double, Packet4d>(const double* from, Index stride) 646 { 647 return _mm256_set_pd(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); 648 } 649 650 template<> EIGEN_DEVICE_FUNC inline void pscatter<float, Packet8f>(float* to, const Packet8f& from, Index stride) 651 { 652 __m128 low = _mm256_extractf128_ps(from, 0); 653 to[stride*0] = _mm_cvtss_f32(low); 654 to[stride*1] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1)); 655 to[stride*2] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)); 656 to[stride*3] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3)); 657 658 __m128 high = _mm256_extractf128_ps(from, 1); 659 to[stride*4] = _mm_cvtss_f32(high); 660 to[stride*5] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1)); 661 to[stride*6] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)); 662 to[stride*7] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3)); 663 } 664 template<> EIGEN_DEVICE_FUNC inline void pscatter<double, Packet4d>(double* to, const Packet4d& from, Index stride) 665 { 666 __m128d low = _mm256_extractf128_pd(from, 0); 667 to[stride*0] = _mm_cvtsd_f64(low); 668 to[stride*1] = _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1)); 669 __m128d high = _mm256_extractf128_pd(from, 1); 670 to[stride*2] = _mm_cvtsd_f64(high); 671 to[stride*3] = _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1)); 672 } 673 674 template<> EIGEN_STRONG_INLINE void pstore1<Packet8f>(float* to, const float& a) 675 { 676 Packet8f pa = pset1<Packet8f>(a); 677 pstore(to, pa); 678 } 679 template<> EIGEN_STRONG_INLINE void pstore1<Packet4d>(double* to, const double& a) 680 { 681 Packet4d pa = pset1<Packet4d>(a); 682 pstore(to, pa); 683 } 684 template<> EIGEN_STRONG_INLINE void pstore1<Packet8i>(int* to, const int& a) 685 { 686 Packet8i pa = pset1<Packet8i>(a); 687 pstore(to, pa); 688 } 689 690 #ifndef EIGEN_VECTORIZE_AVX512 691 template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 692 template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 693 template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } 694 #endif 695 696 template<> EIGEN_STRONG_INLINE float pfirst<Packet8f>(const Packet8f& a) { 697 return _mm_cvtss_f32(_mm256_castps256_ps128(a)); 698 } 699 template<> EIGEN_STRONG_INLINE double pfirst<Packet4d>(const Packet4d& a) { 700 return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); 701 } 702 template<> EIGEN_STRONG_INLINE int pfirst<Packet8i>(const Packet8i& a) { 703 return _mm_cvtsi128_si32(_mm256_castsi256_si128(a)); 704 } 705 706 707 template<> EIGEN_STRONG_INLINE Packet8f preverse(const Packet8f& a) 708 { 709 __m256 tmp = _mm256_shuffle_ps(a,a,0x1b); 710 return _mm256_permute2f128_ps(tmp, tmp, 1); 711 } 712 template<> EIGEN_STRONG_INLINE Packet4d preverse(const Packet4d& a) 713 { 714 __m256d tmp = _mm256_shuffle_pd(a,a,5); 715 return _mm256_permute2f128_pd(tmp, tmp, 1); 716 #if 0 717 // This version is unlikely to be faster as _mm256_shuffle_ps and _mm256_permute_pd 718 // exhibit the same latency/throughput, but it is here for future reference/benchmarking... 719 __m256d swap_halves = _mm256_permute2f128_pd(a,a,1); 720 return _mm256_permute_pd(swap_halves,5); 721 #endif 722 } 723 724 // pabs should be ok 725 template<> EIGEN_STRONG_INLINE Packet8f pabs(const Packet8f& a) 726 { 727 const Packet8f mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF)); 728 return _mm256_and_ps(a,mask); 729 } 730 template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a) 731 { 732 const Packet4d mask = _mm256_castsi256_pd(_mm256_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF)); 733 return _mm256_and_pd(a,mask); 734 } 735 736 template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) { 737 return pfrexp_generic(a,exponent); 738 } 739 740 // Extract exponent without existence of Packet4l. 741 template<> 742 EIGEN_STRONG_INLINE 743 Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) { 744 const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull)); 745 __m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask)); 746 #ifdef EIGEN_VECTORIZE_AVX2 747 a_expo = _mm256_srli_epi64(a_expo, 52); 748 __m128i lo = _mm256_extractf128_si256(a_expo, 0); 749 __m128i hi = _mm256_extractf128_si256(a_expo, 1); 750 #else 751 __m128i lo = _mm256_extractf128_si256(a_expo, 0); 752 __m128i hi = _mm256_extractf128_si256(a_expo, 1); 753 lo = _mm_srli_epi64(lo, 52); 754 hi = _mm_srli_epi64(hi, 52); 755 #endif 756 Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3)); 757 Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3)); 758 Packet4d exponent = _mm256_insertf128_pd(_mm256_setzero_pd(), exponent_lo, 0); 759 exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1); 760 return exponent; 761 } 762 763 764 template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) { 765 return pfrexp_generic(a, exponent); 766 } 767 768 template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) { 769 return pldexp_generic(a, exponent); 770 } 771 772 template<> EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d& exponent) { 773 // Clamp exponent to [-2099, 2099] 774 const Packet4d max_exponent = pset1<Packet4d>(2099.0); 775 const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); 776 777 // Split 2^e into four factors and multiply. 778 const Packet4i bias = pset1<Packet4i>(1023); 779 Packet4i b = parithmetic_shift_right<2>(e); // floor(e/4) 780 781 // 2^b 782 Packet4i hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3); 783 Packet4i lo = _mm_slli_epi64(hi, 52); 784 hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52); 785 Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1)); 786 Packet4d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) 787 788 // 2^(e - 3b) 789 b = psub(psub(psub(e, b), b), b); // e - 3b 790 hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3); 791 lo = _mm_slli_epi64(hi, 52); 792 hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52); 793 c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1)); 794 out = pmul(out, c); // a * 2^e 795 return out; 796 } 797 798 template<> EIGEN_STRONG_INLINE float predux<Packet8f>(const Packet8f& a) 799 { 800 return predux(Packet4f(_mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1)))); 801 } 802 template<> EIGEN_STRONG_INLINE double predux<Packet4d>(const Packet4d& a) 803 { 804 return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a),_mm256_extractf128_pd(a,1)))); 805 } 806 807 template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4<Packet8f>(const Packet8f& a) 808 { 809 return _mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1)); 810 } 811 812 template<> EIGEN_STRONG_INLINE float predux_mul<Packet8f>(const Packet8f& a) 813 { 814 Packet8f tmp; 815 tmp = _mm256_mul_ps(a, _mm256_permute2f128_ps(a,a,1)); 816 tmp = _mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); 817 return pfirst(_mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); 818 } 819 template<> EIGEN_STRONG_INLINE double predux_mul<Packet4d>(const Packet4d& a) 820 { 821 Packet4d tmp; 822 tmp = _mm256_mul_pd(a, _mm256_permute2f128_pd(a,a,1)); 823 return pfirst(_mm256_mul_pd(tmp, _mm256_shuffle_pd(tmp,tmp,1))); 824 } 825 826 template<> EIGEN_STRONG_INLINE float predux_min<Packet8f>(const Packet8f& a) 827 { 828 Packet8f tmp = _mm256_min_ps(a, _mm256_permute2f128_ps(a,a,1)); 829 tmp = _mm256_min_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); 830 return pfirst(_mm256_min_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); 831 } 832 template<> EIGEN_STRONG_INLINE double predux_min<Packet4d>(const Packet4d& a) 833 { 834 Packet4d tmp = _mm256_min_pd(a, _mm256_permute2f128_pd(a,a,1)); 835 return pfirst(_mm256_min_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1))); 836 } 837 838 template<> EIGEN_STRONG_INLINE float predux_max<Packet8f>(const Packet8f& a) 839 { 840 Packet8f tmp = _mm256_max_ps(a, _mm256_permute2f128_ps(a,a,1)); 841 tmp = _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); 842 return pfirst(_mm256_max_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); 843 } 844 845 template<> EIGEN_STRONG_INLINE double predux_max<Packet4d>(const Packet4d& a) 846 { 847 Packet4d tmp = _mm256_max_pd(a, _mm256_permute2f128_pd(a,a,1)); 848 return pfirst(_mm256_max_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1))); 849 } 850 851 // not needed yet 852 // template<> EIGEN_STRONG_INLINE bool predux_all(const Packet8f& x) 853 // { 854 // return _mm256_movemask_ps(x)==0xFF; 855 // } 856 857 template<> EIGEN_STRONG_INLINE bool predux_any(const Packet8f& x) 858 { 859 return _mm256_movemask_ps(x)!=0; 860 } 861 862 EIGEN_DEVICE_FUNC inline void 863 ptranspose(PacketBlock<Packet8f,8>& kernel) { 864 __m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]); 865 __m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]); 866 __m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]); 867 __m256 T3 = _mm256_unpackhi_ps(kernel.packet[2], kernel.packet[3]); 868 __m256 T4 = _mm256_unpacklo_ps(kernel.packet[4], kernel.packet[5]); 869 __m256 T5 = _mm256_unpackhi_ps(kernel.packet[4], kernel.packet[5]); 870 __m256 T6 = _mm256_unpacklo_ps(kernel.packet[6], kernel.packet[7]); 871 __m256 T7 = _mm256_unpackhi_ps(kernel.packet[6], kernel.packet[7]); 872 __m256 S0 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(1,0,1,0)); 873 __m256 S1 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(3,2,3,2)); 874 __m256 S2 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(1,0,1,0)); 875 __m256 S3 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(3,2,3,2)); 876 __m256 S4 = _mm256_shuffle_ps(T4,T6,_MM_SHUFFLE(1,0,1,0)); 877 __m256 S5 = _mm256_shuffle_ps(T4,T6,_MM_SHUFFLE(3,2,3,2)); 878 __m256 S6 = _mm256_shuffle_ps(T5,T7,_MM_SHUFFLE(1,0,1,0)); 879 __m256 S7 = _mm256_shuffle_ps(T5,T7,_MM_SHUFFLE(3,2,3,2)); 880 kernel.packet[0] = _mm256_permute2f128_ps(S0, S4, 0x20); 881 kernel.packet[1] = _mm256_permute2f128_ps(S1, S5, 0x20); 882 kernel.packet[2] = _mm256_permute2f128_ps(S2, S6, 0x20); 883 kernel.packet[3] = _mm256_permute2f128_ps(S3, S7, 0x20); 884 kernel.packet[4] = _mm256_permute2f128_ps(S0, S4, 0x31); 885 kernel.packet[5] = _mm256_permute2f128_ps(S1, S5, 0x31); 886 kernel.packet[6] = _mm256_permute2f128_ps(S2, S6, 0x31); 887 kernel.packet[7] = _mm256_permute2f128_ps(S3, S7, 0x31); 888 } 889 890 EIGEN_DEVICE_FUNC inline void 891 ptranspose(PacketBlock<Packet8f,4>& kernel) { 892 __m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]); 893 __m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]); 894 __m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]); 895 __m256 T3 = _mm256_unpackhi_ps(kernel.packet[2], kernel.packet[3]); 896 897 __m256 S0 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(1,0,1,0)); 898 __m256 S1 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(3,2,3,2)); 899 __m256 S2 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(1,0,1,0)); 900 __m256 S3 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(3,2,3,2)); 901 902 kernel.packet[0] = _mm256_permute2f128_ps(S0, S1, 0x20); 903 kernel.packet[1] = _mm256_permute2f128_ps(S2, S3, 0x20); 904 kernel.packet[2] = _mm256_permute2f128_ps(S0, S1, 0x31); 905 kernel.packet[3] = _mm256_permute2f128_ps(S2, S3, 0x31); 906 } 907 908 EIGEN_DEVICE_FUNC inline void 909 ptranspose(PacketBlock<Packet4d,4>& kernel) { 910 __m256d T0 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 15); 911 __m256d T1 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 0); 912 __m256d T2 = _mm256_shuffle_pd(kernel.packet[2], kernel.packet[3], 15); 913 __m256d T3 = _mm256_shuffle_pd(kernel.packet[2], kernel.packet[3], 0); 914 915 kernel.packet[1] = _mm256_permute2f128_pd(T0, T2, 32); 916 kernel.packet[3] = _mm256_permute2f128_pd(T0, T2, 49); 917 kernel.packet[0] = _mm256_permute2f128_pd(T1, T3, 32); 918 kernel.packet[2] = _mm256_permute2f128_pd(T1, T3, 49); 919 } 920 921 template<> EIGEN_STRONG_INLINE Packet8f pblend(const Selector<8>& ifPacket, const Packet8f& thenPacket, const Packet8f& elsePacket) { 922 const __m256 zero = _mm256_setzero_ps(); 923 const __m256 select = _mm256_set_ps(ifPacket.select[7], ifPacket.select[6], ifPacket.select[5], ifPacket.select[4], ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); 924 __m256 false_mask = _mm256_cmp_ps(select, zero, _CMP_EQ_UQ); 925 return _mm256_blendv_ps(thenPacket, elsePacket, false_mask); 926 } 927 template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, const Packet4d& thenPacket, const Packet4d& elsePacket) { 928 const __m256d zero = _mm256_setzero_pd(); 929 const __m256d select = _mm256_set_pd(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); 930 __m256d false_mask = _mm256_cmp_pd(select, zero, _CMP_EQ_UQ); 931 return _mm256_blendv_pd(thenPacket, elsePacket, false_mask); 932 } 933 934 // Packet math for Eigen::half 935 936 template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; }; 937 938 template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) { 939 return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); 940 } 941 942 template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8h>(const Packet8h& from) { 943 return numext::bit_cast<Eigen::half>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); 944 } 945 946 template<> EIGEN_STRONG_INLINE Packet8h pload<Packet8h>(const Eigen::half* from) { 947 return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); 948 } 949 950 template<> EIGEN_STRONG_INLINE Packet8h ploadu<Packet8h>(const Eigen::half* from) { 951 return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); 952 } 953 954 template<> EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8h& from) { 955 _mm_store_si128(reinterpret_cast<__m128i*>(to), from); 956 } 957 958 template<> EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8h& from) { 959 _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); 960 } 961 962 template<> EIGEN_STRONG_INLINE Packet8h 963 ploaddup<Packet8h>(const Eigen::half* from) { 964 const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); 965 const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); 966 const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); 967 const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); 968 return _mm_set_epi16(d, d, c, c, b, b, a, a); 969 } 970 971 template<> EIGEN_STRONG_INLINE Packet8h 972 ploadquad<Packet8h>(const Eigen::half* from) { 973 const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); 974 const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); 975 return _mm_set_epi16(b, b, b, b, a, a, a, a); 976 } 977 978 template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) { 979 return _mm_cmpeq_epi32(a, a); 980 } 981 982 template <> 983 EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) { 984 const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); 985 return _mm_andnot_si128(sign_mask, a); 986 } 987 988 EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { 989 #ifdef EIGEN_HAS_FP16_C 990 return _mm256_cvtph_ps(a); 991 #else 992 EIGEN_ALIGN32 Eigen::half aux[8]; 993 pstore(aux, a); 994 float f0(aux[0]); 995 float f1(aux[1]); 996 float f2(aux[2]); 997 float f3(aux[3]); 998 float f4(aux[4]); 999 float f5(aux[5]); 1000 float f6(aux[6]); 1001 float f7(aux[7]); 1002 1003 return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0); 1004 #endif 1005 } 1006 1007 EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { 1008 #ifdef EIGEN_HAS_FP16_C 1009 return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC); 1010 #else 1011 EIGEN_ALIGN32 float aux[8]; 1012 pstore(aux, a); 1013 const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0])); 1014 const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1])); 1015 const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2])); 1016 const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3])); 1017 const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4])); 1018 const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5])); 1019 const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6])); 1020 const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7])); 1021 return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); 1022 #endif 1023 } 1024 1025 template <> 1026 EIGEN_STRONG_INLINE Packet8h pmin<Packet8h>(const Packet8h& a, 1027 const Packet8h& b) { 1028 return float2half(pmin<Packet8f>(half2float(a), half2float(b))); 1029 } 1030 1031 template <> 1032 EIGEN_STRONG_INLINE Packet8h pmax<Packet8h>(const Packet8h& a, 1033 const Packet8h& b) { 1034 return float2half(pmax<Packet8f>(half2float(a), half2float(b))); 1035 } 1036 1037 template <> 1038 EIGEN_STRONG_INLINE Packet8h plset<Packet8h>(const half& a) { 1039 return float2half(plset<Packet8f>(static_cast<float>(a))); 1040 } 1041 1042 template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) { 1043 // in some cases Packet4i is a wrapper around __m128i, so we either need to 1044 // cast to Packet4i to directly call the intrinsics as below: 1045 return _mm_or_si128(a,b); 1046 } 1047 template<> EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a,const Packet8h& b) { 1048 return _mm_xor_si128(a,b); 1049 } 1050 template<> EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a,const Packet8h& b) { 1051 return _mm_and_si128(a,b); 1052 } 1053 template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h& b) { 1054 return _mm_andnot_si128(b,a); 1055 } 1056 1057 template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) { 1058 return _mm_blendv_epi8(b, a, mask); 1059 } 1060 1061 template<> EIGEN_STRONG_INLINE Packet8h pround<Packet8h>(const Packet8h& a) { 1062 return float2half(pround<Packet8f>(half2float(a))); 1063 } 1064 1065 template<> EIGEN_STRONG_INLINE Packet8h print<Packet8h>(const Packet8h& a) { 1066 return float2half(print<Packet8f>(half2float(a))); 1067 } 1068 1069 template<> EIGEN_STRONG_INLINE Packet8h pceil<Packet8h>(const Packet8h& a) { 1070 return float2half(pceil<Packet8f>(half2float(a))); 1071 } 1072 1073 template<> EIGEN_STRONG_INLINE Packet8h pfloor<Packet8h>(const Packet8h& a) { 1074 return float2half(pfloor<Packet8f>(half2float(a))); 1075 } 1076 1077 template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) { 1078 return Pack16To8(pcmp_eq(half2float(a), half2float(b))); 1079 } 1080 1081 template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) { 1082 return Pack16To8(pcmp_le(half2float(a), half2float(b))); 1083 } 1084 1085 template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) { 1086 return Pack16To8(pcmp_lt(half2float(a), half2float(b))); 1087 } 1088 1089 template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) { 1090 return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b))); 1091 } 1092 1093 template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } 1094 1095 template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { 1096 Packet8h sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); 1097 return _mm_xor_si128(a, sign_mask); 1098 } 1099 1100 template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) { 1101 Packet8f af = half2float(a); 1102 Packet8f bf = half2float(b); 1103 Packet8f rf = padd(af, bf); 1104 return float2half(rf); 1105 } 1106 1107 template<> EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) { 1108 Packet8f af = half2float(a); 1109 Packet8f bf = half2float(b); 1110 Packet8f rf = psub(af, bf); 1111 return float2half(rf); 1112 } 1113 1114 template<> EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) { 1115 Packet8f af = half2float(a); 1116 Packet8f bf = half2float(b); 1117 Packet8f rf = pmul(af, bf); 1118 return float2half(rf); 1119 } 1120 1121 template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) { 1122 Packet8f af = half2float(a); 1123 Packet8f bf = half2float(b); 1124 Packet8f rf = pdiv(af, bf); 1125 return float2half(rf); 1126 } 1127 1128 template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride) 1129 { 1130 const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); 1131 const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); 1132 const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); 1133 const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); 1134 const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); 1135 const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); 1136 const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); 1137 const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); 1138 return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); 1139 } 1140 1141 template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const Packet8h& from, Index stride) 1142 { 1143 EIGEN_ALIGN32 Eigen::half aux[8]; 1144 pstore(aux, from); 1145 to[stride*0] = aux[0]; 1146 to[stride*1] = aux[1]; 1147 to[stride*2] = aux[2]; 1148 to[stride*3] = aux[3]; 1149 to[stride*4] = aux[4]; 1150 to[stride*5] = aux[5]; 1151 to[stride*6] = aux[6]; 1152 to[stride*7] = aux[7]; 1153 } 1154 1155 template<> EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) { 1156 Packet8f af = half2float(a); 1157 float reduced = predux<Packet8f>(af); 1158 return Eigen::half(reduced); 1159 } 1160 1161 template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) { 1162 Packet8f af = half2float(a); 1163 float reduced = predux_max<Packet8f>(af); 1164 return Eigen::half(reduced); 1165 } 1166 1167 template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8h>(const Packet8h& a) { 1168 Packet8f af = half2float(a); 1169 float reduced = predux_min<Packet8f>(af); 1170 return Eigen::half(reduced); 1171 } 1172 1173 template<> EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8h>(const Packet8h& a) { 1174 Packet8f af = half2float(a); 1175 float reduced = predux_mul<Packet8f>(af); 1176 return Eigen::half(reduced); 1177 } 1178 1179 template<> EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) 1180 { 1181 __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); 1182 return _mm_shuffle_epi8(a,m); 1183 } 1184 1185 EIGEN_STRONG_INLINE void 1186 ptranspose(PacketBlock<Packet8h,8>& kernel) { 1187 __m128i a = kernel.packet[0]; 1188 __m128i b = kernel.packet[1]; 1189 __m128i c = kernel.packet[2]; 1190 __m128i d = kernel.packet[3]; 1191 __m128i e = kernel.packet[4]; 1192 __m128i f = kernel.packet[5]; 1193 __m128i g = kernel.packet[6]; 1194 __m128i h = kernel.packet[7]; 1195 1196 __m128i a03b03 = _mm_unpacklo_epi16(a, b); 1197 __m128i c03d03 = _mm_unpacklo_epi16(c, d); 1198 __m128i e03f03 = _mm_unpacklo_epi16(e, f); 1199 __m128i g03h03 = _mm_unpacklo_epi16(g, h); 1200 __m128i a47b47 = _mm_unpackhi_epi16(a, b); 1201 __m128i c47d47 = _mm_unpackhi_epi16(c, d); 1202 __m128i e47f47 = _mm_unpackhi_epi16(e, f); 1203 __m128i g47h47 = _mm_unpackhi_epi16(g, h); 1204 1205 __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); 1206 __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); 1207 __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); 1208 __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); 1209 __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); 1210 __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); 1211 __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); 1212 __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); 1213 1214 __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); 1215 __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); 1216 __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); 1217 __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); 1218 __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); 1219 __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); 1220 __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); 1221 __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); 1222 1223 kernel.packet[0] = a0b0c0d0e0f0g0h0; 1224 kernel.packet[1] = a1b1c1d1e1f1g1h1; 1225 kernel.packet[2] = a2b2c2d2e2f2g2h2; 1226 kernel.packet[3] = a3b3c3d3e3f3g3h3; 1227 kernel.packet[4] = a4b4c4d4e4f4g4h4; 1228 kernel.packet[5] = a5b5c5d5e5f5g5h5; 1229 kernel.packet[6] = a6b6c6d6e6f6g6h6; 1230 kernel.packet[7] = a7b7c7d7e7f7g7h7; 1231 } 1232 1233 EIGEN_STRONG_INLINE void 1234 ptranspose(PacketBlock<Packet8h,4>& kernel) { 1235 EIGEN_ALIGN32 Eigen::half in[4][8]; 1236 pstore<Eigen::half>(in[0], kernel.packet[0]); 1237 pstore<Eigen::half>(in[1], kernel.packet[1]); 1238 pstore<Eigen::half>(in[2], kernel.packet[2]); 1239 pstore<Eigen::half>(in[3], kernel.packet[3]); 1240 1241 EIGEN_ALIGN32 Eigen::half out[4][8]; 1242 1243 for (int i = 0; i < 4; ++i) { 1244 for (int j = 0; j < 4; ++j) { 1245 out[i][j] = in[j][2*i]; 1246 } 1247 for (int j = 0; j < 4; ++j) { 1248 out[i][j+4] = in[j][2*i+1]; 1249 } 1250 } 1251 1252 kernel.packet[0] = pload<Packet8h>(out[0]); 1253 kernel.packet[1] = pload<Packet8h>(out[1]); 1254 kernel.packet[2] = pload<Packet8h>(out[2]); 1255 kernel.packet[3] = pload<Packet8h>(out[3]); 1256 } 1257 1258 // BFloat16 implementation. 1259 1260 EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { 1261 #ifdef EIGEN_VECTORIZE_AVX2 1262 __m256i extend = _mm256_cvtepu16_epi32(a); 1263 return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16)); 1264 #else 1265 __m128i lo = _mm_cvtepu16_epi32(a); 1266 __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8)); 1267 __m128i lo_shift = _mm_slli_epi32(lo, 16); 1268 __m128i hi_shift = _mm_slli_epi32(hi, 16); 1269 return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1)); 1270 #endif 1271 } 1272 1273 // Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. 1274 EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { 1275 Packet8bf r; 1276 1277 __m256i input = _mm256_castps_si256(a); 1278 1279 #ifdef EIGEN_VECTORIZE_AVX2 1280 // uint32_t lsb = (input >> 16); 1281 __m256i t = _mm256_srli_epi32(input, 16); 1282 // uint32_t lsb = lsb & 1; 1283 t = _mm256_and_si256(t, _mm256_set1_epi32(1)); 1284 // uint32_t rounding_bias = 0x7fff + lsb; 1285 t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff)); 1286 // input += rounding_bias; 1287 t = _mm256_add_epi32(t, input); 1288 // input = input >> 16; 1289 t = _mm256_srli_epi32(t, 16); 1290 // Check NaN before converting back to bf16 1291 __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); 1292 __m256i nan = _mm256_set1_epi32(0x7fc0); 1293 t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); 1294 // output = numext::bit_cast<uint16_t>(input); 1295 return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), 1296 _mm256_extractf128_si256(t, 1)); 1297 #else 1298 // uint32_t lsb = (input >> 16); 1299 __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16); 1300 __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16); 1301 // uint32_t lsb = lsb & 1; 1302 lo = _mm_and_si128(lo, _mm_set1_epi32(1)); 1303 hi = _mm_and_si128(hi, _mm_set1_epi32(1)); 1304 // uint32_t rounding_bias = 0x7fff + lsb; 1305 lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff)); 1306 hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff)); 1307 // input += rounding_bias; 1308 lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0)); 1309 hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1)); 1310 // input = input >> 16; 1311 lo = _mm_srli_epi32(lo, 16); 1312 hi = _mm_srli_epi32(hi, 16); 1313 // Check NaN before converting back to bf16 1314 __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); 1315 __m128i nan = _mm_set1_epi32(0x7fc0); 1316 lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); 1317 hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); 1318 // output = numext::bit_cast<uint16_t>(input); 1319 return _mm_packus_epi32(lo, hi); 1320 #endif 1321 } 1322 1323 template<> EIGEN_STRONG_INLINE Packet8bf pset1<Packet8bf>(const bfloat16& from) { 1324 return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from)); 1325 } 1326 1327 template<> EIGEN_STRONG_INLINE bfloat16 pfirst<Packet8bf>(const Packet8bf& from) { 1328 return numext::bit_cast<bfloat16>(static_cast<numext::uint16_t>(_mm_extract_epi16(from, 0))); 1329 } 1330 1331 template<> EIGEN_STRONG_INLINE Packet8bf pload<Packet8bf>(const bfloat16* from) { 1332 return _mm_load_si128(reinterpret_cast<const __m128i*>(from)); 1333 } 1334 1335 template<> EIGEN_STRONG_INLINE Packet8bf ploadu<Packet8bf>(const bfloat16* from) { 1336 return _mm_loadu_si128(reinterpret_cast<const __m128i*>(from)); 1337 } 1338 1339 template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from) { 1340 _mm_store_si128(reinterpret_cast<__m128i*>(to), from); 1341 } 1342 1343 template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from) { 1344 _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); 1345 } 1346 1347 template<> EIGEN_STRONG_INLINE Packet8bf 1348 ploaddup<Packet8bf>(const bfloat16* from) { 1349 const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); 1350 const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); 1351 const numext::uint16_t c = numext::bit_cast<numext::uint16_t>(from[2]); 1352 const numext::uint16_t d = numext::bit_cast<numext::uint16_t>(from[3]); 1353 return _mm_set_epi16(d, d, c, c, b, b, a, a); 1354 } 1355 1356 template<> EIGEN_STRONG_INLINE Packet8bf 1357 ploadquad<Packet8bf>(const bfloat16* from) { 1358 const numext::uint16_t a = numext::bit_cast<numext::uint16_t>(from[0]); 1359 const numext::uint16_t b = numext::bit_cast<numext::uint16_t>(from[1]); 1360 return _mm_set_epi16(b, b, b, b, a, a, a, a); 1361 } 1362 1363 template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) { 1364 return _mm_cmpeq_epi32(a, a); 1365 } 1366 1367 template <> 1368 EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { 1369 const __m128i sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); 1370 return _mm_andnot_si128(sign_mask, a); 1371 } 1372 1373 template <> 1374 EIGEN_STRONG_INLINE Packet8bf pmin<Packet8bf>(const Packet8bf& a, 1375 const Packet8bf& b) { 1376 return F32ToBf16(pmin<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1377 } 1378 1379 template <> 1380 EIGEN_STRONG_INLINE Packet8bf pmax<Packet8bf>(const Packet8bf& a, 1381 const Packet8bf& b) { 1382 return F32ToBf16(pmax<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1383 } 1384 1385 template <> 1386 EIGEN_STRONG_INLINE Packet8bf plset<Packet8bf>(const bfloat16& a) { 1387 return F32ToBf16(plset<Packet8f>(static_cast<float>(a))); 1388 } 1389 1390 template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) { 1391 return _mm_or_si128(a,b); 1392 } 1393 template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) { 1394 return _mm_xor_si128(a,b); 1395 } 1396 template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) { 1397 return _mm_and_si128(a,b); 1398 } 1399 template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) { 1400 return _mm_andnot_si128(b,a); 1401 } 1402 1403 template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) { 1404 return _mm_blendv_epi8(b, a, mask); 1405 } 1406 1407 template<> EIGEN_STRONG_INLINE Packet8bf pround<Packet8bf>(const Packet8bf& a) 1408 { 1409 return F32ToBf16(pround<Packet8f>(Bf16ToF32(a))); 1410 } 1411 1412 template<> EIGEN_STRONG_INLINE Packet8bf print<Packet8bf>(const Packet8bf& a) { 1413 return F32ToBf16(print<Packet8f>(Bf16ToF32(a))); 1414 } 1415 1416 template<> EIGEN_STRONG_INLINE Packet8bf pceil<Packet8bf>(const Packet8bf& a) { 1417 return F32ToBf16(pceil<Packet8f>(Bf16ToF32(a))); 1418 } 1419 1420 template<> EIGEN_STRONG_INLINE Packet8bf pfloor<Packet8bf>(const Packet8bf& a) { 1421 return F32ToBf16(pfloor<Packet8f>(Bf16ToF32(a))); 1422 } 1423 1424 template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) { 1425 return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); 1426 } 1427 1428 template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) { 1429 return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); 1430 } 1431 1432 template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) { 1433 return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); 1434 } 1435 1436 template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) { 1437 return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); 1438 } 1439 1440 template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } 1441 1442 template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { 1443 Packet8bf sign_mask = _mm_set1_epi16(static_cast<numext::uint16_t>(0x8000)); 1444 return _mm_xor_si128(a, sign_mask); 1445 } 1446 1447 template<> EIGEN_STRONG_INLINE Packet8bf padd<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { 1448 return F32ToBf16(padd<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1449 } 1450 1451 template<> EIGEN_STRONG_INLINE Packet8bf psub<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { 1452 return F32ToBf16(psub<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1453 } 1454 1455 template<> EIGEN_STRONG_INLINE Packet8bf pmul<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { 1456 return F32ToBf16(pmul<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1457 } 1458 1459 template<> EIGEN_STRONG_INLINE Packet8bf pdiv<Packet8bf>(const Packet8bf& a, const Packet8bf& b) { 1460 return F32ToBf16(pdiv<Packet8f>(Bf16ToF32(a), Bf16ToF32(b))); 1461 } 1462 1463 1464 template<> EIGEN_STRONG_INLINE Packet8bf pgather<bfloat16, Packet8bf>(const bfloat16* from, Index stride) 1465 { 1466 const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(from[0*stride]); 1467 const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(from[1*stride]); 1468 const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(from[2*stride]); 1469 const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(from[3*stride]); 1470 const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(from[4*stride]); 1471 const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(from[5*stride]); 1472 const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(from[6*stride]); 1473 const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(from[7*stride]); 1474 return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); 1475 } 1476 1477 template<> EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packet8bf& from, Index stride) 1478 { 1479 EIGEN_ALIGN32 bfloat16 aux[8]; 1480 pstore(aux, from); 1481 to[stride*0] = aux[0]; 1482 to[stride*1] = aux[1]; 1483 to[stride*2] = aux[2]; 1484 to[stride*3] = aux[3]; 1485 to[stride*4] = aux[4]; 1486 to[stride*5] = aux[5]; 1487 to[stride*6] = aux[6]; 1488 to[stride*7] = aux[7]; 1489 } 1490 1491 template<> EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) { 1492 return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a))); 1493 } 1494 1495 template<> EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) { 1496 return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a))); 1497 } 1498 1499 template<> EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) { 1500 return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a))); 1501 } 1502 1503 template<> EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) { 1504 return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a))); 1505 } 1506 1507 template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) 1508 { 1509 __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); 1510 return _mm_shuffle_epi8(a,m); 1511 } 1512 1513 EIGEN_STRONG_INLINE void 1514 ptranspose(PacketBlock<Packet8bf,8>& kernel) { 1515 __m128i a = kernel.packet[0]; 1516 __m128i b = kernel.packet[1]; 1517 __m128i c = kernel.packet[2]; 1518 __m128i d = kernel.packet[3]; 1519 __m128i e = kernel.packet[4]; 1520 __m128i f = kernel.packet[5]; 1521 __m128i g = kernel.packet[6]; 1522 __m128i h = kernel.packet[7]; 1523 1524 __m128i a03b03 = _mm_unpacklo_epi16(a, b); 1525 __m128i c03d03 = _mm_unpacklo_epi16(c, d); 1526 __m128i e03f03 = _mm_unpacklo_epi16(e, f); 1527 __m128i g03h03 = _mm_unpacklo_epi16(g, h); 1528 __m128i a47b47 = _mm_unpackhi_epi16(a, b); 1529 __m128i c47d47 = _mm_unpackhi_epi16(c, d); 1530 __m128i e47f47 = _mm_unpackhi_epi16(e, f); 1531 __m128i g47h47 = _mm_unpackhi_epi16(g, h); 1532 1533 __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); 1534 __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); 1535 __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); 1536 __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); 1537 __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); 1538 __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); 1539 __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); 1540 __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); 1541 1542 kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); 1543 kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); 1544 kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); 1545 kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); 1546 kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); 1547 kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); 1548 kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); 1549 kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); 1550 } 1551 1552 EIGEN_STRONG_INLINE void 1553 ptranspose(PacketBlock<Packet8bf,4>& kernel) { 1554 __m128i a = kernel.packet[0]; 1555 __m128i b = kernel.packet[1]; 1556 __m128i c = kernel.packet[2]; 1557 __m128i d = kernel.packet[3]; 1558 1559 __m128i ab_03 = _mm_unpacklo_epi16(a, b); 1560 __m128i cd_03 = _mm_unpacklo_epi16(c, d); 1561 __m128i ab_47 = _mm_unpackhi_epi16(a, b); 1562 __m128i cd_47 = _mm_unpackhi_epi16(c, d); 1563 1564 kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03); 1565 kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03); 1566 kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47); 1567 kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47); 1568 } 1569 1570 } // end namespace internal 1571 1572 } // end namespace Eigen 1573 1574 #endif // EIGEN_PACKET_MATH_AVX_H 1575