xref: /aosp_15_r20/external/XNNPACK/test/dwconv-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
29*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
30*4bdc9457SAndroid Build Coastguard Worker 
31*4bdc9457SAndroid Build Coastguard Worker 
32*4bdc9457SAndroid Build Coastguard Worker class DWConvMicrokernelTester {
33*4bdc9457SAndroid Build Coastguard Worker  public:
width(uint32_t width)34*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& width(uint32_t width) {
35*4bdc9457SAndroid Build Coastguard Worker     assert(width >= 1);
36*4bdc9457SAndroid Build Coastguard Worker     this->width_ = width;
37*4bdc9457SAndroid Build Coastguard Worker     return *this;
38*4bdc9457SAndroid Build Coastguard Worker   }
39*4bdc9457SAndroid Build Coastguard Worker 
width()40*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t width() const {
41*4bdc9457SAndroid Build Coastguard Worker     return this->width_;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
step(uint32_t step)44*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& step(uint32_t step) {
45*4bdc9457SAndroid Build Coastguard Worker     assert(step >= 1);
46*4bdc9457SAndroid Build Coastguard Worker     this->step_ = step;
47*4bdc9457SAndroid Build Coastguard Worker     return *this;
48*4bdc9457SAndroid Build Coastguard Worker   }
49*4bdc9457SAndroid Build Coastguard Worker 
step()50*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t step() const {
51*4bdc9457SAndroid Build Coastguard Worker     return this->step_;
52*4bdc9457SAndroid Build Coastguard Worker   }
53*4bdc9457SAndroid Build Coastguard Worker 
channels(uint32_t channels)54*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& channels(uint32_t channels) {
55*4bdc9457SAndroid Build Coastguard Worker     assert(channels >= 1);
56*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
57*4bdc9457SAndroid Build Coastguard Worker     return *this;
58*4bdc9457SAndroid Build Coastguard Worker   }
59*4bdc9457SAndroid Build Coastguard Worker 
channels()60*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t channels() const {
61*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
62*4bdc9457SAndroid Build Coastguard Worker   }
63*4bdc9457SAndroid Build Coastguard Worker 
cr(uint32_t cr)64*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& cr(uint32_t cr) {
65*4bdc9457SAndroid Build Coastguard Worker     assert(cr != 0);
66*4bdc9457SAndroid Build Coastguard Worker     this->cr_ = cr;
67*4bdc9457SAndroid Build Coastguard Worker     return *this;
68*4bdc9457SAndroid Build Coastguard Worker   }
69*4bdc9457SAndroid Build Coastguard Worker 
cr()70*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t cr() const {
71*4bdc9457SAndroid Build Coastguard Worker     return this->cr_;
72*4bdc9457SAndroid Build Coastguard Worker   }
73*4bdc9457SAndroid Build Coastguard Worker 
kr(uint32_t kr)74*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& kr(uint32_t kr) {
75*4bdc9457SAndroid Build Coastguard Worker     assert(kr != 0);
76*4bdc9457SAndroid Build Coastguard Worker     this->kr_ = kr;
77*4bdc9457SAndroid Build Coastguard Worker     return *this;
78*4bdc9457SAndroid Build Coastguard Worker   }
79*4bdc9457SAndroid Build Coastguard Worker 
kr()80*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t kr() const {
81*4bdc9457SAndroid Build Coastguard Worker     return this->kr_;
82*4bdc9457SAndroid Build Coastguard Worker   }
83*4bdc9457SAndroid Build Coastguard Worker 
packed_channels()84*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t packed_channels() const {
85*4bdc9457SAndroid Build Coastguard Worker     return (channels() / cr() + !!(channels() % cr())) * cr();
86*4bdc9457SAndroid Build Coastguard Worker   }
87*4bdc9457SAndroid Build Coastguard Worker 
output_stride(uint32_t output_stride)88*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& output_stride(uint32_t output_stride) {
89*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
90*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
91*4bdc9457SAndroid Build Coastguard Worker     return *this;
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
output_stride()94*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t output_stride() const {
95*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
96*4bdc9457SAndroid Build Coastguard Worker       return channels();
97*4bdc9457SAndroid Build Coastguard Worker     } else {
98*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
99*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
100*4bdc9457SAndroid Build Coastguard Worker     }
101*4bdc9457SAndroid Build Coastguard Worker   }
102*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(uint8_t input_zero_point)103*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& input_zero_point(uint8_t input_zero_point) {
104*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
105*4bdc9457SAndroid Build Coastguard Worker     return *this;
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()108*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t input_zero_point() const {
109*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
110*4bdc9457SAndroid Build Coastguard Worker   }
111*4bdc9457SAndroid Build Coastguard Worker 
kernel_zero_point(uint8_t kernel_zero_point)112*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& kernel_zero_point(uint8_t kernel_zero_point) {
113*4bdc9457SAndroid Build Coastguard Worker     this->kernel_zero_point_ = kernel_zero_point;
114*4bdc9457SAndroid Build Coastguard Worker     return *this;
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker 
kernel_zero_point()117*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t kernel_zero_point() const {
118*4bdc9457SAndroid Build Coastguard Worker     return this->kernel_zero_point_;
119*4bdc9457SAndroid Build Coastguard Worker   }
120*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)121*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& qmin(uint8_t qmin) {
122*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
123*4bdc9457SAndroid Build Coastguard Worker     return *this;
124*4bdc9457SAndroid Build Coastguard Worker   }
125*4bdc9457SAndroid Build Coastguard Worker 
qmin()126*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
127*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
128*4bdc9457SAndroid Build Coastguard Worker   }
129*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)130*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& qmax(uint8_t qmax) {
131*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
132*4bdc9457SAndroid Build Coastguard Worker     return *this;
133*4bdc9457SAndroid Build Coastguard Worker   }
134*4bdc9457SAndroid Build Coastguard Worker 
qmax()135*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
136*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
137*4bdc9457SAndroid Build Coastguard Worker   }
138*4bdc9457SAndroid Build Coastguard Worker 
input_offset(size_t input_offset)139*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& input_offset(size_t input_offset) {
140*4bdc9457SAndroid Build Coastguard Worker     this->input_offset_ = input_offset;
141*4bdc9457SAndroid Build Coastguard Worker     return *this;
142*4bdc9457SAndroid Build Coastguard Worker   }
143*4bdc9457SAndroid Build Coastguard Worker 
input_offset()144*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_offset() const {
145*4bdc9457SAndroid Build Coastguard Worker     return this->input_offset_;
146*4bdc9457SAndroid Build Coastguard Worker   }
147*4bdc9457SAndroid Build Coastguard Worker 
zero_index(size_t zero_index)148*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& zero_index(size_t zero_index) {
149*4bdc9457SAndroid Build Coastguard Worker     this->zero_index_ = zero_index;
150*4bdc9457SAndroid Build Coastguard Worker     return *this;
151*4bdc9457SAndroid Build Coastguard Worker   }
152*4bdc9457SAndroid Build Coastguard Worker 
zero_index()153*4bdc9457SAndroid Build Coastguard Worker   inline size_t zero_index() const {
154*4bdc9457SAndroid Build Coastguard Worker     return this->zero_index_;
155*4bdc9457SAndroid Build Coastguard Worker   }
156*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)157*4bdc9457SAndroid Build Coastguard Worker   inline DWConvMicrokernelTester& iterations(size_t iterations) {
158*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
159*4bdc9457SAndroid Build Coastguard Worker     return *this;
160*4bdc9457SAndroid Build Coastguard Worker   }
161*4bdc9457SAndroid Build Coastguard Worker 
iterations()162*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
163*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
164*4bdc9457SAndroid Build Coastguard Worker   }
165*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)166*4bdc9457SAndroid Build Coastguard Worker   void Test(
167*4bdc9457SAndroid Build Coastguard Worker     xnn_qu8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,
168*4bdc9457SAndroid Build Coastguard Worker     xnn_init_qu8_conv_minmax_params_fn init_params,
169*4bdc9457SAndroid Build Coastguard Worker     xnn_qu8_requantize_fn requantize) const
170*4bdc9457SAndroid Build Coastguard Worker   {
171*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
172*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
173*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
174*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
175*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
176*4bdc9457SAndroid Build Coastguard Worker 
177*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint8_t*> indirection((width() - 1) * step() + kr());
178*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + indirection.size() * channels());
179*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> kernel(channels() * kr());
180*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(channels());
181*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_weights((kr() + sizeof(int32_t) / sizeof(uint8_t)) * packed_channels());
182*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
183*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((width() - 1) * output_stride() + channels());
184*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(width() * channels());
185*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(width() * channels());
186*4bdc9457SAndroid Build Coastguard Worker 
187*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
188*4bdc9457SAndroid Build Coastguard Worker       do {
189*4bdc9457SAndroid Build Coastguard Worker         std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
190*4bdc9457SAndroid Build Coastguard Worker       } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
191*4bdc9457SAndroid Build Coastguard Worker       do {
192*4bdc9457SAndroid Build Coastguard Worker         std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
193*4bdc9457SAndroid Build Coastguard Worker       } while (kernel.size() > 1 && *std::max_element(kernel.cbegin(), kernel.cend()) == *std::min_element(kernel.cbegin(), kernel.cend()));
194*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
195*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), input_zero_point());
196*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
197*4bdc9457SAndroid Build Coastguard Worker 
198*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0);
199*4bdc9457SAndroid Build Coastguard Worker       const xnn_qu8_packing_params packing_params = { input_zero_point(), kernel_zero_point() };
200*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_qu8_dwconv_ghw_w(
201*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
202*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(),
203*4bdc9457SAndroid Build Coastguard Worker         0 /* extra bytes */, &packing_params);
204*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
205*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
206*4bdc9457SAndroid Build Coastguard Worker       }
207*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
208*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
209*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
210*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
211*4bdc9457SAndroid Build Coastguard Worker         }
212*4bdc9457SAndroid Build Coastguard Worker       }
213*4bdc9457SAndroid Build Coastguard Worker 
214*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
215*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
216*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
217*4bdc9457SAndroid Build Coastguard Worker           float acc = bias[c];
218*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
219*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
220*4bdc9457SAndroid Build Coastguard Worker               acc +=
221*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(indirection[x * step() + k][c + input_offset()]) - int32_t(input_zero_point())) *
222*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(kernel[c * kr() + k]) - int32_t(kernel_zero_point()));
223*4bdc9457SAndroid Build Coastguard Worker             }
224*4bdc9457SAndroid Build Coastguard Worker           }
225*4bdc9457SAndroid Build Coastguard Worker           accumulators[x * channels() + c] = acc;
226*4bdc9457SAndroid Build Coastguard Worker         }
227*4bdc9457SAndroid Build Coastguard Worker       }
228*4bdc9457SAndroid Build Coastguard Worker 
229*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
230*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
231*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
232*4bdc9457SAndroid Build Coastguard Worker       const uint32_t accumulated_range = uint32_t(accumulated_max) - uint32_t(accumulated_min);
233*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
234*4bdc9457SAndroid Build Coastguard Worker       const uint8_t output_zero_point = uint8_t(std::max(std::min(
235*4bdc9457SAndroid Build Coastguard Worker         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
236*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
237*4bdc9457SAndroid Build Coastguard Worker 
238*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
239*4bdc9457SAndroid Build Coastguard Worker       const float requantization_scale = 1.0f / float(output_scale);
240*4bdc9457SAndroid Build Coastguard Worker       union xnn_qu8_conv_minmax_params quantization_params;
241*4bdc9457SAndroid Build Coastguard Worker       init_params(&quantization_params,
242*4bdc9457SAndroid Build Coastguard Worker         kernel_zero_point(), requantization_scale, output_zero_point, qmin(), qmax());
243*4bdc9457SAndroid Build Coastguard Worker 
244*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
245*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
246*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
247*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = requantize(
248*4bdc9457SAndroid Build Coastguard Worker             accumulators[x * channels() + c], requantization_scale, output_zero_point, qmin(), qmax());
249*4bdc9457SAndroid Build Coastguard Worker         }
250*4bdc9457SAndroid Build Coastguard Worker       }
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
253*4bdc9457SAndroid Build Coastguard Worker       dwconv_minmax(
254*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
255*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), packed_weights.data(), output.data(),
256*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
257*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint8_t),
258*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(uint8_t), zero.data(),
259*4bdc9457SAndroid Build Coastguard Worker         &quantization_params);
260*4bdc9457SAndroid Build Coastguard Worker 
261*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
262*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
263*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
264*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(uint32_t(output[x * output_stride() + c]), uint32_t(qmin()))
265*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
266*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(uint32_t(output[x * output_stride() + c]), uint32_t(qmax()))
267*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
268*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(uint32_t(output[x * output_stride() + c]), uint32_t(output_ref[x * channels() + c]))
269*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c << ", accumulator = " << accumulators[x * channels() + c];
270*4bdc9457SAndroid Build Coastguard Worker         }
271*4bdc9457SAndroid Build Coastguard Worker       }
272*4bdc9457SAndroid Build Coastguard Worker     }
273*4bdc9457SAndroid Build Coastguard Worker   }
274*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qc8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)275*4bdc9457SAndroid Build Coastguard Worker   void Test(
276*4bdc9457SAndroid Build Coastguard Worker     xnn_qc8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,
277*4bdc9457SAndroid Build Coastguard Worker     xnn_init_qc8_conv_minmax_params_fn init_params,
278*4bdc9457SAndroid Build Coastguard Worker     xnn_qs8_requantize_fn requantize) const
279*4bdc9457SAndroid Build Coastguard Worker   {
280*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
281*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
282*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
283*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
284*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
285*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
286*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
287*4bdc9457SAndroid Build Coastguard Worker 
288*4bdc9457SAndroid Build Coastguard Worker     std::vector<const int8_t*> indirection((width() - 1) * step() + kr());
289*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + indirection.size() * channels());
290*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(channels() * kr());
291*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(channels());
292*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_weights((kr() + (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t)) * packed_channels());
293*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
294*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output((width() - 1) * output_stride() + channels());
295*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(width() * channels());
296*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> scale(channels());
297*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(width() * channels());
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
300*4bdc9457SAndroid Build Coastguard Worker       do {
301*4bdc9457SAndroid Build Coastguard Worker         std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
302*4bdc9457SAndroid Build Coastguard Worker       } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
303*4bdc9457SAndroid Build Coastguard Worker       do {
304*4bdc9457SAndroid Build Coastguard Worker         std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
305*4bdc9457SAndroid Build Coastguard Worker       } while (kernel.size() > 1 && *std::max_element(kernel.cbegin(), kernel.cend()) == *std::min_element(kernel.cbegin(), kernel.cend()));
306*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
307*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), int8_t(input_zero_point() - 0x80));
308*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
309*4bdc9457SAndroid Build Coastguard Worker 
310*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0);
311*4bdc9457SAndroid Build Coastguard Worker       const xnn_qs8_packing_params packing_params = { int8_t(input_zero_point() - 0x80) };
312*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_qs8_dwconv_ghw_w(
313*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
314*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(), cr() * sizeof(float),
315*4bdc9457SAndroid Build Coastguard Worker         &packing_params);
316*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
317*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
318*4bdc9457SAndroid Build Coastguard Worker       }
319*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
320*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
321*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
322*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
323*4bdc9457SAndroid Build Coastguard Worker         }
324*4bdc9457SAndroid Build Coastguard Worker       }
325*4bdc9457SAndroid Build Coastguard Worker 
326*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
327*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
328*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
329*4bdc9457SAndroid Build Coastguard Worker           float acc = bias[c];
330*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
331*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
332*4bdc9457SAndroid Build Coastguard Worker               acc +=
333*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(indirection[x * step() + k][c + input_offset()]) - int32_t(input_zero_point() - 0x80)) *
334*4bdc9457SAndroid Build Coastguard Worker                 int32_t(kernel[c * kr() + k]);
335*4bdc9457SAndroid Build Coastguard Worker             }
336*4bdc9457SAndroid Build Coastguard Worker           }
337*4bdc9457SAndroid Build Coastguard Worker           accumulators[x * channels() + c] = acc;
338*4bdc9457SAndroid Build Coastguard Worker         }
339*4bdc9457SAndroid Build Coastguard Worker       }
340*4bdc9457SAndroid Build Coastguard Worker 
341*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
342*4bdc9457SAndroid Build Coastguard Worker       const int8_t output_zero_point = -1;
343*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
344*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_min = accumulators[c];
345*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_max = accumulators[c];
346*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < width(); x++) {
347*4bdc9457SAndroid Build Coastguard Worker           accumulated_min = std::min(accumulated_min, accumulators[x * channels() + c]);
348*4bdc9457SAndroid Build Coastguard Worker           accumulated_max = std::max(accumulated_max, accumulators[x * channels() + c]);
349*4bdc9457SAndroid Build Coastguard Worker         }
350*4bdc9457SAndroid Build Coastguard Worker         const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
351*4bdc9457SAndroid Build Coastguard Worker         const float output_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
352*4bdc9457SAndroid Build Coastguard Worker         scale[c] = 1.0f / output_scale;
353*4bdc9457SAndroid Build Coastguard Worker       }
354*4bdc9457SAndroid Build Coastguard Worker       xnn_init_qc8_scale_fp32_params(
355*4bdc9457SAndroid Build Coastguard Worker         channels(), cr(),
356*4bdc9457SAndroid Build Coastguard Worker         cr() * (kr() * sizeof(int8_t) + sizeof(int32_t) + sizeof(float)), scale.data(),
357*4bdc9457SAndroid Build Coastguard Worker         (void*) ((uintptr_t) packed_weights.data() + cr() * (kr() * sizeof(int8_t) + sizeof(int32_t))));
358*4bdc9457SAndroid Build Coastguard Worker 
359*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
360*4bdc9457SAndroid Build Coastguard Worker       union xnn_qc8_conv_minmax_params minmax_params;
361*4bdc9457SAndroid Build Coastguard Worker       init_params(&minmax_params,
362*4bdc9457SAndroid Build Coastguard Worker         output_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
363*4bdc9457SAndroid Build Coastguard Worker 
364*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
365*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
366*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
367*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = requantize(
368*4bdc9457SAndroid Build Coastguard Worker             accumulators[x * channels() + c], scale[c], output_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
369*4bdc9457SAndroid Build Coastguard Worker         }
370*4bdc9457SAndroid Build Coastguard Worker       }
371*4bdc9457SAndroid Build Coastguard Worker 
372*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
373*4bdc9457SAndroid Build Coastguard Worker       dwconv_minmax(
374*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
375*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), packed_weights.data(), output.data(),
376*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
377*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(int8_t),
378*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(int8_t), zero.data(),
379*4bdc9457SAndroid Build Coastguard Worker         &minmax_params);
380*4bdc9457SAndroid Build Coastguard Worker 
381*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
382*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
383*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
384*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(int32_t(output[x * output_stride() + c]), int32_t(qmin()) - 0x80)
385*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
386*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(int32_t(output[x * output_stride() + c]), int32_t(qmax()) - 0x80)
387*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
388*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(int32_t(output[x * output_stride() + c]), int32_t(output_ref[x * channels() + c]))
389*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c << ", accumulator = " << accumulators[x * channels() + c];
390*4bdc9457SAndroid Build Coastguard Worker         }
391*4bdc9457SAndroid Build Coastguard Worker       }
392*4bdc9457SAndroid Build Coastguard Worker     }
393*4bdc9457SAndroid Build Coastguard Worker   }
394*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qs8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)395*4bdc9457SAndroid Build Coastguard Worker   void Test(
396*4bdc9457SAndroid Build Coastguard Worker     xnn_qs8_dwconv_minmax_unipass_ukernel_function dwconv_minmax,
397*4bdc9457SAndroid Build Coastguard Worker     xnn_init_qs8_conv_minmax_params_fn init_params,
398*4bdc9457SAndroid Build Coastguard Worker     xnn_qs8_requantize_fn requantize) const
399*4bdc9457SAndroid Build Coastguard Worker   {
400*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
401*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
402*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
403*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
404*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
405*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
406*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
407*4bdc9457SAndroid Build Coastguard Worker 
408*4bdc9457SAndroid Build Coastguard Worker     std::vector<const int8_t*> indirection((width() - 1) * step() + kr());
409*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + indirection.size() * channels());
410*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(channels() * kr());
411*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(channels());
412*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_weights((kr() + sizeof(int32_t) / sizeof(int8_t)) * packed_channels());
413*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
414*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output((width() - 1) * output_stride() + channels());
415*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(width() * channels());
416*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(width() * channels());
417*4bdc9457SAndroid Build Coastguard Worker 
418*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
419*4bdc9457SAndroid Build Coastguard Worker       do {
420*4bdc9457SAndroid Build Coastguard Worker         std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
421*4bdc9457SAndroid Build Coastguard Worker       } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
422*4bdc9457SAndroid Build Coastguard Worker       do {
423*4bdc9457SAndroid Build Coastguard Worker         std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
424*4bdc9457SAndroid Build Coastguard Worker       } while (kernel.size() > 1 && *std::max_element(kernel.cbegin(), kernel.cend()) == *std::min_element(kernel.cbegin(), kernel.cend()));
425*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
426*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), int8_t(input_zero_point() - 0x80));
427*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
428*4bdc9457SAndroid Build Coastguard Worker 
429*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0);
430*4bdc9457SAndroid Build Coastguard Worker       const xnn_qs8_packing_params packing_params = { int8_t(input_zero_point() - 0x80) };
431*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_qs8_dwconv_ghw_w(
432*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
433*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(),
434*4bdc9457SAndroid Build Coastguard Worker         0 /* extra bytes */, &packing_params);
435*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
436*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
437*4bdc9457SAndroid Build Coastguard Worker       }
438*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
439*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
440*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
441*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
442*4bdc9457SAndroid Build Coastguard Worker         }
443*4bdc9457SAndroid Build Coastguard Worker       }
444*4bdc9457SAndroid Build Coastguard Worker 
445*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
446*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
447*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
448*4bdc9457SAndroid Build Coastguard Worker           float acc = bias[c];
449*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
450*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
451*4bdc9457SAndroid Build Coastguard Worker               acc +=
452*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(indirection[x * step() + k][c + input_offset()]) - int32_t(input_zero_point() - 0x80)) *
453*4bdc9457SAndroid Build Coastguard Worker                 int32_t(kernel[c * kr() + k]);
454*4bdc9457SAndroid Build Coastguard Worker             }
455*4bdc9457SAndroid Build Coastguard Worker           }
456*4bdc9457SAndroid Build Coastguard Worker           accumulators[x * channels() + c] = acc;
457*4bdc9457SAndroid Build Coastguard Worker         }
458*4bdc9457SAndroid Build Coastguard Worker       }
459*4bdc9457SAndroid Build Coastguard Worker 
460*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
461*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
462*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
463*4bdc9457SAndroid Build Coastguard Worker       const uint32_t accumulated_range = uint32_t(accumulated_max) - uint32_t(accumulated_min);
464*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
465*4bdc9457SAndroid Build Coastguard Worker       const int8_t output_zero_point = int8_t(std::max(std::min(
466*4bdc9457SAndroid Build Coastguard Worker         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
467*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
468*4bdc9457SAndroid Build Coastguard Worker 
469*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
470*4bdc9457SAndroid Build Coastguard Worker       const float requantization_scale = 1.0f / float(output_scale);
471*4bdc9457SAndroid Build Coastguard Worker       union xnn_qs8_conv_minmax_params quantization_params;
472*4bdc9457SAndroid Build Coastguard Worker       init_params(&quantization_params,
473*4bdc9457SAndroid Build Coastguard Worker         requantization_scale, output_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
474*4bdc9457SAndroid Build Coastguard Worker 
475*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
476*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
477*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
478*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = requantize(
479*4bdc9457SAndroid Build Coastguard Worker             accumulators[x * channels() + c], requantization_scale, output_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
480*4bdc9457SAndroid Build Coastguard Worker         }
481*4bdc9457SAndroid Build Coastguard Worker       }
482*4bdc9457SAndroid Build Coastguard Worker 
483*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
484*4bdc9457SAndroid Build Coastguard Worker       dwconv_minmax(
485*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
486*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), packed_weights.data(), output.data(),
487*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
488*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(int8_t),
489*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(int8_t), zero.data(),
490*4bdc9457SAndroid Build Coastguard Worker         &quantization_params);
491*4bdc9457SAndroid Build Coastguard Worker 
492*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
493*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
494*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
495*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(int32_t(output[x * output_stride() + c]), int32_t(qmin()) - 0x80)
496*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
497*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(int32_t(output[x * output_stride() + c]), int32_t(qmax()) - 0x80)
498*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
499*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(int32_t(output[x * output_stride() + c]), int32_t(output_ref[x * channels() + c]))
500*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c << ", accumulator = " << accumulators[x * channels() + c];
501*4bdc9457SAndroid Build Coastguard Worker         }
502*4bdc9457SAndroid Build Coastguard Worker       }
503*4bdc9457SAndroid Build Coastguard Worker     }
504*4bdc9457SAndroid Build Coastguard Worker   }
505*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_dwconv_minmax_unipass_ukernel_function dwconv_minmax,xnn_init_f16_minmax_params_fn init_params)506*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_dwconv_minmax_unipass_ukernel_function dwconv_minmax, xnn_init_f16_minmax_params_fn init_params) const {
507*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
508*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
509*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
510*4bdc9457SAndroid Build Coastguard Worker 
511*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirection((width() - 1) * step() + kr());
512*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + indirection.size() * channels());
513*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> kernel(channels() * kr());
514*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> bias(channels());
515*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_weights((kr() + 1) * packed_channels());
516*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
517*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((width() - 1) * output_stride() + channels());
518*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(width() * channels());
519*4bdc9457SAndroid Build Coastguard Worker 
520*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
521*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
522*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
523*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
524*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), 0);
525*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), 0.0f);
526*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
527*4bdc9457SAndroid Build Coastguard Worker 
528*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0);
529*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_f16_dwconv_ghw_w(
530*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
531*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(),
532*4bdc9457SAndroid Build Coastguard Worker         0 /* extra bytes */, nullptr);
533*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
534*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
535*4bdc9457SAndroid Build Coastguard Worker       }
536*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
537*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
538*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
539*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
540*4bdc9457SAndroid Build Coastguard Worker         }
541*4bdc9457SAndroid Build Coastguard Worker       }
542*4bdc9457SAndroid Build Coastguard Worker 
543*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
544*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
545*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
546*4bdc9457SAndroid Build Coastguard Worker           float acc = fp16_ieee_to_fp32_value(bias[c]);
547*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
548*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
549*4bdc9457SAndroid Build Coastguard Worker               acc += fp16_ieee_to_fp32_value(indirection[x * step() + k][c + input_offset()]) * fp16_ieee_to_fp32_value(kernel[c * kr() + k]);
550*4bdc9457SAndroid Build Coastguard Worker             }
551*4bdc9457SAndroid Build Coastguard Worker           }
552*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc;
553*4bdc9457SAndroid Build Coastguard Worker         }
554*4bdc9457SAndroid Build Coastguard Worker       }
555*4bdc9457SAndroid Build Coastguard Worker 
556*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
557*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
558*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
559*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
560*4bdc9457SAndroid Build Coastguard Worker       const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
561*4bdc9457SAndroid Build Coastguard Worker       const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
562*4bdc9457SAndroid Build Coastguard Worker 
563*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
564*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_minmax_params params;
565*4bdc9457SAndroid Build Coastguard Worker       init_params(&params,
566*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_min),
567*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_max));
568*4bdc9457SAndroid Build Coastguard Worker 
569*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
570*4bdc9457SAndroid Build Coastguard Worker       for (float& output_val : output_ref) {
571*4bdc9457SAndroid Build Coastguard Worker         output_val = std::max(std::min(output_val, output_max), output_min);
572*4bdc9457SAndroid Build Coastguard Worker       }
573*4bdc9457SAndroid Build Coastguard Worker 
574*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
575*4bdc9457SAndroid Build Coastguard Worker       dwconv_minmax(
576*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
577*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirection.data()), packed_weights.data(), output.data(),
578*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
579*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t),
580*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(uint16_t), zero.data(),
581*4bdc9457SAndroid Build Coastguard Worker         &params);
582*4bdc9457SAndroid Build Coastguard Worker 
583*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
584*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
585*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
586*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min)
587*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
588*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max)
589*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
590*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(output_ref[x * channels() + c], fp16_ieee_to_fp32_value(output[x * output_stride() + c]), std::max(1.0e-4f, std::abs(output_ref[x * channels() + c]) * 1.0e-2f))
591*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
592*4bdc9457SAndroid Build Coastguard Worker         }
593*4bdc9457SAndroid Build Coastguard Worker       }
594*4bdc9457SAndroid Build Coastguard Worker     }
595*4bdc9457SAndroid Build Coastguard Worker   }
596*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_dwconv_unipass_ukernel_function dwconv)597*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_dwconv_unipass_ukernel_function dwconv) const {
598*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
599*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
600*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
601*4bdc9457SAndroid Build Coastguard Worker 
602*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirection((width() - 1) * step() + kr());
603*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + indirection.size() * channels());
604*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(channels() * kr());
605*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(channels());
606*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> packed_weights((kr() + 1) * packed_channels());
607*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
608*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((width() - 1) * output_stride() + channels());
609*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(width() * channels());
610*4bdc9457SAndroid Build Coastguard Worker 
611*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
612*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
613*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
614*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
615*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), 0.0f);
616*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), nanf(""));
617*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
618*4bdc9457SAndroid Build Coastguard Worker 
619*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0.0f);
620*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_f32_dwconv_ghw_w(
621*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
622*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(),
623*4bdc9457SAndroid Build Coastguard Worker         0 /* extra bytes */, nullptr);
624*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
625*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
626*4bdc9457SAndroid Build Coastguard Worker       }
627*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
628*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
629*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
630*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
631*4bdc9457SAndroid Build Coastguard Worker         }
632*4bdc9457SAndroid Build Coastguard Worker       }
633*4bdc9457SAndroid Build Coastguard Worker 
634*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
635*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
636*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
637*4bdc9457SAndroid Build Coastguard Worker           float acc = bias[c];
638*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
639*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
640*4bdc9457SAndroid Build Coastguard Worker               acc += indirection[x * step() + k][c + input_offset()] * kernel[c * kr() + k];
641*4bdc9457SAndroid Build Coastguard Worker             }
642*4bdc9457SAndroid Build Coastguard Worker           }
643*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc;
644*4bdc9457SAndroid Build Coastguard Worker         }
645*4bdc9457SAndroid Build Coastguard Worker       }
646*4bdc9457SAndroid Build Coastguard Worker 
647*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
648*4bdc9457SAndroid Build Coastguard Worker       dwconv(
649*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
650*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), packed_weights.data(), output.data(),
651*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
652*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
653*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(float), zero.data(),
654*4bdc9457SAndroid Build Coastguard Worker         nullptr);
655*4bdc9457SAndroid Build Coastguard Worker 
656*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
657*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
658*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
659*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
660*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
661*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
662*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-5)
663*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
664*4bdc9457SAndroid Build Coastguard Worker         }
665*4bdc9457SAndroid Build Coastguard Worker       }
666*4bdc9457SAndroid Build Coastguard Worker     }
667*4bdc9457SAndroid Build Coastguard Worker   }
668*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_dwconv_minmax_unipass_ukernel_function dwconv_minmax,xnn_init_f32_minmax_params_fn init_params)669*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_dwconv_minmax_unipass_ukernel_function dwconv_minmax, xnn_init_f32_minmax_params_fn init_params) const {
670*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
671*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
672*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
673*4bdc9457SAndroid Build Coastguard Worker 
674*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirection((width() - 1) * step() + kr());
675*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + indirection.size() * channels());
676*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(channels() * kr());
677*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(channels());
678*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> packed_weights((kr() + 1) * packed_channels());
679*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
680*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((width() - 1) * output_stride() + channels());
681*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(width() * channels());
682*4bdc9457SAndroid Build Coastguard Worker 
683*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
684*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
685*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
686*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
687*4bdc9457SAndroid Build Coastguard Worker       std::fill(zero.begin(), zero.end(), 0.0f);
688*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), nanf(""));
689*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
690*4bdc9457SAndroid Build Coastguard Worker 
691*4bdc9457SAndroid Build Coastguard Worker       std::fill(packed_weights.begin(), packed_weights.end(), 0.0f);
692*4bdc9457SAndroid Build Coastguard Worker       xnn_pack_f32_dwconv_ghw_w(
693*4bdc9457SAndroid Build Coastguard Worker         kr(), kr(), 1, channels(), cr(),
694*4bdc9457SAndroid Build Coastguard Worker         kernel.data(), bias.data(), packed_weights.data(),
695*4bdc9457SAndroid Build Coastguard Worker         0 /* extra bytes */, nullptr);
696*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
697*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
698*4bdc9457SAndroid Build Coastguard Worker       }
699*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
700*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
701*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < indirection.size(); i += kr()) {
702*4bdc9457SAndroid Build Coastguard Worker           indirection[i + zero_index()] = zero.data();
703*4bdc9457SAndroid Build Coastguard Worker         }
704*4bdc9457SAndroid Build Coastguard Worker       }
705*4bdc9457SAndroid Build Coastguard Worker 
706*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
707*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
708*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
709*4bdc9457SAndroid Build Coastguard Worker           float acc = bias[c];
710*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < kr(); k++) {
711*4bdc9457SAndroid Build Coastguard Worker             if (indirection[x * step() + k] != zero.data()) {
712*4bdc9457SAndroid Build Coastguard Worker               acc += indirection[x * step() + k][c + input_offset()] * kernel[c * kr() + k];
713*4bdc9457SAndroid Build Coastguard Worker             }
714*4bdc9457SAndroid Build Coastguard Worker           }
715*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc;
716*4bdc9457SAndroid Build Coastguard Worker         }
717*4bdc9457SAndroid Build Coastguard Worker       }
718*4bdc9457SAndroid Build Coastguard Worker 
719*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
720*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
721*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
722*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
723*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin());
724*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
725*4bdc9457SAndroid Build Coastguard Worker 
726*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
727*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_minmax_params params;
728*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, output_min, output_max);
729*4bdc9457SAndroid Build Coastguard Worker 
730*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
731*4bdc9457SAndroid Build Coastguard Worker       for (float& output_val : output_ref) {
732*4bdc9457SAndroid Build Coastguard Worker         output_val = std::max(std::min(output_val, output_max), output_min);
733*4bdc9457SAndroid Build Coastguard Worker       }
734*4bdc9457SAndroid Build Coastguard Worker 
735*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
736*4bdc9457SAndroid Build Coastguard Worker       dwconv_minmax(
737*4bdc9457SAndroid Build Coastguard Worker         channels(), width(),
738*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), packed_weights.data(), output.data(),
739*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
740*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
741*4bdc9457SAndroid Build Coastguard Worker         input_offset() * sizeof(float), zero.data(),
742*4bdc9457SAndroid Build Coastguard Worker         &params);
743*4bdc9457SAndroid Build Coastguard Worker 
744*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
745*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < width(); x++) {
746*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
747*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[x * output_stride() + c], output_min)
748*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
749*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[x * output_stride() + c], output_max)
750*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
751*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
752*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
753*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
754*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-5)
755*4bdc9457SAndroid Build Coastguard Worker             << "x = " << x << ", channel = " << c;
756*4bdc9457SAndroid Build Coastguard Worker         }
757*4bdc9457SAndroid Build Coastguard Worker       }
758*4bdc9457SAndroid Build Coastguard Worker     }
759*4bdc9457SAndroid Build Coastguard Worker   }
760*4bdc9457SAndroid Build Coastguard Worker 
761*4bdc9457SAndroid Build Coastguard Worker  private:
762*4bdc9457SAndroid Build Coastguard Worker   uint32_t channels_{1};
763*4bdc9457SAndroid Build Coastguard Worker   uint32_t cr_{1};
764*4bdc9457SAndroid Build Coastguard Worker   uint32_t kr_{1};
765*4bdc9457SAndroid Build Coastguard Worker   uint32_t width_{1};
766*4bdc9457SAndroid Build Coastguard Worker   uint32_t step_{1};
767*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_stride_{0};
768*4bdc9457SAndroid Build Coastguard Worker   uint8_t input_zero_point_{127};
769*4bdc9457SAndroid Build Coastguard Worker   uint8_t kernel_zero_point_{127};
770*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
771*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
772*4bdc9457SAndroid Build Coastguard Worker   size_t input_offset_{0};
773*4bdc9457SAndroid Build Coastguard Worker   size_t zero_index_{SIZE_MAX};
774*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
775*4bdc9457SAndroid Build Coastguard Worker };
776