xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 #include <c10/util/qint32.h>
7 #include <array>
8 
9 // This file defines Vectorized<> for the quantized types.
10 //
11 //
12 // Currently, we simply use these classes as efficient converters between
13 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases
14 // where doing the arithmetic in full-precision is acceptable (e.g.
15 // elementwise operators).
16 //
17 //
18 // Conversions are as follows:
19 //  Vectorized<qint32> -> 1x Vectorized<float>
20 //
21 // The size of the returned float vector is specified by the special
22 // constexpr function float_num_vecs. The type of the value returned
23 // from dequantize (and expected as an argument to quantize) is
24 // specified by float_vec_return_type.
25 //
26 // When writing kernels with these vectors, it is expected that floating-
27 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
28 // iterations.
29 
30 namespace at {
31 namespace vec {
32 inline namespace CPU_CAPABILITY {
33 
34 template <>
35 struct Vectorized<c10::qint32> {
36  private:
37   union {
38     struct {
39       vint32 _vec0;
40       vint32 _vec1;
41     };
42     struct {
43       vbool32 _vecb0;
44       vbool32 _vecb1;
45     };
46 
47   } __attribute__((__may_alias__));
48 
49  public:
50   Vectorized() {}
51 
52   using size_type = int;
53   static constexpr size_type size() {
54     return 8;
55   }
56 
57   static constexpr size_t float_num_vecs() {
58     return 1;
59   }
60   static constexpr int int_num_vecs() {
61     return 1;
62   }
63   using float_vec_return_type = std::array<Vectorized<float>, 1>;
64   using int_vec_return_type = std::array<Vectorized<c10::qint32>, 1>;
65   using value_type = c10::qint32::underlying;
66   using vec_internal_type = vint32;
67   using vec_internal_mask_type = vbool32;
68   C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {}
69   C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
70   C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {}
71   C10_ALWAYS_INLINE Vectorized(vbool32 v1, vbool32 v2) : _vecb0{v1}, _vecb1{v2} {}
72 
73   Vectorized(const c10::qint32& val)
74       : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
75 
76   static Vectorized<c10::qint32> C10_ALWAYS_INLINE
77   loadu(const void* ptr, int count = size()) {
78     if (count == size()) {
79       return {
80           vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
81           vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
82     }
83 
84     __at_align__ value_type tmp_values[size()] = {};
85     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
86 
87     return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
88   }
89   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
90     if (count == size()) {
91       vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
92       vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
93     } else if (count > 0) {
94       __at_align__ value_type tmp_values[size()];
95       vec_vsx_st(_vec0, offset0, tmp_values);
96       vec_vsx_st(_vec1, offset16, tmp_values);
97       std::memcpy(
98           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
99     }
100   }
101 
102   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
103     return _vec0;
104   }
105   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
106     return _vec1;
107   }
108 
109   float_vec_return_type dequantize(
110       Vectorized<float> scale,
111       Vectorized<float> zero_point,
112       Vectorized<float> scale_zp_premul) const {
113     vfloat32 float_vals0 = vec_float(_vec0);
114     vfloat32 float_vals1 = vec_float(_vec1);
115     vfloat32 scale_vec0 = scale.vec0();
116     vfloat32 scale_vec1 = scale.vec1();
117     vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
118     vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
119     return {Vectorized<float>{
120         vec_madd(scale_vec0, float_vals0, scale_zp_premul0),
121         vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}};
122   }
123 
124   float_vec_return_type dequantize(
125       Vectorized<float> scale,
126       Vectorized<float> zero_point) const {
127     vfloat32 float_vals0 = vec_float(_vec0);
128     vfloat32 float_vals1 = vec_float(_vec1);
129     vfloat32 scale_vec0 = scale.vec0();
130     vfloat32 scale_vec1 = scale.vec1();
131     vfloat32 zero_point0 = zero_point.vec0();
132     vfloat32 zero_point1 = zero_point.vec1();
133     return {Vectorized<float>{
134         (float_vals0 - zero_point0) * scale_vec0,
135         (float_vals1 - zero_point1) * scale_vec1}};
136   }
137 
138   static Vectorized<c10::qint32> quantize(
139       const float_vec_return_type& rhs,
140       float scale,
141       int32_t zero_point,
142       float inverse_scale) {
143     Vectorized<c10::qint32> retval;
144 
145     const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min());
146     const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max());
147     vfloat32 inverse_scale_v = vec_splats(inverse_scale);
148     vfloat32 vec_zero_point = vec_splats((float)(zero_point));
149     Vectorized<float> vf0 = rhs[0];
150 
151     vfloat32 vecf0 = vf0.vec0();
152     vfloat32 vecf1 = vf0.vec1();
153     vecf0 = vec_mul(vecf0, inverse_scale_v);
154     vecf1 = vec_mul(vecf1, inverse_scale_v);
155     vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
156     vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
157     vint32 veci0  = vec_signed(vecf0);
158     vint32 veci1  = vec_signed(vecf1);
159 
160     veci0 = vec_max(veci0, vmin);
161     veci1 = vec_max(veci1, vmin);
162     veci0 = vec_min(veci0, vmax);
163     veci1 = vec_min(veci1, vmax);
164 
165     return {veci0, veci1};
166   }
167 
168   Vectorized<c10::qint32> relu(Vectorized<c10::qint32> zero_point) const {
169     return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
170   }
171 
172   Vectorized<c10::qint32> relu6(
173       Vectorized<c10::qint32> zero_point,
174       Vectorized<c10::qint32> q_six) const {
175     vint32 max0 = vec_max(_vec0, zero_point._vec0);
176     vint32 max1 = vec_max(_vec1, zero_point._vec1);
177     return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
178   }
179 
180   int_vec_return_type widening_subtract(Vectorized<c10::qint32> b) const {
181     return {*this - b};
182   }
183 
184   static Vectorized<c10::qint32> requantize_from_int(
185       const int_vec_return_type& inp,
186       float multiplier,
187       int32_t zero_point) {
188     const vint32 vmin = vec_splats(std::numeric_limits<value_type>::min());
189     const vint32 vmax = vec_splats(std::numeric_limits<value_type>::max());
190     vfloat32 vec_mult = vec_splats(multiplier);
191     vint32 vec_zero_point = vec_splats(zero_point);
192     Vectorized<c10::qint32> vi = inp[0];
193     vfloat32 vecf0 = vec_float(vi.vec0());
194     vfloat32 vecf1 = vec_float(vi.vec1());
195 
196     vecf0 = vec_mul(vecf0, vec_mult);
197     vecf1 = vec_mul(vecf1, vec_mult);
198 
199     vecf0 = vec_rint(vecf0);
200     vecf1 = vec_rint(vecf1);
201 
202     vint32 veci0  = vec_add(vec_signed(vecf0),vec_zero_point);
203     vint32 veci1  = vec_add(vec_signed(vecf1),vec_zero_point);
204 
205     veci0 = vec_max(veci0, vmin);
206     veci1 = vec_max(veci1, vmin);
207     veci0 = vec_min(veci0, vmax);
208     veci1 = vec_min(veci1, vmax);
209 
210     return {veci0, veci1};
211   }
212 
213   DEFINE_MEMBER_OP(operator==, c10::qint32, vec_cmpeq)
214   DEFINE_MEMBER_OP(operator!=, c10::qint32, vec_cmpne)
215   DEFINE_MEMBER_OP(operator<, c10::qint32, vec_cmplt)
216   DEFINE_MEMBER_OP(operator<=, c10::qint32, vec_cmple)
217   DEFINE_MEMBER_OP(operator>, c10::qint32, vec_cmpgt)
218   DEFINE_MEMBER_OP(operator>=, c10::qint32, vec_cmpge)
219   DEFINE_MEMBER_OP(operator+, c10::qint32, vec_add)
220   DEFINE_MEMBER_OP(operator-, c10::qint32, vec_sub)
221   DEFINE_MEMBER_OP(operator*, c10::qint32, vec_mul)
222   DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::qint32, /)
223   DEFINE_MEMBER_OP(maximum, c10::qint32, vec_max)
224   DEFINE_MEMBER_OP(minimum, c10::qint32, vec_min)
225   DEFINE_MEMBER_OP(operator&, c10::qint32, vec_and)
226   DEFINE_MEMBER_OP(operator|, c10::qint32, vec_or)
227   DEFINE_MEMBER_OP(operator^, c10::qint32, vec_xor)
228 };
229 
230 template <>
231 Vectorized<c10::qint32> inline maximum(
232     const Vectorized<c10::qint32>& a,
233     const Vectorized<c10::qint32>& b) {
234   return a.maximum(b);
235 }
236 
237 template <>
238 Vectorized<c10::qint32> inline minimum(
239     const Vectorized<c10::qint32>& a,
240     const Vectorized<c10::qint32>& b) {
241   return a.minimum(b);
242 }
243 
244 template <>
245 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator+(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
246   return Vectorized<c10::qint32>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
247 }
248 
249 template <>
250 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator-(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
251   return Vectorized<c10::qint32>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
252 }
253 
254 template <>
255 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator*(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
256   return Vectorized<c10::qint32>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
257 }
258 
259 template <>
260 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator/(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
261   return Vectorized<c10::qint32>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
262 }
263 
264 template <>
265 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator&(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
266   return Vectorized<c10::qint32>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
267 }
268 
269 template <>
270 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator|(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
271   return Vectorized<c10::qint32>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
272 }
273 
274 template <>
275 Vectorized<c10::qint32> C10_ALWAYS_INLINE operator^(const Vectorized<c10::qint32>& a, const Vectorized<c10::qint32>& b) {
276   return Vectorized<c10::qint32>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
277 }
278 
279 } // namespace
280 } // namespace vec
281 } // namespace at
282