1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/cadence/reference/kernels/kernels.h>
10 #include <math.h>
11 #include <algorithm>
12 #include <cstring>
13 #include <limits>
14 #include <numeric>
15
16 namespace impl {
17 namespace reference {
18 namespace kernels {
19
20 // Quantize a fp32 value to an int8_t/uint8_t value
21 template <typename T>
quantize(const float x,float scale,int32_t zero_point)22 T quantize(const float x, float scale, int32_t zero_point) {
23 constexpr float min_val = std::numeric_limits<T>::min();
24 constexpr float max_val = std::numeric_limits<T>::max();
25 float tmp = roundf(x * scale + zero_point);
26 return std::max(std::min(tmp, max_val), min_val);
27 }
28
29 // Quantize an fp32 array to an int8_t/uint8_t array
30 template <typename T>
quantize(T * __restrict__ y,const float * __restrict__ x,float inv_scale,int32_t zero_point,size_t size)31 void quantize(
32 T* __restrict__ y,
33 const float* __restrict__ x,
34 float inv_scale,
35 int32_t zero_point,
36 size_t size) {
37 for (size_t i = 0; i < size; ++i) {
38 y[i] = quantize<T>(x[i], inv_scale, zero_point);
39 }
40 }
41
42 // Dequantize an int8_t/uint8_t value to an fp32 value
43 template <typename T>
dequantize(const T x,float scale,int32_t zero_point)44 float dequantize(const T x, float scale, int32_t zero_point) {
45 return scale * (x - zero_point);
46 }
47
48 // Dequantize an int8_t/uint8_t/int16_t array to an fp32 array
49 template <typename T>
dequantize(float * __restrict__ y,const T * __restrict__ x,float scale,int32_t zero_point,size_t size)50 void dequantize(
51 float* __restrict__ y,
52 const T* __restrict__ x,
53 float scale,
54 int32_t zero_point,
55 size_t size) {
56 for (size_t i = 0; i < size; ++i) {
57 y[i] = dequantize<T>(x[i], scale, zero_point);
58 }
59 }
60
61 // explicit template instantiation
62
63 #define typed_quantize_val(dtype) \
64 template dtype quantize(const float x, float inv_scale, int32_t zero_point);
65 typed_quantize_val(int8_t);
66 typed_quantize_val(uint8_t);
67 typed_quantize_val(int16_t);
68 typed_quantize_val(uint16_t);
69 typed_quantize_val(int32_t);
70 #undef typed_quantize_val
71
72 #define typed_quantize_vec(dtype) \
73 template void quantize( \
74 dtype* __restrict__ y, \
75 const float* __restrict__ x, \
76 float inv_scale, \
77 int32_t zero_point, \
78 size_t size);
79 typed_quantize_vec(int8_t);
80 typed_quantize_vec(uint8_t);
81 typed_quantize_vec(int16_t);
82 typed_quantize_vec(uint16_t);
83 typed_quantize_vec(int32_t);
84 #undef typed_quantize_vec
85
86 #define typed_dequantize_val(dtype) \
87 template float dequantize(const dtype x, float scale, int32_t zero_point);
88 typed_dequantize_val(int8_t);
89 typed_dequantize_val(uint8_t);
90 typed_dequantize_val(int16_t);
91 typed_dequantize_val(uint16_t);
92 typed_dequantize_val(int32_t);
93 #undef typed_dequantize_val
94
95 #define typed_dequantize_vec(dtype) \
96 template void dequantize( \
97 float* __restrict__ y, \
98 const dtype* __restrict__ x, \
99 float scale, \
100 int32_t zero_point, \
101 size_t size);
102 typed_dequantize_vec(int8_t);
103 typed_dequantize_vec(uint8_t);
104 typed_dequantize_vec(int16_t);
105 typed_dequantize_vec(uint16_t);
106 typed_dequantize_vec(int32_t);
107 #undef typed_dequantize_vec
108
109 }; // namespace kernels
110 }; // namespace reference
111 }; // namespace impl
112