xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 #include <c10/util/qint8.h>
7 #include <array>
8 
9 // This file defines Vectorized<> for the quantized types.
10 //
11 //
12 // Currently, we simply use these classes as efficient converters between
13 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases
14 // where doing the arithmetic in full-precision is acceptable (e.g.
15 // elementwise operators).
16 //
17 //
18 // Conversions are as follows:
19 //  Vectorized<qint8> -> 4x Vectorized<float>
20 //
21 // The size of the returned float vector is specified by the special
22 // constexpr function float_num_vecs. The type of the value returned
23 // from dequantize (and expected as an argument to quantize) is
24 // specified by float_vec_return_type.
25 //
26 // When writing kernels with these vectors, it is expected that floating-
27 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
28 // iterations.
29 
30 namespace at {
31 namespace vec {
32 inline namespace CPU_CAPABILITY {
33 
34 template <>
35 struct Vectorized<c10::qint8> {
36  private:
37   union {
38     struct {
39       vint8 _vec0;
40       vint8 _vec1;
41     };
42     struct {
43       vbool8 _vecb0;
44       vbool8 _vecb1;
45     };
46 
47   } __attribute__((__may_alias__));
48 
49  public:
50   Vectorized() {}
51   using size_type = int;
52   static constexpr size_type size() {
53     return 32;
54   }
55 
56   static constexpr size_t float_num_vecs() {
57     return 4;
58   }
59   static constexpr int int_num_vecs() {
60     return 4;
61   }
62   using float_vec_return_type = std::array<Vectorized<float>, 4>;
63   using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
64   using value_type = typename c10::qint8::underlying;
65   using vec_internal_type = vint8;
66   using vec_internal_mask_type = vbool8;
67   // Broadcast constructor
68   C10_ALWAYS_INLINE Vectorized(const c10::qint8& val)
69       : _vec0{vec_splats(val.val_)}, _vec1{vec_splats(val.val_)} {}
70 
71   C10_ALWAYS_INLINE Vectorized(const Vectorized<c10::qint8>& other)
72       : _vec0{other._vec0}, _vec1(other._vec1) {}
73 
74   C10_ALWAYS_INLINE Vectorized(vint8 v) : _vec0{v}, _vec1{v} {}
75   C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
76   C10_ALWAYS_INLINE Vectorized(vint8 v1, vint8 v2) : _vec0{v1}, _vec1{v2} {}
77   C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
78 
79   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
80     return _vec0;
81   }
82   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
83     return _vec1;
84   }
85 
86   static C10_ALWAYS_INLINE Vectorized<c10::qint8> loadu(
87       const void* ptr,
88       int count = size()) {
89     if (count == size()) {
90       return {
91           vec_vsx_ld(offset0, reinterpret_cast<const vint8*>(ptr)),
92           vec_vsx_ld(offset16, reinterpret_cast<const vint8*>(ptr))};
93     }
94     __at_align__ value_type tmp_values[size()] = {};
95     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
96     return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
97   }
98   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
99     if (count == size()) {
100       vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
101       vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
102     } else if (count > 0) {
103       __at_align__ value_type tmp_values[size()];
104       vec_vsx_st(_vec0, offset0, tmp_values);
105       vec_vsx_st(_vec1, offset16, tmp_values);
106       std::memcpy(
107           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
108     }
109   }
110 
111  public:
112   float_vec_return_type C10_ALWAYS_INLINE dequantize(
113       Vectorized<float> scale,
114       Vectorized<float> zero_point,
115       Vectorized<float> scale_zp_premul) const {
116     vint16 vecshi0 = vec_unpackh(_vec0);
117     vint16 vecshi1 = vec_unpackl(_vec0);
118 
119     vint16 vecshi2 = vec_unpackh(_vec1);
120     vint16 vecshi3 = vec_unpackl(_vec1);
121 
122     vint32 veci0 = vec_unpackh(vecshi0);
123     vint32 veci1 = vec_unpackl(vecshi0);
124 
125     vint32 veci2 = vec_unpackh(vecshi1);
126     vint32 veci3 = vec_unpackl(vecshi1);
127 
128     vint32 veci4 = vec_unpackh(vecshi2);
129     vint32 veci5 = vec_unpackl(vecshi2);
130 
131     vint32 veci6 = vec_unpackh(vecshi3);
132     vint32 veci7 = vec_unpackl(vecshi3);
133 
134     vfloat32 vecf0_0 = vec_float(veci0);
135     vfloat32 vecf1_0 = vec_float(veci1);
136 
137     vfloat32 vecf0_1 = vec_float(veci2);
138     vfloat32 vecf1_1 = vec_float(veci3);
139 
140     vfloat32 vecf0_2 = vec_float(veci4);
141     vfloat32 vecf1_2 = vec_float(veci5);
142 
143     vfloat32 vecf0_3 = vec_float(veci6);
144     vfloat32 vecf1_3 = vec_float(veci7);
145     vfloat32 scale_vec0 = scale.vec0();
146     vfloat32 scale_vec1 = scale.vec1();
147     vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
148     vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
149     return {
150         Vectorized<float>{
151             vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
152             vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
153         Vectorized<float>{
154             vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
155             vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
156         Vectorized<float>{
157             vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
158             vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
159         Vectorized<float>{
160             vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
161             vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
162   }
163 
164   float_vec_return_type C10_ALWAYS_INLINE dequantize(
165       Vectorized<float> scale,
166       Vectorized<float> zero_point) const {
167     vint16 vecshi0 = vec_unpackh(_vec0);
168     vint16 vecshi1 = vec_unpackl(_vec0);
169 
170     vint16 vecshi2 = vec_unpackh(_vec1);
171     vint16 vecshi3 = vec_unpackl(_vec1);
172 
173     vint32 veci0 = vec_unpackh(vecshi0);
174     vint32 veci1 = vec_unpackl(vecshi0);
175 
176     vint32 veci2 = vec_unpackh(vecshi1);
177     vint32 veci3 = vec_unpackl(vecshi1);
178 
179     vint32 veci4 = vec_unpackh(vecshi2);
180     vint32 veci5 = vec_unpackl(vecshi2);
181 
182     vint32 veci6 = vec_unpackh(vecshi3);
183     vint32 veci7 = vec_unpackl(vecshi3);
184 
185     vfloat32 vecf0_0 = vec_float(veci0);
186     vfloat32 vecf1_0 = vec_float(veci1);
187 
188     vfloat32 vecf0_1 = vec_float(veci2);
189     vfloat32 vecf1_1 = vec_float(veci3);
190 
191     vfloat32 vecf0_2 = vec_float(veci4);
192     vfloat32 vecf1_2 = vec_float(veci5);
193 
194     vfloat32 vecf0_3 = vec_float(veci6);
195     vfloat32 vecf1_3 = vec_float(veci7);
196     vfloat32 scale_vec0 = scale.vec0();
197     vfloat32 scale_vec1 = scale.vec1();
198     vfloat32 zero_point0 = zero_point.vec0();
199     vfloat32 zero_point1 = zero_point.vec1();
200     return {
201         Vectorized<float>{
202             (vecf0_0 - zero_point0) * scale_vec0,
203             (vecf1_0 - zero_point1) * scale_vec1},
204         Vectorized<float>{
205             (vecf0_1 - zero_point0) * scale_vec0,
206             (vecf1_1 - zero_point1) * scale_vec1},
207         Vectorized<float>{
208             (vecf0_2 - zero_point0) * scale_vec0,
209             (vecf1_2 - zero_point1) * scale_vec1},
210         Vectorized<float>{
211             (vecf0_3 - zero_point0) * scale_vec0,
212             (vecf1_3 - zero_point1) * scale_vec1}};
213   }
214 
215   static Vectorized<c10::qint8> quantize(
216       const float_vec_return_type& rhs,
217       float scale,
218       int32_t zero_point,
219       float inverse_scale) {
220     // constexpr int32_t min_val = std::numeric_limits<value_type>::min();
221     // constexpr int32_t max_val = std::numeric_limits<value_type>::max();
222 
223     vfloat32 inverse_scale_v = vec_splats(inverse_scale);
224     vfloat32 vec_zero_point = vec_splats((float)zero_point);
225     // vint32 vmin = vec_splats(min_val);
226     // vint32 vmax = vec_splats(max_val);
227 
228     Vectorized<float> vf0 = rhs[0];
229     Vectorized<float> vf1 = rhs[1];
230     Vectorized<float> vf2 = rhs[2];
231     Vectorized<float> vf3 = rhs[3];
232     vfloat32 vecf0 = vf0.vec0();
233     vfloat32 vecf1 = vf0.vec1();
234     vfloat32 vecf2 = vf1.vec0();
235     vfloat32 vecf3 = vf1.vec1();
236 
237     vfloat32 vecf4 = vf2.vec0();
238     vfloat32 vecf5 = vf2.vec1();
239     vfloat32 vecf6 = vf3.vec0();
240     vfloat32 vecf7 = vf3.vec1();
241 
242     vecf0 = vec_mul(vecf0, inverse_scale_v);
243     vecf1 = vec_mul(vecf1, inverse_scale_v);
244     vecf2 = vec_mul(vecf2, inverse_scale_v);
245     vecf3 = vec_mul(vecf3, inverse_scale_v);
246 
247     vecf4 = vec_mul(vecf4, inverse_scale_v);
248     vecf5 = vec_mul(vecf5, inverse_scale_v);
249     vecf6 = vec_mul(vecf6, inverse_scale_v);
250     vecf7 = vec_mul(vecf7, inverse_scale_v);
251 
252     vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
253     vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
254     vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
255     vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
256 
257     vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
258     vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
259     vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
260     vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
261 
262     vint32 veci0 = vec_signed(vecf0);
263     vint32 veci1 = vec_signed(vecf1);
264     vint32 veci2 = vec_signed(vecf2);
265     vint32 veci3 = vec_signed(vecf3);
266 
267     vint32 veci4 = vec_signed(vecf4);
268     vint32 veci5 = vec_signed(vecf5);
269     vint32 veci6 = vec_signed(vecf6);
270     vint32 veci7 = vec_signed(vecf7);
271 
272     // veci0 = vec_min(vmax, vec_max( vmin, vecf0)) ;
273     // veci1 = vec_min(vmax, vec_max( vmin, vecf1)) ;
274     // veci2 = vec_min(vmax, vec_max( vmin, vecf2)) ;
275     // veci3 = vec_min(vmax, vec_max( vmin, vecf3)) ;
276 
277     // veci4 = vec_min(vmax, vec_max( vmin, vecf4)) ;
278     // veci5 = vec_min(vmax, vec_max( vmin, vecf5)) ;
279     // veci6 = vec_min(vmax, vec_max( vmin, vecf6)) ;
280     // veci7 = vec_min(vmax, vec_max( vmin, vecf7)) ;
281     // vec_packs CLAMP already
282     vint16 vecshi0 = vec_packs(veci0, veci1);
283     vint16 vecshi1 = vec_packs(veci2, veci3);
284     vint16 vecshi2 = vec_packs(veci4, veci5);
285     vint16 vecshi3 = vec_packs(veci6, veci7);
286 
287     vint8 vec0 = vec_packs(vecshi0, vecshi1);
288     vint8 vec1 = vec_packs(vecshi2, vecshi3);
289 
290     return {vec0, vec1};
291   }
292 
293   Vectorized<c10::qint8> C10_ALWAYS_INLINE relu(Vectorized<c10::qint8> zero_point) const {
294     return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
295   }
296 
297   Vectorized<c10::qint8> C10_ALWAYS_INLINE
298   relu6(Vectorized<c10::qint8> zero_point, Vectorized<c10::qint8> q_six) const {
299     vint8 max0 = vec_max(_vec0, zero_point._vec0);
300     vint8 max1 = vec_max(_vec1, zero_point._vec1);
301     return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
302   }
303 
304   int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
305     vint16 vecshi0 = vec_unpackh(_vec0);
306     vint16 vecBshi0 = vec_unpackh(b._vec0);
307     vint16 vecshi1 = vec_unpackl(_vec0);
308     vint16 vecBshi1 = vec_unpackl(b._vec0);
309 
310     vint16 vecshi2 = vec_unpackh(_vec1);
311     vint16 vecBshi2 = vec_unpackh(b._vec1);
312     vint16 vecshi3 = vec_unpackl(_vec1);
313     vint16 vecBshi3 = vec_unpackl(b._vec1);
314 
315     vint32 veci0 = vec_unpackh(vecshi0);
316     vint32 vecBi0 = vec_unpackh(vecBshi0);
317     vint32 veci1 = vec_unpackl(vecshi0);
318     vint32 vecBi1 = vec_unpackl(vecBshi0);
319 
320     vint32 veci2 = vec_unpackh(vecshi1);
321     vint32 vecBi2 = vec_unpackh(vecBshi1);
322     vint32 veci3 = vec_unpackl(vecshi1);
323     vint32 vecBi3 = vec_unpackl(vecBshi1);
324 
325     vint32 veci4 = vec_unpackh(vecshi2);
326     vint32 vecBi4 = vec_unpackh(vecBshi2);
327     vint32 veci5 = vec_unpackl(vecshi2);
328     vint32 vecBi5 = vec_unpackl(vecBshi2);
329 
330     vint32 veci6 = vec_unpackh(vecshi3);
331     vint32 vecBi6 = vec_unpackh(vecBshi3);
332     vint32 veci7 = vec_unpackl(vecshi3);
333     vint32 vecBi7 = vec_unpackl(vecBshi3);
334 
335     return {
336         Vectorized<c10::qint32>(veci0 - vecBi0, veci1 - vecBi1),
337         Vectorized<c10::qint32>(veci2 - vecBi2, veci3 - vecBi3),
338         Vectorized<c10::qint32>(veci4 - vecBi4, veci5 - vecBi5),
339         Vectorized<c10::qint32>(veci6 - vecBi6, veci7 - vecBi7)};
340   }
341 
342   static Vectorized<c10::qint8> requantize_from_int(
343       const int_vec_return_type& inp,
344       float multiplier,
345       int32_t zero_point) {
346     vfloat32 vec_multiplier = vec_splats(multiplier);
347     vint32 vec_zero_point = vec_splats(zero_point);
348 
349     Vectorized<c10::qint32> vi0 = inp[0];
350     Vectorized<c10::qint32> vi1 = inp[1];
351     Vectorized<c10::qint32> vi2 = inp[2];
352     Vectorized<c10::qint32> vi3 = inp[3];
353 
354     vfloat32 vecf0 = vec_float(vi0.vec0());
355     vfloat32 vecf1 = vec_float(vi0.vec1());
356     vfloat32 vecf2 = vec_float(vi1.vec0());
357     vfloat32 vecf3 = vec_float(vi1.vec1());
358 
359     vfloat32 vecf4 = vec_float(vi2.vec0());
360     vfloat32 vecf5 = vec_float(vi2.vec1());
361     vfloat32 vecf6 = vec_float(vi3.vec0());
362     vfloat32 vecf7 = vec_float(vi3.vec1());
363 
364     vecf0 = vec_mul(vecf0, vec_multiplier);
365     vecf1 = vec_mul(vecf1, vec_multiplier);
366     vecf2 = vec_mul(vecf2, vec_multiplier);
367     vecf3 = vec_mul(vecf3, vec_multiplier);
368 
369     vecf4 = vec_mul(vecf4, vec_multiplier);
370     vecf5 = vec_mul(vecf5, vec_multiplier);
371     vecf6 = vec_mul(vecf6, vec_multiplier);
372     vecf7 = vec_mul(vecf7, vec_multiplier);
373 
374     vecf0 = vec_rint(vecf0);
375     vecf1 = vec_rint(vecf1);
376     vecf2 = vec_rint(vecf2);
377     vecf3 = vec_rint(vecf3);
378 
379     vecf4 = vec_rint(vecf4);
380     vecf5 = vec_rint(vecf5);
381     vecf6 = vec_rint(vecf6);
382     vecf7 = vec_rint(vecf7);
383 
384     vint32 veci0 = vec_signed(vecf0);
385     vint32 veci1 = vec_signed(vecf1);
386     vint32 veci2 = vec_signed(vecf2);
387     vint32 veci3 = vec_signed(vecf3);
388 
389     vint32 veci4 = vec_signed(vecf4);
390     vint32 veci5 = vec_signed(vecf5);
391     vint32 veci6 = vec_signed(vecf6);
392     vint32 veci7 = vec_signed(vecf7);
393 
394     veci0 = vec_add(veci0, vec_zero_point);
395     veci1 = vec_add(veci1, vec_zero_point);
396     veci2 = vec_add(veci2, vec_zero_point);
397     veci3 = vec_add(veci3, vec_zero_point);
398 
399     veci4 = vec_add(veci4, vec_zero_point);
400     veci5 = vec_add(veci5, vec_zero_point);
401     veci6 = vec_add(veci6, vec_zero_point);
402     veci7 = vec_add(veci7, vec_zero_point);
403 
404     vint16 vecshi0 = vec_packs(veci0, veci1);
405     vint16 vecshi1 = vec_packs(veci2, veci3);
406     vint16 vecshi2 = vec_packs(veci4, veci5);
407     vint16 vecshi3 = vec_packs(veci6, veci7);
408 
409     vint8 vec0 = vec_packs(vecshi0, vecshi1);
410     vint8 vec1 = vec_packs(vecshi2, vecshi3);
411 
412     return {vec0, vec1};
413   }
414 
415   DEFINE_MEMBER_OP(operator==, c10::qint8, vec_cmpeq)
416   DEFINE_MEMBER_OP(operator!=, c10::qint8, vec_cmpne)
417   DEFINE_MEMBER_OP(operator<, c10::qint8, vec_cmplt)
418   DEFINE_MEMBER_OP(operator<=, c10::qint8, vec_cmple)
419   DEFINE_MEMBER_OP(operator>, c10::qint8, vec_cmpgt)
420   DEFINE_MEMBER_OP(operator>=, c10::qint8, vec_cmpge)
421   DEFINE_MEMBER_OP(operator+, c10::qint8, vec_add)
422   DEFINE_MEMBER_OP(operator-, c10::qint8, vec_sub)
423   DEFINE_MEMBER_OP(operator*, c10::qint8, vec_mul)
424   DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint8, /)
425   DEFINE_MEMBER_OP(maximum, c10::qint8, vec_max)
426   DEFINE_MEMBER_OP(minimum, c10::qint8, vec_min)
427   DEFINE_MEMBER_OP(operator&, c10::qint8, vec_and)
428   DEFINE_MEMBER_OP(operator|, c10::qint8, vec_or)
429   DEFINE_MEMBER_OP(operator^, c10::qint8, vec_xor)
430 };
431 
432 template <>
433 Vectorized<c10::qint8> inline maximum(
434     const Vectorized<c10::qint8>& a,
435     const Vectorized<c10::qint8>& b) {
436   return a.maximum(b);
437 }
438 
439 template <>
440 Vectorized<c10::qint8> inline minimum(
441     const Vectorized<c10::qint8>& a,
442     const Vectorized<c10::qint8>& b) {
443   return a.minimum(b);
444 }
445 
446 template <>
447 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
448   return Vectorized<c10::qint8>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
449 }
450 
451 template <>
452 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
453   return Vectorized<c10::qint8>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
454 }
455 
456 template <>
457 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
458   return Vectorized<c10::qint8>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
459 }
460 
461 template <>
462 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
463   return Vectorized<c10::qint8>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
464 }
465 
466 template <>
467 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
468   return Vectorized<c10::qint8>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
469 }
470 
471 template <>
472 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
473   return Vectorized<c10::qint8>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
474 }
475 
476 template <>
477 Vectorized<c10::qint8> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint8>& a, const Vectorized<c10::qint8>& b) {
478   return Vectorized<c10::qint8>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
479 }
480 
481 } // namespace
482 } // namespace vec
483 } // namespace at
484