xref: /aosp_15_r20/external/executorch/backends/cadence/reference/kernels/kernels.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <math.h>
11 #include <algorithm>
12 #include <cstring>
13 #include <limits>
14 #include <numeric>
15 
16 namespace impl {
17 namespace reference {
18 namespace kernels {
19 
20 // Quantize a fp32 value to an int8_t/uint8_t value
21 template <typename T>
quantize(const float x,float scale,int32_t zero_point)22 T quantize(const float x, float scale, int32_t zero_point) {
23   constexpr float min_val = std::numeric_limits<T>::min();
24   constexpr float max_val = std::numeric_limits<T>::max();
25   float tmp = roundf(x * scale + zero_point);
26   return std::max(std::min(tmp, max_val), min_val);
27 }
28 
29 // Quantize an fp32 array to an int8_t/uint8_t array
30 template <typename T>
quantize(T * __restrict__ y,const float * __restrict__ x,float inv_scale,int32_t zero_point,size_t size)31 void quantize(
32     T* __restrict__ y,
33     const float* __restrict__ x,
34     float inv_scale,
35     int32_t zero_point,
36     size_t size) {
37   for (size_t i = 0; i < size; ++i) {
38     y[i] = quantize<T>(x[i], inv_scale, zero_point);
39   }
40 }
41 
42 // Dequantize an int8_t/uint8_t value to an fp32 value
43 template <typename T>
dequantize(const T x,float scale,int32_t zero_point)44 float dequantize(const T x, float scale, int32_t zero_point) {
45   return scale * (x - zero_point);
46 }
47 
48 // Dequantize an int8_t/uint8_t/int16_t array to an fp32 array
49 template <typename T>
dequantize(float * __restrict__ y,const T * __restrict__ x,float scale,int32_t zero_point,size_t size)50 void dequantize(
51     float* __restrict__ y,
52     const T* __restrict__ x,
53     float scale,
54     int32_t zero_point,
55     size_t size) {
56   for (size_t i = 0; i < size; ++i) {
57     y[i] = dequantize<T>(x[i], scale, zero_point);
58   }
59 }
60 
61 // explicit template instantiation
62 
63 #define typed_quantize_val(dtype) \
64   template dtype quantize(const float x, float inv_scale, int32_t zero_point);
65 typed_quantize_val(int8_t);
66 typed_quantize_val(uint8_t);
67 typed_quantize_val(int16_t);
68 typed_quantize_val(uint16_t);
69 typed_quantize_val(int32_t);
70 #undef typed_quantize_val
71 
72 #define typed_quantize_vec(dtype)  \
73   template void quantize(          \
74       dtype* __restrict__ y,       \
75       const float* __restrict__ x, \
76       float inv_scale,             \
77       int32_t zero_point,          \
78       size_t size);
79 typed_quantize_vec(int8_t);
80 typed_quantize_vec(uint8_t);
81 typed_quantize_vec(int16_t);
82 typed_quantize_vec(uint16_t);
83 typed_quantize_vec(int32_t);
84 #undef typed_quantize_vec
85 
86 #define typed_dequantize_val(dtype) \
87   template float dequantize(const dtype x, float scale, int32_t zero_point);
88 typed_dequantize_val(int8_t);
89 typed_dequantize_val(uint8_t);
90 typed_dequantize_val(int16_t);
91 typed_dequantize_val(uint16_t);
92 typed_dequantize_val(int32_t);
93 #undef typed_dequantize_val
94 
95 #define typed_dequantize_vec(dtype) \
96   template void dequantize(         \
97       float* __restrict__ y,        \
98       const dtype* __restrict__ x,  \
99       float scale,                  \
100       int32_t zero_point,           \
101       size_t size);
102 typed_dequantize_vec(int8_t);
103 typed_dequantize_vec(uint8_t);
104 typed_dequantize_vec(int16_t);
105 typed_dequantize_vec(uint16_t);
106 typed_dequantize_vec(int32_t);
107 #undef typed_dequantize_vec
108 
109 }; // namespace kernels
110 }; // namespace reference
111 }; // namespace impl
112