1 #pragma once 2 3 #include <ATen/cpu/vec/intrinsics.h> 4 #include <ATen/cpu/vec/vec_base.h> 5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> 6 #include <c10/util/qint8.h> 7 #include <array> 8 9 // This file defines Vectorized<> for the quantized types. 10 // 11 // 12 // Currently, we simply use these classes as efficient converters between 13 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases 14 // where doing the arithmetic in full-precision is acceptable (e.g. 15 // elementwise operators). 16 // 17 // 18 // Conversions are as follows: 19 // Vectorized<qint8> -> 4x Vectorized<float> 20 // 21 // The size of the returned float vector is specified by the special 22 // constexpr function float_num_vecs. The type of the value returned 23 // from dequantize (and expected as an argument to quantize) is 24 // specified by float_vec_return_type. 25 // 26 // When writing kernels with these vectors, it is expected that floating- 27 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs 28 // iterations. 29 30 namespace at { 31 namespace vec { 32 inline namespace CPU_CAPABILITY { 33 34 template <> 35 struct Vectorized<c10::qint8> { 36 private: 37 union { 38 struct { 39 vint8 _vec0; 40 vint8 _vec1; 41 }; 42 struct { 43 vbool8 _vecb0; 44 vbool8 _vecb1; 45 }; 46 47 } __attribute__((__may_alias__)); 48 49 public: 50 Vectorized() {} 51 using size_type = int; 52 static constexpr size_type size() { 53 return 32; 54 } 55 56 static constexpr size_t float_num_vecs() { 57 return 4; 58 } 59 static constexpr int int_num_vecs() { 60 return 4; 61 } 62 using float_vec_return_type = std::array<Vectorized<float>, 4>; 63 using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>; 64 using value_type = typename c10::qint8::underlying; 65 using vec_internal_type = vint8; 66 using vec_internal_mask_type = vbool8; 67 // Broadcast constructor 68 C10_ALWAYS_INLINE Vectorized(const c10::qint8& val) 69 : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {} 70 71 C10_ALWAYS_INLINE Vectorized(const Vectorized<c10::qint8>& other) 72 : _vec0{other._vec0}, _vec1(other._vec1) {} 73 74 C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {} 75 C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {} 76 C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {} 77 C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {} 78 79 C10_ALWAYS_INLINE const vec_internal_type& vec0() const { 80 return _vec0; 81 } 82 C10_ALWAYS_INLINE const vec_internal_type& vec1() const { 83 return _vec1; 84 } 85 86 static C10_ALWAYS_INLINE Vectorized<c10::qint8> loadu( 87 const void* ptr, 88 int count = size()) { 89 if (count == size()) { 90 return { 91 vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)), 92 vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))}; 93 } 94 __at_align__ value_type tmp_values[size()] = {}; 95 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); 96 return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; 97 } 98 void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { 99 if (count == size()) { 100 vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); 101 vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); 102 } else if (count > 0) { 103 __at_align__ value_type tmp_values[size()]; 104 vec_vsx_st(_vec0, offset0, tmp_values); 105 vec_vsx_st(_vec1, offset16, tmp_values); 106 std::memcpy( 107 ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); 108 } 109 } 110 111 public: 112 float_vec_return_type C10_ALWAYS_INLINE dequantize( 113 Vectorized<float> scale, 114 Vectorized<float> zero_point, 115 Vectorized<float> scale_zp_premul) const { 116 vint16 vecshi0 = vec_unpackh(_vec0); 117 vint16 vecshi1 = vec_unpackl(_vec0); 118 119 vint16 vecshi2 = vec_unpackh(_vec1); 120 vint16 vecshi3 = vec_unpackl(_vec1); 121 122 vint32 veci0 = vec_unpackh(vecshi0); 123 vint32 veci1 = vec_unpackl(vecshi0); 124 125 vint32 veci2 = vec_unpackh(vecshi1); 126 vint32 veci3 = vec_unpackl(vecshi1); 127 128 vint32 veci4 = vec_unpackh(vecshi2); 129 vint32 veci5 = vec_unpackl(vecshi2); 130 131 vint32 veci6 = vec_unpackh(vecshi3); 132 vint32 veci7 = vec_unpackl(vecshi3); 133 134 vfloat32 vecf0_0 = vec_float(veci0); 135 vfloat32 vecf1_0 = vec_float(veci1); 136 137 vfloat32 vecf0_1 = vec_float(veci2); 138 vfloat32 vecf1_1 = vec_float(veci3); 139 140 vfloat32 vecf0_2 = vec_float(veci4); 141 vfloat32 vecf1_2 = vec_float(veci5); 142 143 vfloat32 vecf0_3 = vec_float(veci6); 144 vfloat32 vecf1_3 = vec_float(veci7); 145 vfloat32 scale_vec0 = scale.vec0(); 146 vfloat32 scale_vec1 = scale.vec1(); 147 vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); 148 vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); 149 return { 150 Vectorized<float>{ 151 vec_madd(scale_vec0, vecf0_0, scale_zp_premul0), 152 vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)}, 153 Vectorized<float>{ 154 vec_madd(scale_vec0, vecf0_1, scale_zp_premul0), 155 vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)}, 156 Vectorized<float>{ 157 vec_madd(scale_vec0, vecf0_2, scale_zp_premul0), 158 vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)}, 159 Vectorized<float>{ 160 vec_madd(scale_vec0, vecf0_3, scale_zp_premul0), 161 vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}}; 162 } 163 164 float_vec_return_type C10_ALWAYS_INLINE dequantize( 165 Vectorized<float> scale, 166 Vectorized<float> zero_point) const { 167 vint16 vecshi0 = vec_unpackh(_vec0); 168 vint16 vecshi1 = vec_unpackl(_vec0); 169 170 vint16 vecshi2 = vec_unpackh(_vec1); 171 vint16 vecshi3 = vec_unpackl(_vec1); 172 173 vint32 veci0 = vec_unpackh(vecshi0); 174 vint32 veci1 = vec_unpackl(vecshi0); 175 176 vint32 veci2 = vec_unpackh(vecshi1); 177 vint32 veci3 = vec_unpackl(vecshi1); 178 179 vint32 veci4 = vec_unpackh(vecshi2); 180 vint32 veci5 = vec_unpackl(vecshi2); 181 182 vint32 veci6 = vec_unpackh(vecshi3); 183 vint32 veci7 = vec_unpackl(vecshi3); 184 185 vfloat32 vecf0_0 = vec_float(veci0); 186 vfloat32 vecf1_0 = vec_float(veci1); 187 188 vfloat32 vecf0_1 = vec_float(veci2); 189 vfloat32 vecf1_1 = vec_float(veci3); 190 191 vfloat32 vecf0_2 = vec_float(veci4); 192 vfloat32 vecf1_2 = vec_float(veci5); 193 194 vfloat32 vecf0_3 = vec_float(veci6); 195 vfloat32 vecf1_3 = vec_float(veci7); 196 vfloat32 scale_vec0 = scale.vec0(); 197 vfloat32 scale_vec1 = scale.vec1(); 198 vfloat32 zero_point0 = zero_point.vec0(); 199 vfloat32 zero_point1 = zero_point.vec1(); 200 return { 201 Vectorized<float>{ 202 (vecf0_0 - zero_point0) * scale_vec0, 203 (vecf1_0 - zero_point1) * scale_vec1}, 204 Vectorized<float>{ 205 (vecf0_1 - zero_point0) * scale_vec0, 206 (vecf1_1 - zero_point1) * scale_vec1}, 207 Vectorized<float>{ 208 (vecf0_2 - zero_point0) * scale_vec0, 209 (vecf1_2 - zero_point1) * scale_vec1}, 210 Vectorized<float>{ 211 (vecf0_3 - zero_point0) * scale_vec0, 212 (vecf1_3 - zero_point1) * scale_vec1}}; 213 } 214 215 static Vectorized<c10::qint8> quantize( 216 const float_vec_return_type& rhs, 217 float scale, 218 int32_t zero_point, 219 float inverse_scale) { 220 // constexpr int32_t min_val = std::numeric_limits<value_type>::min(); 221 // constexpr int32_t max_val = std::numeric_limits<value_type>::max(); 222 223 vfloat32 inverse_scale_v = vec_splats(inverse_scale); 224 vfloat32 vec_zero_point = vec_splats((float)zero_point); 225 // vint32 vmin = vec_splats(min_val); 226 // vint32 vmax = vec_splats(max_val); 227 228 Vectorized<float> vf0 = rhs[0]; 229 Vectorized<float> vf1 = rhs[1]; 230 Vectorized<float> vf2 = rhs[2]; 231 Vectorized<float> vf3 = rhs[3]; 232 vfloat32 vecf0 = vf0.vec0(); 233 vfloat32 vecf1 = vf0.vec1(); 234 vfloat32 vecf2 = vf1.vec0(); 235 vfloat32 vecf3 = vf1.vec1(); 236 237 vfloat32 vecf4 = vf2.vec0(); 238 vfloat32 vecf5 = vf2.vec1(); 239 vfloat32 vecf6 = vf3.vec0(); 240 vfloat32 vecf7 = vf3.vec1(); 241 242 vecf0 = vec_mul(vecf0, inverse_scale_v); 243 vecf1 = vec_mul(vecf1, inverse_scale_v); 244 vecf2 = vec_mul(vecf2, inverse_scale_v); 245 vecf3 = vec_mul(vecf3, inverse_scale_v); 246 247 vecf4 = vec_mul(vecf4, inverse_scale_v); 248 vecf5 = vec_mul(vecf5, inverse_scale_v); 249 vecf6 = vec_mul(vecf6, inverse_scale_v); 250 vecf7 = vec_mul(vecf7, inverse_scale_v); 251 252 vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); 253 vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); 254 vecf2 = vec_add(vec_rint(vecf2), vec_zero_point); 255 vecf3 = vec_add(vec_rint(vecf3), vec_zero_point); 256 257 vecf4 = vec_add(vec_rint(vecf4), vec_zero_point); 258 vecf5 = vec_add(vec_rint(vecf5), vec_zero_point); 259 vecf6 = vec_add(vec_rint(vecf6), vec_zero_point); 260 vecf7 = vec_add(vec_rint(vecf7), vec_zero_point); 261 262 vint32 veci0 = vec_signed(vecf0); 263 vint32 veci1 = vec_signed(vecf1); 264 vint32 veci2 = vec_signed(vecf2); 265 vint32 veci3 = vec_signed(vecf3); 266 267 vint32 veci4 = vec_signed(vecf4); 268 vint32 veci5 = vec_signed(vecf5); 269 vint32 veci6 = vec_signed(vecf6); 270 vint32 veci7 = vec_signed(vecf7); 271 272 // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ; 273 // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ; 274 // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ; 275 // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ; 276 277 // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ; 278 // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ; 279 // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ; 280 // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ; 281 // vec_packs CLAMP already 282 vint16 vecshi0 = vec_packs(veci0, veci1); 283 vint16 vecshi1 = vec_packs(veci2, veci3); 284 vint16 vecshi2 = vec_packs(veci4, veci5); 285 vint16 vecshi3 = vec_packs(veci6, veci7); 286 287 vint8 vec0 = vec_packs(vecshi0, vecshi1); 288 vint8 vec1 = vec_packs(vecshi2, vecshi3); 289 290 return {vec0, vec1}; 291 } 292 293 Vectorized<c10::qint8> C10_ALWAYS_INLINE relu(Vectorized<c10::qint8> zero_point) const { 294 return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; 295 } 296 297 Vectorized<c10::qint8> C10_ALWAYS_INLINE 298 relu6(Vectorized<c10::qint8> zero_point, Vectorized<c10::qint8> q_six) const { 299 vint8 max0 = vec_max(_vec0, zero_point._vec0); 300 vint8 max1 = vec_max(_vec1, zero_point._vec1); 301 return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; 302 } 303 304 int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const { 305 vint16 vecshi0 = vec_unpackh(_vec0); 306 vint16 vecBshi0 = vec_unpackh(b._vec0); 307 vint16 vecshi1 = vec_unpackl(_vec0); 308 vint16 vecBshi1 = vec_unpackl(b._vec0); 309 310 vint16 vecshi2 = vec_unpackh(_vec1); 311 vint16 vecBshi2 = vec_unpackh(b._vec1); 312 vint16 vecshi3 = vec_unpackl(_vec1); 313 vint16 vecBshi3 = vec_unpackl(b._vec1); 314 315 vint32 veci0 = vec_unpackh(vecshi0); 316 vint32 vecBi0 = vec_unpackh(vecBshi0); 317 vint32 veci1 = vec_unpackl(vecshi0); 318 vint32 vecBi1 = vec_unpackl(vecBshi0); 319 320 vint32 veci2 = vec_unpackh(vecshi1); 321 vint32 vecBi2 = vec_unpackh(vecBshi1); 322 vint32 veci3 = vec_unpackl(vecshi1); 323 vint32 vecBi3 = vec_unpackl(vecBshi1); 324 325 vint32 veci4 = vec_unpackh(vecshi2); 326 vint32 vecBi4 = vec_unpackh(vecBshi2); 327 vint32 veci5 = vec_unpackl(vecshi2); 328 vint32 vecBi5 = vec_unpackl(vecBshi2); 329 330 vint32 veci6 = vec_unpackh(vecshi3); 331 vint32 vecBi6 = vec_unpackh(vecBshi3); 332 vint32 veci7 = vec_unpackl(vecshi3); 333 vint32 vecBi7 = vec_unpackl(vecBshi3); 334 335 return { 336 Vectorized<c10::qint32>(veci0 - vecBi0, veci1 - vecBi1), 337 Vectorized<c10::qint32>(veci2 - vecBi2, veci3 - vecBi3), 338 Vectorized<c10::qint32>(veci4 - vecBi4, veci5 - vecBi5), 339 Vectorized<c10::qint32>(veci6 - vecBi6, veci7 - vecBi7)}; 340 } 341 342 static Vectorized<c10::qint8> requantize_from_int( 343 const int_vec_return_type& inp, 344 float multiplier, 345 int32_t zero_point) { 346 vfloat32 vec_multiplier = vec_splats(multiplier); 347 vint32 vec_zero_point = vec_splats(zero_point); 348 349 Vectorized<c10::qint32> vi0 = inp[0]; 350 Vectorized<c10::qint32> vi1 = inp[1]; 351 Vectorized<c10::qint32> vi2 = inp[2]; 352 Vectorized<c10::qint32> vi3 = inp[3]; 353 354 vfloat32 vecf0 = vec_float(vi0.vec0()); 355 vfloat32 vecf1 = vec_float(vi0.vec1()); 356 vfloat32 vecf2 = vec_float(vi1.vec0()); 357 vfloat32 vecf3 = vec_float(vi1.vec1()); 358 359 vfloat32 vecf4 = vec_float(vi2.vec0()); 360 vfloat32 vecf5 = vec_float(vi2.vec1()); 361 vfloat32 vecf6 = vec_float(vi3.vec0()); 362 vfloat32 vecf7 = vec_float(vi3.vec1()); 363 364 vecf0 = vec_mul(vecf0, vec_multiplier); 365 vecf1 = vec_mul(vecf1, vec_multiplier); 366 vecf2 = vec_mul(vecf2, vec_multiplier); 367 vecf3 = vec_mul(vecf3, vec_multiplier); 368 369 vecf4 = vec_mul(vecf4, vec_multiplier); 370 vecf5 = vec_mul(vecf5, vec_multiplier); 371 vecf6 = vec_mul(vecf6, vec_multiplier); 372 vecf7 = vec_mul(vecf7, vec_multiplier); 373 374 vecf0 = vec_rint(vecf0); 375 vecf1 = vec_rint(vecf1); 376 vecf2 = vec_rint(vecf2); 377 vecf3 = vec_rint(vecf3); 378 379 vecf4 = vec_rint(vecf4); 380 vecf5 = vec_rint(vecf5); 381 vecf6 = vec_rint(vecf6); 382 vecf7 = vec_rint(vecf7); 383 384 vint32 veci0 = vec_signed(vecf0); 385 vint32 veci1 = vec_signed(vecf1); 386 vint32 veci2 = vec_signed(vecf2); 387 vint32 veci3 = vec_signed(vecf3); 388 389 vint32 veci4 = vec_signed(vecf4); 390 vint32 veci5 = vec_signed(vecf5); 391 vint32 veci6 = vec_signed(vecf6); 392 vint32 veci7 = vec_signed(vecf7); 393 394 veci0 = vec_add(veci0, vec_zero_point); 395 veci1 = vec_add(veci1, vec_zero_point); 396 veci2 = vec_add(veci2, vec_zero_point); 397 veci3 = vec_add(veci3, vec_zero_point); 398 399 veci4 = vec_add(veci4, vec_zero_point); 400 veci5 = vec_add(veci5, vec_zero_point); 401 veci6 = vec_add(veci6, vec_zero_point); 402 veci7 = vec_add(veci7, vec_zero_point); 403 404 vint16 vecshi0 = vec_packs(veci0, veci1); 405 vint16 vecshi1 = vec_packs(veci2, veci3); 406 vint16 vecshi2 = vec_packs(veci4, veci5); 407 vint16 vecshi3 = vec_packs(veci6, veci7); 408 409 vint8 vec0 = vec_packs(vecshi0, vecshi1); 410 vint8 vec1 = vec_packs(vecshi2, vecshi3); 411 412 return {vec0, vec1}; 413 } 414 415 DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq) 416 DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne) 417 DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt) 418 DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple) 419 DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt) 420 DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge) 421 DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add) 422 DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub) 423 DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul) 424 DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /) 425 DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max) 426 DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min) 427 DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and) 428 DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or) 429 DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor) 430 }; 431 432 template <> 433 Vectorized<c10::qint8> inline maximum( 434 const Vectorized<c10::qint8>& a, 435 const Vectorized<c10::qint8>& b) { 436 return a.maximum(b); 437 } 438 439 template <> 440 Vectorized<c10::qint8> inline minimum( 441 const Vectorized<c10::qint8>& a, 442 const Vectorized<c10::qint8>& b) { 443 return a.minimum(b); 444 } 445 446 template <> 447 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 448 return Vectorized<c10::qint8>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; 449 } 450 451 template <> 452 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 453 return Vectorized<c10::qint8>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; 454 } 455 456 template <> 457 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 458 return Vectorized<c10::qint8>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; 459 } 460 461 template <> 462 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 463 return Vectorized<c10::qint8>{a.vec0()/b.vec0(), a.vec1()/b.vec1()}; 464 } 465 466 template <> 467 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 468 return Vectorized<c10::qint8>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; 469 } 470 471 template <> 472 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 473 return Vectorized<c10::qint8>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; 474 } 475 476 template <> 477 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) { 478 return Vectorized<c10::qint8>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; 479 } 480 481 } // namespace 482 } // namespace vec 483 } // namespace at 484