xref: /aosp_15_r20/external/XNNPACK/test/avgpool-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
29*4bdc9457SAndroid Build Coastguard Worker 
30*4bdc9457SAndroid Build Coastguard Worker 
31*4bdc9457SAndroid Build Coastguard Worker class AvgPoolMicrokernelTester {
32*4bdc9457SAndroid Build Coastguard Worker  public:
output_pixels(size_t output_pixels)33*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& output_pixels(size_t output_pixels) {
34*4bdc9457SAndroid Build Coastguard Worker     assert(output_pixels != 0);
35*4bdc9457SAndroid Build Coastguard Worker     this->output_pixels_ = output_pixels;
36*4bdc9457SAndroid Build Coastguard Worker     return *this;
37*4bdc9457SAndroid Build Coastguard Worker   }
38*4bdc9457SAndroid Build Coastguard Worker 
output_pixels()39*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_pixels() const {
40*4bdc9457SAndroid Build Coastguard Worker     return this->output_pixels_;
41*4bdc9457SAndroid Build Coastguard Worker   }
42*4bdc9457SAndroid Build Coastguard Worker 
step(size_t step)43*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& step(size_t step) {
44*4bdc9457SAndroid Build Coastguard Worker     assert(step != 0);
45*4bdc9457SAndroid Build Coastguard Worker     this->step_ = step;
46*4bdc9457SAndroid Build Coastguard Worker     return *this;
47*4bdc9457SAndroid Build Coastguard Worker   }
48*4bdc9457SAndroid Build Coastguard Worker 
step()49*4bdc9457SAndroid Build Coastguard Worker   inline size_t step() const {
50*4bdc9457SAndroid Build Coastguard Worker     return this->step_;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
input_offset(size_t input_offset)53*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& input_offset(size_t input_offset) {
54*4bdc9457SAndroid Build Coastguard Worker     assert(input_offset != 0);
55*4bdc9457SAndroid Build Coastguard Worker     this->input_offset_ = input_offset;
56*4bdc9457SAndroid Build Coastguard Worker     return *this;
57*4bdc9457SAndroid Build Coastguard Worker   }
58*4bdc9457SAndroid Build Coastguard Worker 
input_offset()59*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_offset() const {
60*4bdc9457SAndroid Build Coastguard Worker     return this->input_offset_;
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
zero_index(size_t zero_index)63*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& zero_index(size_t zero_index) {
64*4bdc9457SAndroid Build Coastguard Worker     this->zero_index_ = zero_index;
65*4bdc9457SAndroid Build Coastguard Worker     return *this;
66*4bdc9457SAndroid Build Coastguard Worker   }
67*4bdc9457SAndroid Build Coastguard Worker 
zero_index()68*4bdc9457SAndroid Build Coastguard Worker   inline size_t zero_index() const {
69*4bdc9457SAndroid Build Coastguard Worker     return this->zero_index_;
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker 
pooling_elements(size_t pooling_elements)72*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& pooling_elements(size_t pooling_elements) {
73*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_elements != 0);
74*4bdc9457SAndroid Build Coastguard Worker     this->pooling_elements_ = pooling_elements;
75*4bdc9457SAndroid Build Coastguard Worker     return *this;
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
pooling_elements()78*4bdc9457SAndroid Build Coastguard Worker   inline size_t pooling_elements() const {
79*4bdc9457SAndroid Build Coastguard Worker     return this->pooling_elements_;
80*4bdc9457SAndroid Build Coastguard Worker   }
81*4bdc9457SAndroid Build Coastguard Worker 
packed_pooling_elements()82*4bdc9457SAndroid Build Coastguard Worker   inline size_t packed_pooling_elements() const {
83*4bdc9457SAndroid Build Coastguard Worker     if (pooling_elements() <= primary_pooling_tile()) {
84*4bdc9457SAndroid Build Coastguard Worker       return primary_pooling_tile();
85*4bdc9457SAndroid Build Coastguard Worker     } else {
86*4bdc9457SAndroid Build Coastguard Worker       return (pooling_elements() - primary_pooling_tile()) % incremental_pooling_tile() == 0 ? pooling_elements() : ((pooling_elements() - primary_pooling_tile()) / incremental_pooling_tile() + 1) * incremental_pooling_tile() + primary_pooling_tile();
87*4bdc9457SAndroid Build Coastguard Worker     }
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& pooling_tile(size_t primary_tile, size_t incremental_tile = 0) {
91*4bdc9457SAndroid Build Coastguard Worker     assert(primary_tile != 0);
92*4bdc9457SAndroid Build Coastguard Worker     this->primary_pooling_tile_ = primary_tile;
93*4bdc9457SAndroid Build Coastguard Worker     this->incremental_pooling_tile_ = incremental_tile;
94*4bdc9457SAndroid Build Coastguard Worker     return *this;
95*4bdc9457SAndroid Build Coastguard Worker   }
96*4bdc9457SAndroid Build Coastguard Worker 
primary_pooling_tile(size_t primary_pooling_tile)97*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& primary_pooling_tile(size_t primary_pooling_tile) {
98*4bdc9457SAndroid Build Coastguard Worker     assert(primary_pooling_tile != 0);
99*4bdc9457SAndroid Build Coastguard Worker     this->primary_pooling_tile_ = primary_pooling_tile;
100*4bdc9457SAndroid Build Coastguard Worker     return *this;
101*4bdc9457SAndroid Build Coastguard Worker   }
102*4bdc9457SAndroid Build Coastguard Worker 
primary_pooling_tile()103*4bdc9457SAndroid Build Coastguard Worker   inline size_t primary_pooling_tile() const {
104*4bdc9457SAndroid Build Coastguard Worker     return this->primary_pooling_tile_;
105*4bdc9457SAndroid Build Coastguard Worker   }
106*4bdc9457SAndroid Build Coastguard Worker 
incremental_pooling_tile(size_t incremental_pooling_tile)107*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& incremental_pooling_tile(size_t incremental_pooling_tile) {
108*4bdc9457SAndroid Build Coastguard Worker     assert(incremental_pooling_tile != 0);
109*4bdc9457SAndroid Build Coastguard Worker     this->incremental_pooling_tile_ = incremental_pooling_tile;
110*4bdc9457SAndroid Build Coastguard Worker     return *this;
111*4bdc9457SAndroid Build Coastguard Worker   }
112*4bdc9457SAndroid Build Coastguard Worker 
incremental_pooling_tile()113*4bdc9457SAndroid Build Coastguard Worker   inline size_t incremental_pooling_tile() const {
114*4bdc9457SAndroid Build Coastguard Worker     return this->incremental_pooling_tile_;
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)117*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& channels(size_t channels) {
118*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
119*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
120*4bdc9457SAndroid Build Coastguard Worker     return *this;
121*4bdc9457SAndroid Build Coastguard Worker   }
122*4bdc9457SAndroid Build Coastguard Worker 
channels()123*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
124*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
125*4bdc9457SAndroid Build Coastguard Worker   }
126*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)127*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& output_stride(size_t output_stride) {
128*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
129*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
130*4bdc9457SAndroid Build Coastguard Worker     return *this;
131*4bdc9457SAndroid Build Coastguard Worker   }
132*4bdc9457SAndroid Build Coastguard Worker 
output_stride()133*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
134*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
135*4bdc9457SAndroid Build Coastguard Worker       return channels();
136*4bdc9457SAndroid Build Coastguard Worker     } else {
137*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
138*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
139*4bdc9457SAndroid Build Coastguard Worker     }
140*4bdc9457SAndroid Build Coastguard Worker   }
141*4bdc9457SAndroid Build Coastguard Worker 
input_scale(float input_scale)142*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& input_scale(float input_scale) {
143*4bdc9457SAndroid Build Coastguard Worker     assert(input_scale > 0.0f);
144*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(input_scale));
145*4bdc9457SAndroid Build Coastguard Worker     this->input_scale_ = input_scale;
146*4bdc9457SAndroid Build Coastguard Worker     return *this;
147*4bdc9457SAndroid Build Coastguard Worker   }
148*4bdc9457SAndroid Build Coastguard Worker 
input_scale()149*4bdc9457SAndroid Build Coastguard Worker   inline float input_scale() const {
150*4bdc9457SAndroid Build Coastguard Worker     return this->input_scale_;
151*4bdc9457SAndroid Build Coastguard Worker   }
152*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(uint8_t input_zero_point)153*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& input_zero_point(uint8_t input_zero_point) {
154*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
155*4bdc9457SAndroid Build Coastguard Worker     return *this;
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()158*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t input_zero_point() const {
159*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
160*4bdc9457SAndroid Build Coastguard Worker   }
161*4bdc9457SAndroid Build Coastguard Worker 
output_scale(float output_scale)162*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& output_scale(float output_scale) {
163*4bdc9457SAndroid Build Coastguard Worker     assert(output_scale > 0.0f);
164*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(output_scale));
165*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
166*4bdc9457SAndroid Build Coastguard Worker     return *this;
167*4bdc9457SAndroid Build Coastguard Worker   }
168*4bdc9457SAndroid Build Coastguard Worker 
output_scale()169*4bdc9457SAndroid Build Coastguard Worker   inline float output_scale() const {
170*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
171*4bdc9457SAndroid Build Coastguard Worker   }
172*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(uint8_t output_zero_point)173*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& output_zero_point(uint8_t output_zero_point) {
174*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
175*4bdc9457SAndroid Build Coastguard Worker     return *this;
176*4bdc9457SAndroid Build Coastguard Worker   }
177*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()178*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t output_zero_point() const {
179*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
180*4bdc9457SAndroid Build Coastguard Worker   }
181*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)182*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& qmin(uint8_t qmin) {
183*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
184*4bdc9457SAndroid Build Coastguard Worker     return *this;
185*4bdc9457SAndroid Build Coastguard Worker   }
186*4bdc9457SAndroid Build Coastguard Worker 
qmin()187*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
188*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
189*4bdc9457SAndroid Build Coastguard Worker   }
190*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)191*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& qmax(uint8_t qmax) {
192*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
193*4bdc9457SAndroid Build Coastguard Worker     return *this;
194*4bdc9457SAndroid Build Coastguard Worker   }
195*4bdc9457SAndroid Build Coastguard Worker 
qmax()196*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
197*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
198*4bdc9457SAndroid Build Coastguard Worker   }
199*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)200*4bdc9457SAndroid Build Coastguard Worker   inline AvgPoolMicrokernelTester& iterations(size_t iterations) {
201*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
202*4bdc9457SAndroid Build Coastguard Worker     return *this;
203*4bdc9457SAndroid Build Coastguard Worker   }
204*4bdc9457SAndroid Build Coastguard Worker 
iterations()205*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
206*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
207*4bdc9457SAndroid Build Coastguard Worker   }
208*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_avgpool_minmax_unipass_ukernel_function avgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)209*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_avgpool_minmax_unipass_ukernel_function avgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const {
210*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
211*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
212*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
213*4bdc9457SAndroid Build Coastguard Worker 
214*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
215*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
216*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
217*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
218*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((output_pixels() - 1) * output_stride() + channels());
219*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
220*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
221*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
222*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT16_C(0x7E00) /* NaN */);
223*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint16_t), input.end(), UINT16_C(0x7E00) /* NaN */);
224*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
225*4bdc9457SAndroid Build Coastguard Worker 
226*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
227*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
228*4bdc9457SAndroid Build Coastguard Worker       }
229*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
230*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
231*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
232*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
233*4bdc9457SAndroid Build Coastguard Worker       }
234*4bdc9457SAndroid Build Coastguard Worker 
235*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
236*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
237*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
238*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
239*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
240*4bdc9457SAndroid Build Coastguard Worker             const uint16_t* row = indirect_input[x * step() + p];
241*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
242*4bdc9457SAndroid Build Coastguard Worker               acc += fp16_ieee_to_fp32_value(row[c + input_offset()]);
243*4bdc9457SAndroid Build Coastguard Worker             }
244*4bdc9457SAndroid Build Coastguard Worker           }
245*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc / float(pooling_elements());
246*4bdc9457SAndroid Build Coastguard Worker         }
247*4bdc9457SAndroid Build Coastguard Worker       }
248*4bdc9457SAndroid Build Coastguard Worker 
249*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
250*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
251*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
252*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
253*4bdc9457SAndroid Build Coastguard Worker       float output_min_as_float = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
254*4bdc9457SAndroid Build Coastguard Worker       float output_max_as_float = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
255*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min_as_float);
256*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max_as_float);
257*4bdc9457SAndroid Build Coastguard Worker       output_min_as_float = fp16_ieee_to_fp32_value(output_min_as_half);
258*4bdc9457SAndroid Build Coastguard Worker       output_max_as_float = fp16_ieee_to_fp32_value(output_max_as_half);
259*4bdc9457SAndroid Build Coastguard Worker 
260*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
261*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
262*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max_as_float), output_min_as_float);
263*4bdc9457SAndroid Build Coastguard Worker       }
264*4bdc9457SAndroid Build Coastguard Worker 
265*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
266*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_scaleminmax_params params;
267*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, fp16_ieee_from_fp32_value(1.0f / float(pooling_elements())), output_min_as_half, output_max_as_half);
268*4bdc9457SAndroid Build Coastguard Worker 
269*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
270*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
271*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirect_input.data()), input_offset() * sizeof(uint16_t), zero.data(),
272*4bdc9457SAndroid Build Coastguard Worker         output.data(),
273*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
274*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t),
275*4bdc9457SAndroid Build Coastguard Worker         &params);
276*4bdc9457SAndroid Build Coastguard Worker 
277*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
278*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
279*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
280*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min_as_float)
281*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
282*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
283*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
284*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max_as_float)
285*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
286*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
287*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
288*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
289*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[x * output_stride() + c]),
290*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
291*4bdc9457SAndroid Build Coastguard Worker               std::max(1.0e-4f, std::abs(output_ref[x * channels() + c]) * 3.0e-3f))
292*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
293*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
294*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
295*4bdc9457SAndroid Build Coastguard Worker         }
296*4bdc9457SAndroid Build Coastguard Worker       }
297*4bdc9457SAndroid Build Coastguard Worker     }
298*4bdc9457SAndroid Build Coastguard Worker   }
299*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_avgpool_minmax_multipass_ukernel_function avgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)300*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_avgpool_minmax_multipass_ukernel_function avgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const {
301*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
302*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
303*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
304*4bdc9457SAndroid Build Coastguard Worker 
305*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
306*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
307*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
308*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
309*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((output_pixels() - 1) * output_stride() + channels());
310*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
311*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> buffer(XNN_EXTRA_BYTES / sizeof(uint16_t) + channels());
312*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
313*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
314*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT16_C(0x7E00) /* NaN */);
315*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint16_t), input.end(), UINT16_C(0x7E00) /* NaN */);
316*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
317*4bdc9457SAndroid Build Coastguard Worker 
318*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
319*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
320*4bdc9457SAndroid Build Coastguard Worker       }
321*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
322*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
323*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
324*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
325*4bdc9457SAndroid Build Coastguard Worker       }
326*4bdc9457SAndroid Build Coastguard Worker 
327*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
328*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
329*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
330*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
331*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
332*4bdc9457SAndroid Build Coastguard Worker             const uint16_t* row = indirect_input[x * step() + p];
333*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
334*4bdc9457SAndroid Build Coastguard Worker               acc += fp16_ieee_to_fp32_value(row[c + input_offset()]);
335*4bdc9457SAndroid Build Coastguard Worker             }
336*4bdc9457SAndroid Build Coastguard Worker           }
337*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc / float(pooling_elements());
338*4bdc9457SAndroid Build Coastguard Worker         }
339*4bdc9457SAndroid Build Coastguard Worker       }
340*4bdc9457SAndroid Build Coastguard Worker 
341*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
342*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
343*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
344*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
345*4bdc9457SAndroid Build Coastguard Worker       float output_min_as_float = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
346*4bdc9457SAndroid Build Coastguard Worker       float output_max_as_float = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
347*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min_as_float);
348*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max_as_float);
349*4bdc9457SAndroid Build Coastguard Worker       output_min_as_float = fp16_ieee_to_fp32_value(output_min_as_half);
350*4bdc9457SAndroid Build Coastguard Worker       output_max_as_float = fp16_ieee_to_fp32_value(output_max_as_half);
351*4bdc9457SAndroid Build Coastguard Worker 
352*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
353*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
354*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max_as_float), output_min_as_float);
355*4bdc9457SAndroid Build Coastguard Worker       }
356*4bdc9457SAndroid Build Coastguard Worker 
357*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
358*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_scaleminmax_params params;
359*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, fp16_ieee_from_fp32_value(1.0f / float(pooling_elements())), output_min_as_half, output_max_as_half);
360*4bdc9457SAndroid Build Coastguard Worker 
361*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
362*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
363*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirect_input.data()), input_offset() * sizeof(uint16_t), zero.data(),
364*4bdc9457SAndroid Build Coastguard Worker         buffer.data(), output.data(),
365*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
366*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t),
367*4bdc9457SAndroid Build Coastguard Worker         &params);
368*4bdc9457SAndroid Build Coastguard Worker 
369*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
370*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
371*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
372*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min_as_float)
373*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
374*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
375*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
376*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max_as_float)
377*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
378*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
379*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
380*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
381*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[x * output_stride() + c]),
382*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
383*4bdc9457SAndroid Build Coastguard Worker               std::max(1.0e-4f, std::abs(output_ref[x * channels() + c]) * 3.0e-3f))
384*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
385*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
386*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
387*4bdc9457SAndroid Build Coastguard Worker         }
388*4bdc9457SAndroid Build Coastguard Worker       }
389*4bdc9457SAndroid Build Coastguard Worker     }
390*4bdc9457SAndroid Build Coastguard Worker   }
391*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_avgpool_minmax_unipass_ukernel_function avgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)392*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_avgpool_minmax_unipass_ukernel_function avgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const {
393*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
394*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
395*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
396*4bdc9457SAndroid Build Coastguard Worker 
397*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
398*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
399*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
400*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
401*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
402*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
403*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
404*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
405*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), std::nanf(""));
406*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(float), input.end(), std::nanf(""));
407*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
408*4bdc9457SAndroid Build Coastguard Worker 
409*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
410*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
411*4bdc9457SAndroid Build Coastguard Worker       }
412*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
413*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
414*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
415*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
416*4bdc9457SAndroid Build Coastguard Worker       }
417*4bdc9457SAndroid Build Coastguard Worker 
418*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
419*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
420*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
421*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
422*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
423*4bdc9457SAndroid Build Coastguard Worker             const float* row = indirect_input[x * step() + p];
424*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
425*4bdc9457SAndroid Build Coastguard Worker               acc += row[c + input_offset()];
426*4bdc9457SAndroid Build Coastguard Worker             }
427*4bdc9457SAndroid Build Coastguard Worker           }
428*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc / float(pooling_elements());
429*4bdc9457SAndroid Build Coastguard Worker         }
430*4bdc9457SAndroid Build Coastguard Worker       }
431*4bdc9457SAndroid Build Coastguard Worker 
432*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
433*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
434*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
435*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
436*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
437*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
438*4bdc9457SAndroid Build Coastguard Worker 
439*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
440*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
441*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max), output_min);
442*4bdc9457SAndroid Build Coastguard Worker       }
443*4bdc9457SAndroid Build Coastguard Worker 
444*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
445*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_scaleminmax_params params;
446*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, 1.0f / float(pooling_elements()), output_min, output_max);
447*4bdc9457SAndroid Build Coastguard Worker 
448*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
449*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
450*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float), zero.data(),
451*4bdc9457SAndroid Build Coastguard Worker         output.data(),
452*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
453*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
454*4bdc9457SAndroid Build Coastguard Worker         &params);
455*4bdc9457SAndroid Build Coastguard Worker 
456*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
457*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
458*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
459*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[x * output_stride() + c], output_min)
460*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
461*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
462*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
463*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[x * output_stride() + c], output_max)
464*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
465*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
466*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
467*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
468*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
469*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
470*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-6f)
471*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
472*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
473*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
474*4bdc9457SAndroid Build Coastguard Worker         }
475*4bdc9457SAndroid Build Coastguard Worker       }
476*4bdc9457SAndroid Build Coastguard Worker     }
477*4bdc9457SAndroid Build Coastguard Worker   }
478*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_avgpool_minmax_multipass_ukernel_function avgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)479*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_avgpool_minmax_multipass_ukernel_function avgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const {
480*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
481*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
482*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
483*4bdc9457SAndroid Build Coastguard Worker 
484*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
485*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
486*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
487*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
488*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
489*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
490*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> buffer(XNN_EXTRA_BYTES / sizeof(float) + channels());
491*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
492*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
493*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), std::nanf(""));
494*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(float), input.end(), std::nanf(""));
495*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
496*4bdc9457SAndroid Build Coastguard Worker 
497*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
498*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
499*4bdc9457SAndroid Build Coastguard Worker       }
500*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
501*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
502*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
503*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
504*4bdc9457SAndroid Build Coastguard Worker       }
505*4bdc9457SAndroid Build Coastguard Worker 
506*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
507*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
508*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
509*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
510*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
511*4bdc9457SAndroid Build Coastguard Worker             const float* row = indirect_input[x * step() + p];
512*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
513*4bdc9457SAndroid Build Coastguard Worker               acc += row[c + input_offset()];
514*4bdc9457SAndroid Build Coastguard Worker             }
515*4bdc9457SAndroid Build Coastguard Worker           }
516*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc / float(pooling_elements());
517*4bdc9457SAndroid Build Coastguard Worker         }
518*4bdc9457SAndroid Build Coastguard Worker       }
519*4bdc9457SAndroid Build Coastguard Worker 
520*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
521*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
522*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
523*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
524*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
525*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
526*4bdc9457SAndroid Build Coastguard Worker 
527*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
528*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
529*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max), output_min);
530*4bdc9457SAndroid Build Coastguard Worker       }
531*4bdc9457SAndroid Build Coastguard Worker 
532*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
533*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_scaleminmax_params params;
534*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, 1.0f / float(pooling_elements()), output_min, output_max);
535*4bdc9457SAndroid Build Coastguard Worker 
536*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
537*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
538*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float), zero.data(),
539*4bdc9457SAndroid Build Coastguard Worker         buffer.data(), output.data(),
540*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
541*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
542*4bdc9457SAndroid Build Coastguard Worker         &params);
543*4bdc9457SAndroid Build Coastguard Worker 
544*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
545*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
546*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
547*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[x * output_stride() + c], output_min)
548*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
549*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
550*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
551*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[x * output_stride() + c], output_max)
552*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
553*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
554*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
555*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
556*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
557*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
558*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-6f)
559*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
560*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
561*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
562*4bdc9457SAndroid Build Coastguard Worker         }
563*4bdc9457SAndroid Build Coastguard Worker       }
564*4bdc9457SAndroid Build Coastguard Worker     }
565*4bdc9457SAndroid Build Coastguard Worker   }
566*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_avgpool_minmax_unipass_ukernel_function avgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params)567*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qu8_avgpool_minmax_unipass_ukernel_function avgpool_minmax, xnn_init_qu8_avgpool_minmax_params_fn init_params) const {
568*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
569*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
570*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
571*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
572*4bdc9457SAndroid Build Coastguard Worker 
573*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint8_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
574*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
575*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
576*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
577*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((output_pixels() - 1) * output_stride() + channels());
578*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(output_pixels() * channels());
579*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_real(output_pixels() * channels());
580*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulator(output_pixels() * channels());
581*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
582*4bdc9457SAndroid Build Coastguard Worker       do {
583*4bdc9457SAndroid Build Coastguard Worker         std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
584*4bdc9457SAndroid Build Coastguard Worker       } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
585*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT8_C(0xA5));
586*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint8_t), input.end(), UINT8_C(0xA5));
587*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
588*4bdc9457SAndroid Build Coastguard Worker 
589*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
590*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
591*4bdc9457SAndroid Build Coastguard Worker       }
592*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
593*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
594*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
595*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
596*4bdc9457SAndroid Build Coastguard Worker       }
597*4bdc9457SAndroid Build Coastguard Worker 
598*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
599*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_avgpool_minmax_params params;
600*4bdc9457SAndroid Build Coastguard Worker       init_params(
601*4bdc9457SAndroid Build Coastguard Worker         &params,
602*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point()) * int32_t(pooling_elements()),
603*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(pooling_elements())),
604*4bdc9457SAndroid Build Coastguard Worker         output_zero_point(), qmin(), qmax());
605*4bdc9457SAndroid Build Coastguard Worker 
606*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
607*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
608*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
609*4bdc9457SAndroid Build Coastguard Worker           int32_t acc = 0;
610*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
611*4bdc9457SAndroid Build Coastguard Worker             const uint8_t* row = indirect_input[x * step() + p];
612*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
613*4bdc9457SAndroid Build Coastguard Worker               acc += int32_t(row[c + input_offset()]);
614*4bdc9457SAndroid Build Coastguard Worker             }
615*4bdc9457SAndroid Build Coastguard Worker             acc -= int32_t(input_zero_point());
616*4bdc9457SAndroid Build Coastguard Worker           }
617*4bdc9457SAndroid Build Coastguard Worker           accumulator[x * channels() + c] = acc;
618*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = xnn_qu8_requantize_rndna(
619*4bdc9457SAndroid Build Coastguard Worker             acc, input_scale() / (output_scale() * float(pooling_elements())), output_zero_point(), qmin(), qmax());
620*4bdc9457SAndroid Build Coastguard Worker           const float scaled_acc =
621*4bdc9457SAndroid Build Coastguard Worker             float(acc) * input_scale() / (output_scale() * float(pooling_elements())) + float(output_zero_point());
622*4bdc9457SAndroid Build Coastguard Worker           output_real[x * channels() + c] = std::min(std::max(scaled_acc, float(qmin())), float(qmax()));
623*4bdc9457SAndroid Build Coastguard Worker         }
624*4bdc9457SAndroid Build Coastguard Worker       }
625*4bdc9457SAndroid Build Coastguard Worker 
626*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
627*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
628*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(uint8_t), zero.data(),
629*4bdc9457SAndroid Build Coastguard Worker         output.data(),
630*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
631*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint8_t),
632*4bdc9457SAndroid Build Coastguard Worker         &params);
633*4bdc9457SAndroid Build Coastguard Worker 
634*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
635*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
636*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
637*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(uint32_t(output[x * output_stride() + c]), uint32_t(qmin()))
638*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
639*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
640*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
641*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(uint32_t(output[x * output_stride() + c]), uint32_t(qmax()))
642*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
643*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
644*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
645*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(float(int32_t(output[x * output_stride() + c])), output_real[x * channels() + c], 0.5f)
646*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
647*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
648*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset() << ", accumulator = " << accumulator[x * channels() + c];
649*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(uint32_t(output_ref[x * channels() + c]), uint32_t(output[x * output_stride() + c]))
650*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
651*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
652*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset() << ", accumulator = " << accumulator[x * channels() + c];
653*4bdc9457SAndroid Build Coastguard Worker         }
654*4bdc9457SAndroid Build Coastguard Worker       }
655*4bdc9457SAndroid Build Coastguard Worker     }
656*4bdc9457SAndroid Build Coastguard Worker   }
657*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_avgpool_minmax_multipass_ukernel_function avgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params)658*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_qu8_avgpool_minmax_multipass_ukernel_function avgpool_minmax, xnn_init_qu8_avgpool_minmax_params_fn init_params) const {
659*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
660*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
661*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
662*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
663*4bdc9457SAndroid Build Coastguard Worker 
664*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint8_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
665*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
666*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
667*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
668*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((output_pixels() - 1) * output_stride() + channels());
669*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(output_pixels() * channels());
670*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_real(output_pixels() * channels());
671*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulator(output_pixels() * channels());
672*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t, AlignedAllocator<int32_t, 64>> buffer(XNN_EXTRA_BYTES / sizeof(uint8_t) + channels());
673*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
674*4bdc9457SAndroid Build Coastguard Worker       do {
675*4bdc9457SAndroid Build Coastguard Worker         std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
676*4bdc9457SAndroid Build Coastguard Worker       } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend()));
677*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT8_C(0xA5));
678*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint8_t), input.end(), UINT8_C(0xA5));
679*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
680*4bdc9457SAndroid Build Coastguard Worker 
681*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
682*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
683*4bdc9457SAndroid Build Coastguard Worker       }
684*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
685*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
686*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
687*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
688*4bdc9457SAndroid Build Coastguard Worker       }
689*4bdc9457SAndroid Build Coastguard Worker 
690*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
691*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_avgpool_minmax_params params;
692*4bdc9457SAndroid Build Coastguard Worker       init_params(
693*4bdc9457SAndroid Build Coastguard Worker         &params,
694*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point()) * int32_t(pooling_elements()),
695*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(pooling_elements())),
696*4bdc9457SAndroid Build Coastguard Worker         output_zero_point(), qmin(), qmax());
697*4bdc9457SAndroid Build Coastguard Worker 
698*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
699*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
700*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
701*4bdc9457SAndroid Build Coastguard Worker           int32_t acc = 0;
702*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
703*4bdc9457SAndroid Build Coastguard Worker             const uint8_t* row = indirect_input[x * step() + p];
704*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
705*4bdc9457SAndroid Build Coastguard Worker               acc += int32_t(row[c + input_offset()]);
706*4bdc9457SAndroid Build Coastguard Worker             }
707*4bdc9457SAndroid Build Coastguard Worker             acc -= int32_t(input_zero_point());
708*4bdc9457SAndroid Build Coastguard Worker           }
709*4bdc9457SAndroid Build Coastguard Worker           accumulator[x * channels() + c] = acc;
710*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = xnn_qu8_requantize_rndna(
711*4bdc9457SAndroid Build Coastguard Worker             acc, input_scale() / (output_scale() * float(pooling_elements())), output_zero_point(), qmin(), qmax());
712*4bdc9457SAndroid Build Coastguard Worker           const float scaled_acc =
713*4bdc9457SAndroid Build Coastguard Worker             float(acc) * input_scale() / (output_scale() * float(pooling_elements())) + float(output_zero_point());
714*4bdc9457SAndroid Build Coastguard Worker           output_real[x * channels() + c] = std::min(std::max(scaled_acc, float(qmin())), float(qmax()));
715*4bdc9457SAndroid Build Coastguard Worker         }
716*4bdc9457SAndroid Build Coastguard Worker       }
717*4bdc9457SAndroid Build Coastguard Worker 
718*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
719*4bdc9457SAndroid Build Coastguard Worker       avgpool_minmax(output_pixels(), pooling_elements(), channels(),
720*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(uint8_t), zero.data(),
721*4bdc9457SAndroid Build Coastguard Worker         buffer.data(), output.data(),
722*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
723*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint8_t),
724*4bdc9457SAndroid Build Coastguard Worker         &params);
725*4bdc9457SAndroid Build Coastguard Worker 
726*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
727*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
728*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
729*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(uint32_t(output[x * output_stride() + c]), uint32_t(qmin()))
730*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
731*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
732*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
733*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(uint32_t(output[x * output_stride() + c]), uint32_t(qmax()))
734*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
735*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
736*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
737*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(float(int32_t(output[x * output_stride() + c])), output_real[x * channels() + c], 0.5f)
738*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
739*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
740*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset() << ", accumulator = " << accumulator[x * channels() + c];
741*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(uint32_t(output_ref[x * channels() + c]), uint32_t(output[x * output_stride() + c]))
742*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
743*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
744*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset() << ", accumulator = " << accumulator[x * channels() + c];
745*4bdc9457SAndroid Build Coastguard Worker         }
746*4bdc9457SAndroid Build Coastguard Worker       }
747*4bdc9457SAndroid Build Coastguard Worker     }
748*4bdc9457SAndroid Build Coastguard Worker   }
749*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_pavgpool_minmax_unipass_ukernel_function pavgpool_minmax,xnn_init_f16_minmax_params_fn init_params)750*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_pavgpool_minmax_unipass_ukernel_function pavgpool_minmax, xnn_init_f16_minmax_params_fn init_params) const {
751*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
752*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
753*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
754*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> m32dist(0.1f, 0.5f);
755*4bdc9457SAndroid Build Coastguard Worker 
756*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
757*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
758*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
759*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
760*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> multiplier(output_pixels());
761*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((output_pixels() - 1) * output_stride() + channels());
762*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
763*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
764*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
765*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT16_C(0x7E00) /* NaN */);
766*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint16_t), input.end(), UINT16_C(0x7E00) /* NaN */);
767*4bdc9457SAndroid Build Coastguard Worker       std::generate(multiplier.begin(), multiplier.end(), [&]() { return fp16_ieee_from_fp32_value(m32dist(rng)); });
768*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
769*4bdc9457SAndroid Build Coastguard Worker 
770*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
771*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
772*4bdc9457SAndroid Build Coastguard Worker       }
773*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
774*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
775*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
776*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
777*4bdc9457SAndroid Build Coastguard Worker       }
778*4bdc9457SAndroid Build Coastguard Worker 
779*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
780*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
781*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
782*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
783*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
784*4bdc9457SAndroid Build Coastguard Worker             const uint16_t* row = indirect_input[x * step() + p];
785*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
786*4bdc9457SAndroid Build Coastguard Worker               acc += fp16_ieee_to_fp32_value(row[c + input_offset()]);
787*4bdc9457SAndroid Build Coastguard Worker             }
788*4bdc9457SAndroid Build Coastguard Worker           }
789*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc * fp16_ieee_to_fp32_value(multiplier[x]);
790*4bdc9457SAndroid Build Coastguard Worker         }
791*4bdc9457SAndroid Build Coastguard Worker       }
792*4bdc9457SAndroid Build Coastguard Worker 
793*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
794*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
795*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
796*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
797*4bdc9457SAndroid Build Coastguard Worker       float output_min_as_float = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
798*4bdc9457SAndroid Build Coastguard Worker       float output_max_as_float = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
799*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min_as_float);
800*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max_as_float);
801*4bdc9457SAndroid Build Coastguard Worker       output_min_as_float = fp16_ieee_to_fp32_value(output_min_as_half);
802*4bdc9457SAndroid Build Coastguard Worker       output_max_as_float = fp16_ieee_to_fp32_value(output_max_as_half);
803*4bdc9457SAndroid Build Coastguard Worker 
804*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
805*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
806*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max_as_float), output_min_as_float);
807*4bdc9457SAndroid Build Coastguard Worker       }
808*4bdc9457SAndroid Build Coastguard Worker 
809*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
810*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_minmax_params params;
811*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, output_min_as_half, output_max_as_half);
812*4bdc9457SAndroid Build Coastguard Worker 
813*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
814*4bdc9457SAndroid Build Coastguard Worker       pavgpool_minmax(output_pixels(), pooling_elements(), channels(),
815*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirect_input.data()), input_offset() * sizeof(uint16_t), zero.data(),
816*4bdc9457SAndroid Build Coastguard Worker         multiplier.data(), output.data(),
817*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
818*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t),
819*4bdc9457SAndroid Build Coastguard Worker         &params);
820*4bdc9457SAndroid Build Coastguard Worker 
821*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
822*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
823*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
824*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min_as_float)
825*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
826*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
827*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
828*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max_as_float)
829*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
830*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
831*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
832*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
833*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[x * output_stride() + c]),
834*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
835*4bdc9457SAndroid Build Coastguard Worker               std::max(1.0e-4f, std::abs(output_ref[x * channels() + c]) * 3.0e-3f))
836*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
837*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
838*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
839*4bdc9457SAndroid Build Coastguard Worker         }
840*4bdc9457SAndroid Build Coastguard Worker       }
841*4bdc9457SAndroid Build Coastguard Worker     }
842*4bdc9457SAndroid Build Coastguard Worker   }
843*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_pavgpool_minmax_multipass_ukernel_function pavgpool_minmax,xnn_init_f16_minmax_params_fn init_params)844*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_pavgpool_minmax_multipass_ukernel_function pavgpool_minmax, xnn_init_f16_minmax_params_fn init_params) const {
845*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
846*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
847*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
848*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> m32dist(0.1f, 0.5f);
849*4bdc9457SAndroid Build Coastguard Worker 
850*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
851*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
852*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
853*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
854*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> multiplier(output_pixels());
855*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((output_pixels() - 1) * output_stride() + channels());
856*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
857*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> buffer(XNN_EXTRA_BYTES / sizeof(uint16_t) + channels());
858*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
859*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
860*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), UINT16_C(0x7E00) /* NaN */);
861*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(uint16_t), input.end(), UINT16_C(0x7E00) /* NaN */);
862*4bdc9457SAndroid Build Coastguard Worker       std::generate(multiplier.begin(), multiplier.end(), [&]() { return fp16_ieee_from_fp32_value(m32dist(rng)); });
863*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
864*4bdc9457SAndroid Build Coastguard Worker 
865*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
866*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
867*4bdc9457SAndroid Build Coastguard Worker       }
868*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
869*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
870*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
871*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
872*4bdc9457SAndroid Build Coastguard Worker       }
873*4bdc9457SAndroid Build Coastguard Worker 
874*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
875*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
876*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
877*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
878*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
879*4bdc9457SAndroid Build Coastguard Worker             const uint16_t* row = indirect_input[x * step() + p];
880*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
881*4bdc9457SAndroid Build Coastguard Worker               acc += fp16_ieee_to_fp32_value(row[c + input_offset()]);
882*4bdc9457SAndroid Build Coastguard Worker             }
883*4bdc9457SAndroid Build Coastguard Worker           }
884*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc * fp16_ieee_to_fp32_value(multiplier[x]);
885*4bdc9457SAndroid Build Coastguard Worker         }
886*4bdc9457SAndroid Build Coastguard Worker       }
887*4bdc9457SAndroid Build Coastguard Worker 
888*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
889*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
890*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
891*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
892*4bdc9457SAndroid Build Coastguard Worker       float output_min_as_float = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
893*4bdc9457SAndroid Build Coastguard Worker       float output_max_as_float = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
894*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min_as_float);
895*4bdc9457SAndroid Build Coastguard Worker       const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max_as_float);
896*4bdc9457SAndroid Build Coastguard Worker       output_min_as_float = fp16_ieee_to_fp32_value(output_min_as_half);
897*4bdc9457SAndroid Build Coastguard Worker       output_max_as_float = fp16_ieee_to_fp32_value(output_max_as_half);
898*4bdc9457SAndroid Build Coastguard Worker 
899*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
900*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
901*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max_as_float), output_min_as_float);
902*4bdc9457SAndroid Build Coastguard Worker       }
903*4bdc9457SAndroid Build Coastguard Worker 
904*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
905*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_minmax_params params;
906*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, output_min_as_half, output_max_as_half);
907*4bdc9457SAndroid Build Coastguard Worker 
908*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
909*4bdc9457SAndroid Build Coastguard Worker       pavgpool_minmax(output_pixels(), pooling_elements(), channels(),
910*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirect_input.data()), input_offset() * sizeof(uint16_t), zero.data(),
911*4bdc9457SAndroid Build Coastguard Worker         multiplier.data(), buffer.data(), output.data(),
912*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
913*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t),
914*4bdc9457SAndroid Build Coastguard Worker         &params);
915*4bdc9457SAndroid Build Coastguard Worker 
916*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
917*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
918*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
919*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min_as_float)
920*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
921*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
922*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
923*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max_as_float)
924*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
925*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
926*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
927*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
928*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[x * output_stride() + c]),
929*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
930*4bdc9457SAndroid Build Coastguard Worker               std::max(1.0e-4f, std::abs(output_ref[x * channels() + c]) * 3.0e-3f))
931*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
932*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
933*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
934*4bdc9457SAndroid Build Coastguard Worker         }
935*4bdc9457SAndroid Build Coastguard Worker       }
936*4bdc9457SAndroid Build Coastguard Worker     }
937*4bdc9457SAndroid Build Coastguard Worker   }
938*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_pavgpool_minmax_unipass_ukernel_function pavgpool_minmax,xnn_init_f32_minmax_params_fn init_params)939*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_pavgpool_minmax_unipass_ukernel_function pavgpool_minmax, xnn_init_f32_minmax_params_fn init_params) const {
940*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
941*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
942*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
943*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> m32dist(0.1f, 0.5f);
944*4bdc9457SAndroid Build Coastguard Worker 
945*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
946*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
947*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
948*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
949*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> multiplier(output_pixels());
950*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
951*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
952*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
953*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
954*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), std::nanf(""));
955*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(float), input.end(), std::nanf(""));
956*4bdc9457SAndroid Build Coastguard Worker       std::generate(multiplier.begin(), multiplier.end(), [&]() { return m32dist(rng); });
957*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
958*4bdc9457SAndroid Build Coastguard Worker 
959*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
960*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
961*4bdc9457SAndroid Build Coastguard Worker       }
962*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
963*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
964*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
965*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
966*4bdc9457SAndroid Build Coastguard Worker       }
967*4bdc9457SAndroid Build Coastguard Worker 
968*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
969*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
970*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
971*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
972*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
973*4bdc9457SAndroid Build Coastguard Worker             const float* row = indirect_input[x * step() + p];
974*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
975*4bdc9457SAndroid Build Coastguard Worker               acc += row[c + input_offset()];
976*4bdc9457SAndroid Build Coastguard Worker             }
977*4bdc9457SAndroid Build Coastguard Worker           }
978*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc * multiplier[x];
979*4bdc9457SAndroid Build Coastguard Worker         }
980*4bdc9457SAndroid Build Coastguard Worker       }
981*4bdc9457SAndroid Build Coastguard Worker 
982*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
983*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
984*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
985*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
986*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
987*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
988*4bdc9457SAndroid Build Coastguard Worker 
989*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
990*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
991*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max), output_min);
992*4bdc9457SAndroid Build Coastguard Worker       }
993*4bdc9457SAndroid Build Coastguard Worker 
994*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
995*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_minmax_params params;
996*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, output_min, output_max);
997*4bdc9457SAndroid Build Coastguard Worker 
998*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
999*4bdc9457SAndroid Build Coastguard Worker       pavgpool_minmax(output_pixels(), pooling_elements(), channels(),
1000*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float), zero.data(),
1001*4bdc9457SAndroid Build Coastguard Worker         multiplier.data(), output.data(),
1002*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
1003*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
1004*4bdc9457SAndroid Build Coastguard Worker         &params);
1005*4bdc9457SAndroid Build Coastguard Worker 
1006*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1007*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
1008*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
1009*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[x * output_stride() + c], output_min)
1010*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1011*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1012*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1013*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[x * output_stride() + c], output_max)
1014*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1015*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1016*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1017*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
1018*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
1019*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
1020*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-6f)
1021*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1022*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1023*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1024*4bdc9457SAndroid Build Coastguard Worker         }
1025*4bdc9457SAndroid Build Coastguard Worker       }
1026*4bdc9457SAndroid Build Coastguard Worker     }
1027*4bdc9457SAndroid Build Coastguard Worker   }
1028*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_pavgpool_minmax_multipass_ukernel_function pavgpool_minmax,xnn_init_f32_minmax_params_fn init_params)1029*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_pavgpool_minmax_multipass_ukernel_function pavgpool_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1030*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1031*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1032*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
1033*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> m32dist(0.1f, 0.5f);
1034*4bdc9457SAndroid Build Coastguard Worker 
1035*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
1036*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
1037*4bdc9457SAndroid Build Coastguard Worker       input_offset() + indirect_input.size() * channels());
1038*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
1039*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> multiplier(output_pixels());
1040*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
1041*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
1042*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> buffer(XNN_EXTRA_BYTES / sizeof(float) + channels());
1043*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1044*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1045*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.begin(), input.begin() + input_offset(), std::nanf(""));
1046*4bdc9457SAndroid Build Coastguard Worker       std::fill(input.end() - XNN_EXTRA_BYTES / sizeof(float), input.end(), std::nanf(""));
1047*4bdc9457SAndroid Build Coastguard Worker       std::generate(multiplier.begin(), multiplier.end(), [&]() { return m32dist(rng); });
1048*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
1049*4bdc9457SAndroid Build Coastguard Worker 
1050*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
1051*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels();
1052*4bdc9457SAndroid Build Coastguard Worker       }
1053*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
1054*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
1055*4bdc9457SAndroid Build Coastguard Worker       if (zero_index() != SIZE_MAX) {
1056*4bdc9457SAndroid Build Coastguard Worker         indirect_input[zero_index()] = zero.data();
1057*4bdc9457SAndroid Build Coastguard Worker       }
1058*4bdc9457SAndroid Build Coastguard Worker 
1059*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
1060*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
1061*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
1062*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
1063*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
1064*4bdc9457SAndroid Build Coastguard Worker             const float* row = indirect_input[x * step() + p];
1065*4bdc9457SAndroid Build Coastguard Worker             if (row != zero.data()) {
1066*4bdc9457SAndroid Build Coastguard Worker               acc += row[c + input_offset()];
1067*4bdc9457SAndroid Build Coastguard Worker             }
1068*4bdc9457SAndroid Build Coastguard Worker           }
1069*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = acc * multiplier[x];
1070*4bdc9457SAndroid Build Coastguard Worker         }
1071*4bdc9457SAndroid Build Coastguard Worker       }
1072*4bdc9457SAndroid Build Coastguard Worker 
1073*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
1074*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1075*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1076*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
1077*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
1078*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
1079*4bdc9457SAndroid Build Coastguard Worker 
1080*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
1081*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
1082*4bdc9457SAndroid Build Coastguard Worker         output_value = std::max(std::min(output_value, output_max), output_min);
1083*4bdc9457SAndroid Build Coastguard Worker       }
1084*4bdc9457SAndroid Build Coastguard Worker 
1085*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
1086*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_minmax_params params;
1087*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, output_min, output_max);
1088*4bdc9457SAndroid Build Coastguard Worker 
1089*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
1090*4bdc9457SAndroid Build Coastguard Worker       pavgpool_minmax(output_pixels(), pooling_elements(), channels(),
1091*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float), zero.data(),
1092*4bdc9457SAndroid Build Coastguard Worker         multiplier.data(), buffer.data(), output.data(),
1093*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
1094*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float),
1095*4bdc9457SAndroid Build Coastguard Worker         &params);
1096*4bdc9457SAndroid Build Coastguard Worker 
1097*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1098*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
1099*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
1100*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[x * output_stride() + c], output_min)
1101*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1102*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1103*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1104*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[x * output_stride() + c], output_max)
1105*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1106*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1107*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1108*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
1109*4bdc9457SAndroid Build Coastguard Worker               output[x * output_stride() + c],
1110*4bdc9457SAndroid Build Coastguard Worker               output_ref[x * channels() + c],
1111*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[x * channels() + c]) * 1.0e-6f)
1112*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
1113*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
1114*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
1115*4bdc9457SAndroid Build Coastguard Worker         }
1116*4bdc9457SAndroid Build Coastguard Worker       }
1117*4bdc9457SAndroid Build Coastguard Worker     }
1118*4bdc9457SAndroid Build Coastguard Worker   }
1119*4bdc9457SAndroid Build Coastguard Worker 
1120*4bdc9457SAndroid Build Coastguard Worker  private:
1121*4bdc9457SAndroid Build Coastguard Worker   size_t output_pixels_{1};
1122*4bdc9457SAndroid Build Coastguard Worker   size_t pooling_elements_{1};
1123*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
1124*4bdc9457SAndroid Build Coastguard Worker   size_t input_offset_{0};
1125*4bdc9457SAndroid Build Coastguard Worker   size_t zero_index_{SIZE_MAX};
1126*4bdc9457SAndroid Build Coastguard Worker   size_t step_{1};
1127*4bdc9457SAndroid Build Coastguard Worker   size_t primary_pooling_tile_{1};
1128*4bdc9457SAndroid Build Coastguard Worker   size_t incremental_pooling_tile_{1};
1129*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
1130*4bdc9457SAndroid Build Coastguard Worker   float input_scale_{1.25f};
1131*4bdc9457SAndroid Build Coastguard Worker   float output_scale_{0.75f};
1132*4bdc9457SAndroid Build Coastguard Worker   uint8_t input_zero_point_{121};
1133*4bdc9457SAndroid Build Coastguard Worker   uint8_t output_zero_point_{133};
1134*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
1135*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
1136*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
1137*4bdc9457SAndroid Build Coastguard Worker };
1138