xref: /aosp_15_r20/external/executorch/backends/xnnpack/runtime/utils/utils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 #include <unistd.h>
11 #include <cmath>
12 #include <limits>
13 #include <vector>
14 
15 #include <executorch/runtime/core/error.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 
18 #ifdef __aarch64__
19 #include <arm_neon.h>
20 #endif
21 
22 namespace executorch {
23 namespace backends {
24 namespace xnnpack {
25 namespace utils {
26 
27 struct QuantizationParams {
28   double scale;
29   int32_t zero_point;
30 };
31 
32 executorch::runtime::Error ChooseQuantizationParams(
33     float min,
34     float max,
35     int32_t qmin,
36     int32_t qmax,
37     QuantizationParams& result,
38     bool preserve_sparsity,
39     bool force_scale_power_of_two,
40     bool reduce_range);
41 
42 #if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
43 template <class T>
Round(const float x)44 inline float Round(const float x) {
45   return ::nearbyintf(x);
46 }
Round(const double x)47 inline double Round(const double x) {
48   return ::nearbyint(x);
49 }
50 #else
51 template <class T>
Round(const T x)52 inline T Round(const T x) {
53   return std::nearbyint(x);
54 }
55 #endif
56 
57 template <typename T>
quantize_val(double scale,int64_t zero_point,float value)58 T quantize_val(double scale, int64_t zero_point, float value) {
59   // std::nearbyint results in nearest integer value according to the current
60   // rounding mode and the default rounding mode is rounds to even in half-way
61   // cases in most popular processor architectures like x86 and ARM. This is
62   // typically faster than an alternatives like std::round that rounds half-way
63   // cases away from zero, and can be consistent with SIMD implementations for
64   // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
65   // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
66   int64_t qvalue;
67   constexpr int64_t qmin = std::numeric_limits<T>::min();
68   constexpr int64_t qmax = std::numeric_limits<T>::max();
69   float inv_scale = 1.0f / static_cast<float>(scale);
70   qvalue = static_cast<int64_t>(zero_point + Round(value * inv_scale));
71   qvalue = std::max<int64_t>(qvalue, qmin);
72   qvalue = std::min<int64_t>(qvalue, qmax);
73   return static_cast<T>(qvalue);
74 }
75 
76 #ifdef __aarch64__
77 template <typename Tx8>
78 Tx8 vqmov(int16x8_t vraw);
79 
80 template <typename T, typename Tx8>
81 void vst1(T* out, Tx8 vout);
82 
83 template <typename underlying_t, typename underlying_x8_t>
quantize_tensor_arm64_q8(const float * __restrict__ in,underlying_t * __restrict__ out,const int64_t N,const float scale,const int32_t zero_point)84 void quantize_tensor_arm64_q8(
85     const float* __restrict__ in,
86     underlying_t* __restrict__ out,
87     const int64_t N,
88     const float scale,
89     const int32_t zero_point) {
90   const float inv_scale = 1.0f / scale;
91   uint32_t i = 0;
92   underlying_t* out_underlying = reinterpret_cast<underlying_t*>(out);
93   const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
94 
95   const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point);
96   for (i = 0; i + 8 <= N; i += 8) {
97     const float32x4_t vin0123 = vld1q_f32(in);
98     in += 4;
99     const float32x4_t vin4567 = vld1q_f32(in);
100     in += 4;
101     const int32x4_t v0123_rounded =
102         vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
103     const int32x4_t v4567_rounded =
104         vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
105     const int16x8_t v01234567_packed = vqaddq_s16(
106         vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point);
107     const underlying_x8_t vout01234567 =
108         vqmov<underlying_x8_t>(v01234567_packed);
109     vst1<underlying_t, underlying_x8_t>(out_underlying, vout01234567);
110     out_underlying += 8;
111   }
112   for (; i < N; ++i) {
113     (*out_underlying++) =
114         quantize_val<underlying_t>(scale, zero_point, (*in++));
115   }
116 }
117 
118 template <typename T>
119 void quantize_tensor_arm64_q8_wrapper(
120     const float* __restrict__ in,
121     T* __restrict__ out,
122     const int64_t N,
123     const float scale,
124     const int32_t zero_point);
125 
126 #endif /* __aarch64__ */
127 
128 template <typename T = uint8_t>
QuantizePerTensor(const executorch::aten::Tensor & rtensor,executorch::aten::Tensor & qtensor,double scale,int zero_point)129 executorch::runtime::Error QuantizePerTensor(
130     const executorch::aten::Tensor& rtensor,
131     executorch::aten::Tensor& qtensor,
132     double scale,
133     int zero_point) {
134   const float* rdata = rtensor.const_data_ptr<float>();
135   int numel = rtensor.numel();
136   ET_CHECK_OR_RETURN_ERROR(
137       (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value),
138       Internal,
139       "Expecting quantized output tensor of dtype uint8_t or int8_t");
140   ET_CHECK_OR_RETURN_ERROR(
141       rtensor.numel() <= qtensor.numel(),
142       Internal,
143       "Expecting quantized output tensor of same or smaller size as input, %zd vs. %zd",
144       qtensor.numel(),
145       rtensor.numel());
146   T* qdata = qtensor.mutable_data_ptr<T>();
147 
148 #if defined(__aarch64__)
149   quantize_tensor_arm64_q8_wrapper<T>(rdata, qdata, numel, scale, zero_point);
150 #else
151   for (int i = 0; i < numel; ++i) {
152     qdata[i] = quantize_val<T>(scale, zero_point, rdata[i]);
153   }
154 #endif /* __aarch64__ */
155   return executorch::runtime::Error::Ok;
156 }
157 
158 executorch::runtime::Error GenerateRequantizationScale(
159     const executorch::aten::Tensor& weight_scales,
160     float input_scale,
161     float output_scale,
162     std::vector<float>& requant_scales);
163 
164 std::pair<float, float> GetMinMax(const executorch::aten::Tensor& ft);
165 
166 } // namespace utils
167 } // namespace xnnpack
168 } // namespace backends
169 } // namespace executorch
170