xref: /aosp_15_r20/external/XNNPACK/test/vcvt-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <functional>
16*4bdc9457SAndroid Build Coastguard Worker #include <limits>
17*4bdc9457SAndroid Build Coastguard Worker #include <random>
18*4bdc9457SAndroid Build Coastguard Worker #include <vector>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
26*4bdc9457SAndroid Build Coastguard Worker 
27*4bdc9457SAndroid Build Coastguard Worker 
28*4bdc9457SAndroid Build Coastguard Worker class VCvtMicrokernelTester {
29*4bdc9457SAndroid Build Coastguard Worker  public:
batch_size(size_t batch_size)30*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& batch_size(size_t batch_size) {
31*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
32*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
33*4bdc9457SAndroid Build Coastguard Worker     return *this;
34*4bdc9457SAndroid Build Coastguard Worker   }
35*4bdc9457SAndroid Build Coastguard Worker 
batch_size()36*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
37*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
38*4bdc9457SAndroid Build Coastguard Worker   }
39*4bdc9457SAndroid Build Coastguard Worker 
scale(float scale)40*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& scale(float scale) {
41*4bdc9457SAndroid Build Coastguard Worker     assert(scale > 0.0f);
42*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(scale));
43*4bdc9457SAndroid Build Coastguard Worker     this->scale_ = scale;
44*4bdc9457SAndroid Build Coastguard Worker     return *this;
45*4bdc9457SAndroid Build Coastguard Worker   }
46*4bdc9457SAndroid Build Coastguard Worker 
scale()47*4bdc9457SAndroid Build Coastguard Worker   inline float scale() const {
48*4bdc9457SAndroid Build Coastguard Worker     return this->scale_;
49*4bdc9457SAndroid Build Coastguard Worker   }
50*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(int16_t input_zero_point)51*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& input_zero_point(int16_t input_zero_point) {
52*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
53*4bdc9457SAndroid Build Coastguard Worker     return *this;
54*4bdc9457SAndroid Build Coastguard Worker   }
55*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()56*4bdc9457SAndroid Build Coastguard Worker   inline int16_t input_zero_point() const {
57*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
58*4bdc9457SAndroid Build Coastguard Worker   }
59*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(int16_t output_zero_point)60*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& output_zero_point(int16_t output_zero_point) {
61*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
62*4bdc9457SAndroid Build Coastguard Worker     return *this;
63*4bdc9457SAndroid Build Coastguard Worker   }
64*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()65*4bdc9457SAndroid Build Coastguard Worker   inline int16_t output_zero_point() const {
66*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
67*4bdc9457SAndroid Build Coastguard Worker   }
68*4bdc9457SAndroid Build Coastguard Worker 
qmin(int16_t qmin)69*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& qmin(int16_t qmin) {
70*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
71*4bdc9457SAndroid Build Coastguard Worker     return *this;
72*4bdc9457SAndroid Build Coastguard Worker   }
73*4bdc9457SAndroid Build Coastguard Worker 
qmin()74*4bdc9457SAndroid Build Coastguard Worker   inline int16_t qmin() const {
75*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
qmax(int16_t qmax)78*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& qmax(int16_t qmax) {
79*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
80*4bdc9457SAndroid Build Coastguard Worker     return *this;
81*4bdc9457SAndroid Build Coastguard Worker   }
82*4bdc9457SAndroid Build Coastguard Worker 
qmax()83*4bdc9457SAndroid Build Coastguard Worker   inline int16_t qmax() const {
84*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
85*4bdc9457SAndroid Build Coastguard Worker   }
86*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)87*4bdc9457SAndroid Build Coastguard Worker   inline VCvtMicrokernelTester& iterations(size_t iterations) {
88*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
89*4bdc9457SAndroid Build Coastguard Worker     return *this;
90*4bdc9457SAndroid Build Coastguard Worker   }
91*4bdc9457SAndroid Build Coastguard Worker 
iterations()92*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
93*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
94*4bdc9457SAndroid Build Coastguard Worker   }
95*4bdc9457SAndroid Build Coastguard Worker 
96*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_f32_vcvt_ukernel_function vcvt, xnn_init_f16_f32_cvt_params_fn init_params = nullptr) const {
97*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
98*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
99*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-100.0f, 100.0f);
100*4bdc9457SAndroid Build Coastguard Worker 
101*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
102*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size());
103*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
104*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
105*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
106*4bdc9457SAndroid Build Coastguard Worker 
107*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_f32_cvt_params params;
108*4bdc9457SAndroid Build Coastguard Worker       if (init_params) {
109*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
110*4bdc9457SAndroid Build Coastguard Worker       }
111*4bdc9457SAndroid Build Coastguard Worker 
112*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
113*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(uint16_t), input.data(), output.data(), &params);
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
116*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
117*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(float_as_uint32(output[i]), float_as_uint32(fp16_ieee_to_fp32_value(input[i])))
118*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
119*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = 0x" << std::hex << std::setw(4) << std::setfill('0') << input[i];
120*4bdc9457SAndroid Build Coastguard Worker       }
121*4bdc9457SAndroid Build Coastguard Worker     }
122*4bdc9457SAndroid Build Coastguard Worker   }
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_f16_vcvt_ukernel_function vcvt, xnn_init_f32_f16_cvt_params_fn init_params = nullptr) const {
125*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
126*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
127*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-100.0f, 100.0f);
128*4bdc9457SAndroid Build Coastguard Worker 
129*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
130*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(batch_size());
131*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
132*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
133*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
134*4bdc9457SAndroid Build Coastguard Worker 
135*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_f16_cvt_params params;
136*4bdc9457SAndroid Build Coastguard Worker       if (init_params) {
137*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
138*4bdc9457SAndroid Build Coastguard Worker       }
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
141*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(float), input.data(), output.data(), &params);
142*4bdc9457SAndroid Build Coastguard Worker 
143*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
144*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
145*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(output[i], fp16_ieee_from_fp32_value(input[i]))
146*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
147*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = 0x" << std::hex << std::setw(8) << std::setfill('0') << float_as_uint32(input[i])
148*4bdc9457SAndroid Build Coastguard Worker           << " (" << input[i] << ")";
149*4bdc9457SAndroid Build Coastguard Worker       }
150*4bdc9457SAndroid Build Coastguard Worker     }
151*4bdc9457SAndroid Build Coastguard Worker   }
152*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_qs8_vcvt_ukernel_function vcvt,xnn_init_f32_qs8_cvt_params_fn init_params)153*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_qs8_vcvt_ukernel_function vcvt, xnn_init_f32_qs8_cvt_params_fn init_params) const {
154*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(qmin(), std::numeric_limits<int8_t>::min());
155*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(qmax(), std::numeric_limits<int8_t>::max());
156*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LT(qmin(), qmax());
157*4bdc9457SAndroid Build Coastguard Worker 
158*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<int8_t>::min());
159*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<int8_t>::max());
160*4bdc9457SAndroid Build Coastguard Worker 
161*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
162*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
163*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
164*4bdc9457SAndroid Build Coastguard Worker 
165*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
166*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(batch_size());
167*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(batch_size());
168*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
169*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
170*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
171*4bdc9457SAndroid Build Coastguard Worker 
172*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_qs8_cvt_params params;
173*4bdc9457SAndroid Build Coastguard Worker       if (init_params) {
174*4bdc9457SAndroid Build Coastguard Worker         init_params(&params, scale(), output_zero_point(), qmin(), qmax());
175*4bdc9457SAndroid Build Coastguard Worker       }
176*4bdc9457SAndroid Build Coastguard Worker 
177*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
178*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(float), input.data(), output.data(), &params);
179*4bdc9457SAndroid Build Coastguard Worker 
180*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
181*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
182*4bdc9457SAndroid Build Coastguard Worker         float scaled_input = input[i] * scale();
183*4bdc9457SAndroid Build Coastguard Worker         scaled_input = std::min<float>(scaled_input, float(qmax() - output_zero_point()));
184*4bdc9457SAndroid Build Coastguard Worker         scaled_input = std::max<float>(scaled_input, float(qmin() - output_zero_point()));
185*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = int8_t(std::lrintf(scaled_input) + long(output_zero_point()));
186*4bdc9457SAndroid Build Coastguard Worker       }
187*4bdc9457SAndroid Build Coastguard Worker 
188*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
189*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
190*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output[i]), int32_t(output_ref[i]))
191*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
192*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = 0x" << std::hex << std::setw(8) << std::setfill('0') << float_as_uint32(input[i])
193*4bdc9457SAndroid Build Coastguard Worker           << " (" << input[i] << ")";
194*4bdc9457SAndroid Build Coastguard Worker       }
195*4bdc9457SAndroid Build Coastguard Worker     }
196*4bdc9457SAndroid Build Coastguard Worker   }
197*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_qu8_vcvt_ukernel_function vcvt,xnn_init_f32_qu8_cvt_params_fn init_params)198*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_qu8_vcvt_ukernel_function vcvt, xnn_init_f32_qu8_cvt_params_fn init_params) const {
199*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(qmin(), std::numeric_limits<uint8_t>::min());
200*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(qmax(), std::numeric_limits<uint8_t>::max());
201*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LT(qmin(), qmax());
202*4bdc9457SAndroid Build Coastguard Worker 
203*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<uint8_t>::min());
204*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<uint8_t>::max());
205*4bdc9457SAndroid Build Coastguard Worker 
206*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
207*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
208*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
209*4bdc9457SAndroid Build Coastguard Worker 
210*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
211*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(batch_size());
212*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(batch_size());
213*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
214*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
215*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
216*4bdc9457SAndroid Build Coastguard Worker 
217*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_qu8_cvt_params params;
218*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, scale(), output_zero_point(), qmin(), qmax());
219*4bdc9457SAndroid Build Coastguard Worker 
220*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
221*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(float), input.data(), output.data(), &params);
222*4bdc9457SAndroid Build Coastguard Worker 
223*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
224*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
225*4bdc9457SAndroid Build Coastguard Worker         float scaled_input = input[i] * scale();
226*4bdc9457SAndroid Build Coastguard Worker         scaled_input = std::min<float>(scaled_input, float(qmax() - output_zero_point()));
227*4bdc9457SAndroid Build Coastguard Worker         scaled_input = std::max<float>(scaled_input, float(qmin() - output_zero_point()));
228*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = uint8_t(std::lrintf(scaled_input) + long(output_zero_point()));
229*4bdc9457SAndroid Build Coastguard Worker       }
230*4bdc9457SAndroid Build Coastguard Worker 
231*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
232*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
233*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output[i]), int32_t(output_ref[i]))
234*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
235*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = 0x" << std::hex << std::setw(8) << std::setfill('0') << float_as_uint32(input[i])
236*4bdc9457SAndroid Build Coastguard Worker           << " (" << input[i] << ")";
237*4bdc9457SAndroid Build Coastguard Worker       }
238*4bdc9457SAndroid Build Coastguard Worker     }
239*4bdc9457SAndroid Build Coastguard Worker   }
240*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qs8_vcvt_ukernel_function vcvt,xnn_init_qs8_cvt_params_fn init_params)241*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qs8_vcvt_ukernel_function vcvt, xnn_init_qs8_cvt_params_fn init_params) const {
242*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input_zero_point(), std::numeric_limits<int8_t>::min());
243*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input_zero_point(), std::numeric_limits<int8_t>::max());
244*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<int8_t>::min());
245*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<int8_t>::max());
246*4bdc9457SAndroid Build Coastguard Worker 
247*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
248*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
249*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
250*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(batch_size() + XNN_EXTRA_BYTES / sizeof(int8_t));
253*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(batch_size());
254*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(batch_size());
255*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
256*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
257*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
258*4bdc9457SAndroid Build Coastguard Worker 
259*4bdc9457SAndroid Build Coastguard Worker       union xnn_qs8_cvt_params params;
260*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, scale(), input_zero_point(), output_zero_point());
261*4bdc9457SAndroid Build Coastguard Worker 
262*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
263*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(int8_t), input.data(), output.data(), &params);
264*4bdc9457SAndroid Build Coastguard Worker 
265*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
266*4bdc9457SAndroid Build Coastguard Worker       const int32_t multiplier = (int32_t) lrintf(-256.0f * scale());
267*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
268*4bdc9457SAndroid Build Coastguard Worker         const int32_t input_value = (input_zero_point() - input[i]) << 7;
269*4bdc9457SAndroid Build Coastguard Worker         int32_t output_value = math_asr_s32(input_value * multiplier + INT32_C(0x4000), 15) + output_zero_point();
270*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min<int32_t>(output_value, std::numeric_limits<int8_t>::max());
271*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max<int32_t>(output_value, std::numeric_limits<int8_t>::min());
272*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = static_cast<int8_t>(output_value);
273*4bdc9457SAndroid Build Coastguard Worker       }
274*4bdc9457SAndroid Build Coastguard Worker 
275*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
276*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
277*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output[i]), int32_t(output_ref[i]))
278*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
279*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = " << int32_t(input[i]);
280*4bdc9457SAndroid Build Coastguard Worker       }
281*4bdc9457SAndroid Build Coastguard Worker     }
282*4bdc9457SAndroid Build Coastguard Worker   }
283*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qs8_f32_vcvt_ukernel_function vcvt,xnn_init_qs8_f32_cvt_params_fn init_params)284*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qs8_f32_vcvt_ukernel_function vcvt, xnn_init_qs8_f32_cvt_params_fn init_params) const {
285*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input_zero_point(), std::numeric_limits<int8_t>::min());
286*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input_zero_point(), std::numeric_limits<int8_t>::max());
287*4bdc9457SAndroid Build Coastguard Worker 
288*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
289*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
290*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
291*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
292*4bdc9457SAndroid Build Coastguard Worker 
293*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(batch_size() + XNN_EXTRA_BYTES / sizeof(int8_t));
294*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size());
295*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size());
296*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
297*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
298*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
299*4bdc9457SAndroid Build Coastguard Worker 
300*4bdc9457SAndroid Build Coastguard Worker       union xnn_qs8_f32_cvt_params params;
301*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, scale(), input_zero_point());
302*4bdc9457SAndroid Build Coastguard Worker 
303*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
304*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(int8_t), input.data(), output.data(), &params);
305*4bdc9457SAndroid Build Coastguard Worker 
306*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
307*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
308*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = float(int16_t(input[i]) - input_zero_point()) * scale();
309*4bdc9457SAndroid Build Coastguard Worker       }
310*4bdc9457SAndroid Build Coastguard Worker 
311*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
312*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
313*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(output[i], output_ref[i])
314*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
315*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = " << int32_t(input[i]);
316*4bdc9457SAndroid Build Coastguard Worker       }
317*4bdc9457SAndroid Build Coastguard Worker     }
318*4bdc9457SAndroid Build Coastguard Worker   }
319*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_vcvt_ukernel_function vcvt,xnn_init_qu8_cvt_params_fn init_params)320*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qu8_vcvt_ukernel_function vcvt, xnn_init_qu8_cvt_params_fn init_params) const {
321*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input_zero_point(), std::numeric_limits<uint8_t>::min());
322*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input_zero_point(), std::numeric_limits<uint8_t>::max());
323*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<uint8_t>::min());
324*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<uint8_t>::max());
325*4bdc9457SAndroid Build Coastguard Worker 
326*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
327*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
328*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
329*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
330*4bdc9457SAndroid Build Coastguard Worker 
331*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t));
332*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(batch_size());
333*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(batch_size());
334*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
335*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
336*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
337*4bdc9457SAndroid Build Coastguard Worker 
338*4bdc9457SAndroid Build Coastguard Worker       union xnn_qu8_cvt_params params;
339*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, scale(), input_zero_point(), output_zero_point());
340*4bdc9457SAndroid Build Coastguard Worker 
341*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
342*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(uint8_t), input.data(), output.data(), &params);
343*4bdc9457SAndroid Build Coastguard Worker 
344*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
345*4bdc9457SAndroid Build Coastguard Worker       const int32_t multiplier = (int32_t) lrintf(-256.0f * scale());
346*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
347*4bdc9457SAndroid Build Coastguard Worker         const int32_t input_value = (input_zero_point() - input[i]) << 7;
348*4bdc9457SAndroid Build Coastguard Worker         int32_t output_value = math_asr_s32(input_value * multiplier + INT32_C(0x4000), 15) + output_zero_point();
349*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min<int32_t>(output_value, std::numeric_limits<uint8_t>::max());
350*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max<int32_t>(output_value, std::numeric_limits<uint8_t>::min());
351*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = static_cast<uint8_t>(output_value);
352*4bdc9457SAndroid Build Coastguard Worker       }
353*4bdc9457SAndroid Build Coastguard Worker 
354*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
355*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
356*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output[i]), int32_t(output_ref[i]))
357*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
358*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = " << int32_t(input[i]);
359*4bdc9457SAndroid Build Coastguard Worker       }
360*4bdc9457SAndroid Build Coastguard Worker     }
361*4bdc9457SAndroid Build Coastguard Worker   }
362*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_f32_vcvt_ukernel_function vcvt,xnn_init_qu8_f32_cvt_params_fn init_params)363*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qu8_f32_vcvt_ukernel_function vcvt, xnn_init_qu8_f32_cvt_params_fn init_params) const {
364*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input_zero_point(), std::numeric_limits<uint8_t>::min());
365*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input_zero_point(), std::numeric_limits<uint8_t>::max());
366*4bdc9457SAndroid Build Coastguard Worker 
367*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
368*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
369*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
370*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
371*4bdc9457SAndroid Build Coastguard Worker 
372*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t));
373*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size());
374*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size());
375*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
376*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
377*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
378*4bdc9457SAndroid Build Coastguard Worker 
379*4bdc9457SAndroid Build Coastguard Worker       union xnn_qu8_f32_cvt_params params;
380*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, scale(), input_zero_point());
381*4bdc9457SAndroid Build Coastguard Worker 
382*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
383*4bdc9457SAndroid Build Coastguard Worker       vcvt(batch_size() * sizeof(uint8_t), input.data(), output.data(), &params);
384*4bdc9457SAndroid Build Coastguard Worker 
385*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results
386*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
387*4bdc9457SAndroid Build Coastguard Worker         output_ref[i] = float(int16_t(input[i]) - input_zero_point()) * scale();
388*4bdc9457SAndroid Build Coastguard Worker       }
389*4bdc9457SAndroid Build Coastguard Worker 
390*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
391*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
392*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(output[i], output_ref[i])
393*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
394*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "] = " << int32_t(input[i]);
395*4bdc9457SAndroid Build Coastguard Worker       }
396*4bdc9457SAndroid Build Coastguard Worker     }
397*4bdc9457SAndroid Build Coastguard Worker   }
398*4bdc9457SAndroid Build Coastguard Worker 
399*4bdc9457SAndroid Build Coastguard Worker  private:
400*4bdc9457SAndroid Build Coastguard Worker   float scale_ = 1.75f;
401*4bdc9457SAndroid Build Coastguard Worker   int16_t input_zero_point_ = 1;
402*4bdc9457SAndroid Build Coastguard Worker   int16_t output_zero_point_ = 5;
403*4bdc9457SAndroid Build Coastguard Worker   int16_t qmin_ = std::numeric_limits<int16_t>::min();
404*4bdc9457SAndroid Build Coastguard Worker   int16_t qmax_ = std::numeric_limits<int16_t>::max();
405*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_ = 1;
406*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_ = 15;
407*4bdc9457SAndroid Build Coastguard Worker };
408