1 #pragma once 2 3 #include <ATen/cpu/vec/intrinsics.h> 4 #include <ATen/cpu/vec/vec_base.h> 5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h> 6 #include <c10/util/qint32.h> 7 #include <array> 8 9 // This file defines Vectorized<> for the quantized types. 10 // 11 // 12 // Currently, we simply use these classes as efficient converters between 13 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases 14 // where doing the arithmetic in full-precision is acceptable (e.g. 15 // elementwise operators). 16 // 17 // 18 // Conversions are as follows: 19 // Vectorized<qint32> -> 1x Vectorized<float> 20 // 21 // The size of the returned float vector is specified by the special 22 // constexpr function float_num_vecs. The type of the value returned 23 // from dequantize (and expected as an argument to quantize) is 24 // specified by float_vec_return_type. 25 // 26 // When writing kernels with these vectors, it is expected that floating- 27 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs 28 // iterations. 29 30 namespace at { 31 namespace vec { 32 inline namespace CPU_CAPABILITY { 33 34 template <> 35 struct Vectorized<c10::qint32> { 36 private: 37 union { 38 struct { 39 vint32 _vec0; 40 vint32 _vec1; 41 }; 42 struct { 43 vbool32 _vecb0; 44 vbool32 _vecb1; 45 }; 46 47 } __attribute__((__may_alias__)); 48 49 public: 50 Vectorized() {} 51 52 using size_type = int; 53 static constexpr size_type size() { 54 return 8; 55 } 56 57 static constexpr size_t float_num_vecs() { 58 return 1; 59 } 60 static constexpr int int_num_vecs() { 61 return 1; 62 } 63 using float_vec_return_type = std::array<Vectorized<float>, 1>; 64 using int_vec_return_type = std::array<Vectorized<c10::qint32>, 1>; 65 using value_type = c10::qint32::underlying; 66 using vec_internal_type = vint32; 67 using vec_internal_mask_type = vbool32; 68 C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} 69 C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} 70 C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} 71 C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {} 72 73 Vectorized(const c10::qint32& val) 74 : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {} 75 76 static Vectorized<c10::qint32> C10_ALWAYS_INLINE 77 loadu(const void* ptr, int count = size()) { 78 if (count == size()) { 79 return { 80 vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)), 81 vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))}; 82 } 83 84 __at_align__ value_type tmp_values[size()] = {}; 85 std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); 86 87 return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; 88 } 89 void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const { 90 if (count == size()) { 91 vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr)); 92 vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr)); 93 } else if (count > 0) { 94 __at_align__ value_type tmp_values[size()]; 95 vec_vsx_st(_vec0, offset0, tmp_values); 96 vec_vsx_st(_vec1, offset16, tmp_values); 97 std::memcpy( 98 ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); 99 } 100 } 101 102 C10_ALWAYS_INLINE const vec_internal_type& vec0() const { 103 return _vec0; 104 } 105 C10_ALWAYS_INLINE const vec_internal_type& vec1() const { 106 return _vec1; 107 } 108 109 float_vec_return_type dequantize( 110 Vectorized<float> scale, 111 Vectorized<float> zero_point, 112 Vectorized<float> scale_zp_premul) const { 113 vfloat32 float_vals0 = vec_float(_vec0); 114 vfloat32 float_vals1 = vec_float(_vec1); 115 vfloat32 scale_vec0 = scale.vec0(); 116 vfloat32 scale_vec1 = scale.vec1(); 117 vfloat32 scale_zp_premul0 = scale_zp_premul.vec0(); 118 vfloat32 scale_zp_premul1 = scale_zp_premul.vec1(); 119 return {Vectorized<float>{ 120 vec_madd(scale_vec0, float_vals0, scale_zp_premul0), 121 vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}}; 122 } 123 124 float_vec_return_type dequantize( 125 Vectorized<float> scale, 126 Vectorized<float> zero_point) const { 127 vfloat32 float_vals0 = vec_float(_vec0); 128 vfloat32 float_vals1 = vec_float(_vec1); 129 vfloat32 scale_vec0 = scale.vec0(); 130 vfloat32 scale_vec1 = scale.vec1(); 131 vfloat32 zero_point0 = zero_point.vec0(); 132 vfloat32 zero_point1 = zero_point.vec1(); 133 return {Vectorized<float>{ 134 (float_vals0 - zero_point0) * scale_vec0, 135 (float_vals1 - zero_point1) * scale_vec1}}; 136 } 137 138 static Vectorized<c10::qint32> quantize( 139 const float_vec_return_type& rhs, 140 float scale, 141 int32_t zero_point, 142 float inverse_scale) { 143 Vectorized<c10::qint32> retval; 144 145 const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min()); 146 const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max()); 147 vfloat32 inverse_scale_v = vec_splats(inverse_scale); 148 vfloat32 vec_zero_point = vec_splats((float)(zero_point)); 149 Vectorized<float> vf0 = rhs[0]; 150 151 vfloat32 vecf0 = vf0.vec0(); 152 vfloat32 vecf1 = vf0.vec1(); 153 vecf0 = vec_mul(vecf0, inverse_scale_v); 154 vecf1 = vec_mul(vecf1, inverse_scale_v); 155 vecf0 = vec_add(vec_rint(vecf0), vec_zero_point); 156 vecf1 = vec_add(vec_rint(vecf1), vec_zero_point); 157 vint32 veci0 = vec_signed(vecf0); 158 vint32 veci1 = vec_signed(vecf1); 159 160 veci0 = vec_max(veci0, vmin); 161 veci1 = vec_max(veci1, vmin); 162 veci0 = vec_min(veci0, vmax); 163 veci1 = vec_min(veci1, vmax); 164 165 return {veci0, veci1}; 166 } 167 168 Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const { 169 return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)}; 170 } 171 172 Vectorized<c10::qint32> relu6( 173 Vectorized<c10::qint32> zero_point, 174 Vectorized<c10::qint32> q_six) const { 175 vint32 max0 = vec_max(_vec0, zero_point._vec0); 176 vint32 max1 = vec_max(_vec1, zero_point._vec1); 177 return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)}; 178 } 179 180 int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const { 181 return {*this - b}; 182 } 183 184 static Vectorized<c10::qint32> requantize_from_int( 185 const int_vec_return_type& inp, 186 float multiplier, 187 int32_t zero_point) { 188 const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min()); 189 const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max()); 190 vfloat32 vec_mult = vec_splats(multiplier); 191 vint32 vec_zero_point = vec_splats(zero_point); 192 Vectorized<c10::qint32> vi = inp[0]; 193 vfloat32 vecf0 = vec_float(vi.vec0()); 194 vfloat32 vecf1 = vec_float(vi.vec1()); 195 196 vecf0 = vec_mul(vecf0, vec_mult); 197 vecf1 = vec_mul(vecf1, vec_mult); 198 199 vecf0 = vec_rint(vecf0); 200 vecf1 = vec_rint(vecf1); 201 202 vint32 veci0 = vec_add(vec_signed(vecf0),vec_zero_point); 203 vint32 veci1 = vec_add(vec_signed(vecf1),vec_zero_point); 204 205 veci0 = vec_max(veci0, vmin); 206 veci1 = vec_max(veci1, vmin); 207 veci0 = vec_min(veci0, vmax); 208 veci1 = vec_min(veci1, vmax); 209 210 return {veci0, veci1}; 211 } 212 213 DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq) 214 DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne) 215 DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt) 216 DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple) 217 DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt) 218 DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge) 219 DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add) 220 DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub) 221 DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul) 222 DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /) 223 DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max) 224 DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min) 225 DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and) 226 DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or) 227 DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor) 228 }; 229 230 template <> 231 Vectorized<c10::qint32> inline maximum( 232 const Vectorized<c10::qint32>& a, 233 const Vectorized<c10::qint32>& b) { 234 return a.maximum(b); 235 } 236 237 template <> 238 Vectorized<c10::qint32> inline minimum( 239 const Vectorized<c10::qint32>& a, 240 const Vectorized<c10::qint32>& b) { 241 return a.minimum(b); 242 } 243 244 template <> 245 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 246 return Vectorized<c10::qint32>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())}; 247 } 248 249 template <> 250 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 251 return Vectorized<c10::qint32>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())}; 252 } 253 254 template <> 255 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 256 return Vectorized<c10::qint32>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())}; 257 } 258 259 template <> 260 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 261 return Vectorized<c10::qint32>{a.vec0()/b.vec0(), a.vec1()/b.vec1()}; 262 } 263 264 template <> 265 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 266 return Vectorized<c10::qint32>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())}; 267 } 268 269 template <> 270 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 271 return Vectorized<c10::qint32>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())}; 272 } 273 274 template <> 275 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) { 276 return Vectorized<c10::qint32>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())}; 277 } 278 279 } // namespace 280 } // namespace vec 281 } // namespace at 282