1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <limits> 16*4bdc9457SAndroid Build Coastguard Worker #include <random> 17*4bdc9457SAndroid Build Coastguard Worker #include <vector> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h> 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h> 22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker class ConvHWCMicrokernelTester { 27*4bdc9457SAndroid Build Coastguard Worker public: 28*4bdc9457SAndroid Build Coastguard Worker enum class Variant { 29*4bdc9457SAndroid Build Coastguard Worker Native, 30*4bdc9457SAndroid Build Coastguard Worker Scalar, 31*4bdc9457SAndroid Build Coastguard Worker }; 32*4bdc9457SAndroid Build Coastguard Worker output_channels_tile(uint32_t output_channels_tile)33*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& output_channels_tile(uint32_t output_channels_tile) { 34*4bdc9457SAndroid Build Coastguard Worker this->output_channels_tile_ = output_channels_tile; 35*4bdc9457SAndroid Build Coastguard Worker return *this; 36*4bdc9457SAndroid Build Coastguard Worker } 37*4bdc9457SAndroid Build Coastguard Worker output_channels_tile()38*4bdc9457SAndroid Build Coastguard Worker inline uint32_t output_channels_tile() const { 39*4bdc9457SAndroid Build Coastguard Worker return this->output_channels_tile_; 40*4bdc9457SAndroid Build Coastguard Worker } 41*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding)42*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding(uint32_t padding) { 43*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 44*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 45*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 46*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 47*4bdc9457SAndroid Build Coastguard Worker return *this; 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)50*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_height(uint32_t padding_height) { 51*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 52*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 53*4bdc9457SAndroid Build Coastguard Worker return *this; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)56*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_width(uint32_t padding_width) { 57*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 58*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 59*4bdc9457SAndroid Build Coastguard Worker return *this; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)62*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_top(uint32_t padding_top) { 63*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 64*4bdc9457SAndroid Build Coastguard Worker return *this; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker padding_top()67*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 68*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 69*4bdc9457SAndroid Build Coastguard Worker } 70*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)71*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_right(uint32_t padding_right) { 72*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 73*4bdc9457SAndroid Build Coastguard Worker return *this; 74*4bdc9457SAndroid Build Coastguard Worker } 75*4bdc9457SAndroid Build Coastguard Worker padding_right()76*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 77*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 78*4bdc9457SAndroid Build Coastguard Worker } 79*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)80*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_bottom(uint32_t padding_bottom) { 81*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 82*4bdc9457SAndroid Build Coastguard Worker return *this; 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker padding_bottom()85*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 86*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 87*4bdc9457SAndroid Build Coastguard Worker } 88*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)89*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& padding_left(uint32_t padding_left) { 90*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 91*4bdc9457SAndroid Build Coastguard Worker return *this; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker padding_left()94*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 95*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 96*4bdc9457SAndroid Build Coastguard Worker } 97*4bdc9457SAndroid Build Coastguard Worker input_size(uint32_t input_height,uint32_t input_width)98*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& input_size(uint32_t input_height, uint32_t input_width) { 99*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 100*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 101*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 102*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 103*4bdc9457SAndroid Build Coastguard Worker return *this; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker input_height(uint32_t input_height)106*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& input_height(uint32_t input_height) { 107*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 108*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 109*4bdc9457SAndroid Build Coastguard Worker return *this; 110*4bdc9457SAndroid Build Coastguard Worker } 111*4bdc9457SAndroid Build Coastguard Worker input_height()112*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_height() const { 113*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker input_width(uint32_t input_width)116*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& input_width(uint32_t input_width) { 117*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 118*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 119*4bdc9457SAndroid Build Coastguard Worker return *this; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker input_width()122*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_width() const { 123*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 124*4bdc9457SAndroid Build Coastguard Worker } 125*4bdc9457SAndroid Build Coastguard Worker input_channels(size_t input_channels)126*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& input_channels(size_t input_channels) { 127*4bdc9457SAndroid Build Coastguard Worker assert(input_channels >= 1); 128*4bdc9457SAndroid Build Coastguard Worker this->input_channels_ = input_channels; 129*4bdc9457SAndroid Build Coastguard Worker return *this; 130*4bdc9457SAndroid Build Coastguard Worker } 131*4bdc9457SAndroid Build Coastguard Worker input_channels()132*4bdc9457SAndroid Build Coastguard Worker inline size_t input_channels() const { 133*4bdc9457SAndroid Build Coastguard Worker return this->input_channels_; 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker output_channels(size_t output_channels)136*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& output_channels(size_t output_channels) { 137*4bdc9457SAndroid Build Coastguard Worker assert(output_channels >= 1); 138*4bdc9457SAndroid Build Coastguard Worker this->output_channels_ = output_channels; 139*4bdc9457SAndroid Build Coastguard Worker return *this; 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker output_channels()142*4bdc9457SAndroid Build Coastguard Worker inline size_t output_channels() const { 143*4bdc9457SAndroid Build Coastguard Worker return this->output_channels_; 144*4bdc9457SAndroid Build Coastguard Worker } 145*4bdc9457SAndroid Build Coastguard Worker packed_output_channels()146*4bdc9457SAndroid Build Coastguard Worker inline size_t packed_output_channels() const { 147*4bdc9457SAndroid Build Coastguard Worker return output_channels() % output_channels_tile() == 0 ? output_channels() : output_channels() / output_channels_tile() * output_channels_tile() + output_channels_tile(); 148*4bdc9457SAndroid Build Coastguard Worker } 149*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)150*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& batch_size(size_t batch_size) { 151*4bdc9457SAndroid Build Coastguard Worker assert(batch_size >= 1); 152*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 153*4bdc9457SAndroid Build Coastguard Worker return *this; 154*4bdc9457SAndroid Build Coastguard Worker } 155*4bdc9457SAndroid Build Coastguard Worker batch_size()156*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 157*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 158*4bdc9457SAndroid Build Coastguard Worker } 159*4bdc9457SAndroid Build Coastguard Worker kernel_size(uint32_t kernel_size)160*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& kernel_size(uint32_t kernel_size) { 161*4bdc9457SAndroid Build Coastguard Worker assert(kernel_size >= 1); 162*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_size; 163*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_size; 164*4bdc9457SAndroid Build Coastguard Worker return *this; 165*4bdc9457SAndroid Build Coastguard Worker } 166*4bdc9457SAndroid Build Coastguard Worker kernel_height(uint32_t kernel_height)167*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& kernel_height(uint32_t kernel_height) { 168*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height >= 1); 169*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height; 170*4bdc9457SAndroid Build Coastguard Worker return *this; 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker kernel_height()173*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_height() const { 174*4bdc9457SAndroid Build Coastguard Worker return this->kernel_height_; 175*4bdc9457SAndroid Build Coastguard Worker } 176*4bdc9457SAndroid Build Coastguard Worker kernel_width(uint32_t kernel_width)177*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& kernel_width(uint32_t kernel_width) { 178*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width >= 1); 179*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width; 180*4bdc9457SAndroid Build Coastguard Worker return *this; 181*4bdc9457SAndroid Build Coastguard Worker } 182*4bdc9457SAndroid Build Coastguard Worker kernel_width()183*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_width() const { 184*4bdc9457SAndroid Build Coastguard Worker return this->kernel_width_; 185*4bdc9457SAndroid Build Coastguard Worker } 186*4bdc9457SAndroid Build Coastguard Worker subsampling(uint32_t subsampling)187*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& subsampling(uint32_t subsampling) { 188*4bdc9457SAndroid Build Coastguard Worker assert(subsampling >= 1); 189*4bdc9457SAndroid Build Coastguard Worker this->subsampling_height_ = subsampling; 190*4bdc9457SAndroid Build Coastguard Worker this->subsampling_width_ = subsampling; 191*4bdc9457SAndroid Build Coastguard Worker return *this; 192*4bdc9457SAndroid Build Coastguard Worker } 193*4bdc9457SAndroid Build Coastguard Worker subsampling_height(uint32_t subsampling_height)194*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& subsampling_height(uint32_t subsampling_height) { 195*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_height >= 1); 196*4bdc9457SAndroid Build Coastguard Worker this->subsampling_height_ = subsampling_height; 197*4bdc9457SAndroid Build Coastguard Worker return *this; 198*4bdc9457SAndroid Build Coastguard Worker } 199*4bdc9457SAndroid Build Coastguard Worker subsampling_height()200*4bdc9457SAndroid Build Coastguard Worker inline uint32_t subsampling_height() const { 201*4bdc9457SAndroid Build Coastguard Worker return this->subsampling_height_; 202*4bdc9457SAndroid Build Coastguard Worker } 203*4bdc9457SAndroid Build Coastguard Worker subsampling_width(uint32_t subsampling_width)204*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& subsampling_width(uint32_t subsampling_width) { 205*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_width >= 1); 206*4bdc9457SAndroid Build Coastguard Worker this->subsampling_width_ = subsampling_width; 207*4bdc9457SAndroid Build Coastguard Worker return *this; 208*4bdc9457SAndroid Build Coastguard Worker } 209*4bdc9457SAndroid Build Coastguard Worker subsampling_width()210*4bdc9457SAndroid Build Coastguard Worker inline uint32_t subsampling_width() const { 211*4bdc9457SAndroid Build Coastguard Worker return this->subsampling_width_; 212*4bdc9457SAndroid Build Coastguard Worker } 213*4bdc9457SAndroid Build Coastguard Worker output_y_start(uint32_t output_y_start)214*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& output_y_start(uint32_t output_y_start) { 215*4bdc9457SAndroid Build Coastguard Worker this->output_y_start_ = output_y_start; 216*4bdc9457SAndroid Build Coastguard Worker return *this; 217*4bdc9457SAndroid Build Coastguard Worker } 218*4bdc9457SAndroid Build Coastguard Worker output_y_start()219*4bdc9457SAndroid Build Coastguard Worker inline uint32_t output_y_start() const { 220*4bdc9457SAndroid Build Coastguard Worker return this->output_y_start_; 221*4bdc9457SAndroid Build Coastguard Worker } 222*4bdc9457SAndroid Build Coastguard Worker output_y_end(uint32_t output_y_end)223*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& output_y_end(uint32_t output_y_end) { 224*4bdc9457SAndroid Build Coastguard Worker this->output_y_end_ = output_y_end; 225*4bdc9457SAndroid Build Coastguard Worker return *this; 226*4bdc9457SAndroid Build Coastguard Worker } 227*4bdc9457SAndroid Build Coastguard Worker output_y_end()228*4bdc9457SAndroid Build Coastguard Worker inline uint32_t output_y_end() const { 229*4bdc9457SAndroid Build Coastguard Worker if (this->output_y_end_ == std::numeric_limits<uint32_t>::max()) { 230*4bdc9457SAndroid Build Coastguard Worker return output_height(); 231*4bdc9457SAndroid Build Coastguard Worker } else { 232*4bdc9457SAndroid Build Coastguard Worker return this->output_y_end_; 233*4bdc9457SAndroid Build Coastguard Worker } 234*4bdc9457SAndroid Build Coastguard Worker } 235*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()236*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 237*4bdc9457SAndroid Build Coastguard Worker return input_channels(); 238*4bdc9457SAndroid Build Coastguard Worker } 239*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()240*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 241*4bdc9457SAndroid Build Coastguard Worker return output_channels(); 242*4bdc9457SAndroid Build Coastguard Worker } 243*4bdc9457SAndroid Build Coastguard Worker output_height()244*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 245*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + input_height() + padding_bottom(); 246*4bdc9457SAndroid Build Coastguard Worker return (std::max<size_t>(padded_input_height + subsampling_height(), kernel_height()) - kernel_height()) 247*4bdc9457SAndroid Build Coastguard Worker / subsampling_height(); 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker output_width()250*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 251*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + input_width() + padding_right(); 252*4bdc9457SAndroid Build Coastguard Worker return (std::max<size_t>(padded_input_width + subsampling_width(), kernel_width()) - kernel_width()) 253*4bdc9457SAndroid Build Coastguard Worker / subsampling_width(); 254*4bdc9457SAndroid Build Coastguard Worker } 255*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)256*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& qmin(uint8_t qmin) { 257*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 258*4bdc9457SAndroid Build Coastguard Worker return *this; 259*4bdc9457SAndroid Build Coastguard Worker } 260*4bdc9457SAndroid Build Coastguard Worker qmin()261*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 262*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 263*4bdc9457SAndroid Build Coastguard Worker } 264*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)265*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& qmax(uint8_t qmax) { 266*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 267*4bdc9457SAndroid Build Coastguard Worker return *this; 268*4bdc9457SAndroid Build Coastguard Worker } 269*4bdc9457SAndroid Build Coastguard Worker qmax()270*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 271*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 272*4bdc9457SAndroid Build Coastguard Worker } 273*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)274*4bdc9457SAndroid Build Coastguard Worker inline ConvHWCMicrokernelTester& iterations(size_t iterations) { 275*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 276*4bdc9457SAndroid Build Coastguard Worker return *this; 277*4bdc9457SAndroid Build Coastguard Worker } 278*4bdc9457SAndroid Build Coastguard Worker iterations()279*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 280*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 281*4bdc9457SAndroid Build Coastguard Worker } 282*4bdc9457SAndroid Build Coastguard Worker 283*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_conv_hwc_ukernel_function conv, Variant variant = Variant::Native) const { 284*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(output_y_start(), output_height()); 285*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output_y_end(), output_height()); 286*4bdc9457SAndroid Build Coastguard Worker ASSERT_GT(output_y_end(), output_y_start()); 287*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output_width(), 1); 288*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output_height(), 1); 289*4bdc9457SAndroid Build Coastguard Worker 290*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 291*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 292*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 293*4bdc9457SAndroid Build Coastguard Worker 294*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 295*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_pixel_stride() + input_channels())); 296*4bdc9457SAndroid Build Coastguard Worker std::vector<float> zero(XNN_EXTRA_BYTES / sizeof(float) + input_width() * input_channels()); 297*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(output_channels() * kernel_height() * kernel_width() * input_channels()); 298*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(output_channels()); 299*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size() * ((output_height() * output_width() - 1) * output_pixel_stride() + output_channels())); 300*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * output_channels()); 301*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_weights((input_channels() * kernel_height() * kernel_width() + 1) * packed_output_channels()); 302*4bdc9457SAndroid Build Coastguard Worker 303*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 304*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 305*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 306*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 307*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 308*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_weights.begin(), packed_weights.end(), 0.0f); 309*4bdc9457SAndroid Build Coastguard Worker 310*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dconv_oki_w( 311*4bdc9457SAndroid Build Coastguard Worker output_channels(), 312*4bdc9457SAndroid Build Coastguard Worker input_channels(), 313*4bdc9457SAndroid Build Coastguard Worker output_channels_tile(), 314*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 315*4bdc9457SAndroid Build Coastguard Worker kernel.data(), bias.data(), packed_weights.data(), nullptr); 316*4bdc9457SAndroid Build Coastguard Worker 317*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 318*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 319*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 320*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 321*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < output_channels(); oc++) { 322*4bdc9457SAndroid Build Coastguard Worker float acc = bias[oc]; 323*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 324*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky - padding_top(); 325*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 326*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 327*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx - padding_left(); 328*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 329*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < input_channels(); ic++) { 330*4bdc9457SAndroid Build Coastguard Worker acc += 331*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + ic] * 332*4bdc9457SAndroid Build Coastguard Worker kernel[((oc * kernel_height() + ky) * kernel_width() + kx) * input_channels() + ic]; 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker } 335*4bdc9457SAndroid Build Coastguard Worker } 336*4bdc9457SAndroid Build Coastguard Worker } 337*4bdc9457SAndroid Build Coastguard Worker } 338*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * output_channels() + oc] = acc; 339*4bdc9457SAndroid Build Coastguard Worker } 340*4bdc9457SAndroid Build Coastguard Worker } 341*4bdc9457SAndroid Build Coastguard Worker } 342*4bdc9457SAndroid Build Coastguard Worker } 343*4bdc9457SAndroid Build Coastguard Worker 344*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 345*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 346*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 347*4bdc9457SAndroid Build Coastguard Worker 348*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 349*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 350*4bdc9457SAndroid Build Coastguard Worker 351*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 352*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 353*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 354*4bdc9457SAndroid Build Coastguard Worker } 355*4bdc9457SAndroid Build Coastguard Worker 356*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 357*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params; 358*4bdc9457SAndroid Build Coastguard Worker switch (variant) { 359*4bdc9457SAndroid Build Coastguard Worker case Variant::Native: 360*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_params(¶ms, output_min, output_max); 361*4bdc9457SAndroid Build Coastguard Worker break; 362*4bdc9457SAndroid Build Coastguard Worker case Variant::Scalar: 363*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_scalar_params(¶ms, output_min, output_max); 364*4bdc9457SAndroid Build Coastguard Worker break; 365*4bdc9457SAndroid Build Coastguard Worker } 366*4bdc9457SAndroid Build Coastguard Worker 367*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 368*4bdc9457SAndroid Build Coastguard Worker conv( 369*4bdc9457SAndroid Build Coastguard Worker input_height(), input_width(), 370*4bdc9457SAndroid Build Coastguard Worker output_y_start(), output_y_end(), 371*4bdc9457SAndroid Build Coastguard Worker input.data(), zero.data(), packed_weights.data(), output.data(), 372*4bdc9457SAndroid Build Coastguard Worker padding_top(), output_channels(), 373*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride() * output_width() * sizeof(float), 374*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride() * sizeof(float), 375*4bdc9457SAndroid Build Coastguard Worker ¶ms); 376*4bdc9457SAndroid Build Coastguard Worker 377*4bdc9457SAndroid Build Coastguard Worker // Verify results. 378*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 379*4bdc9457SAndroid Build Coastguard Worker for (size_t y = output_y_start(); y < output_y_end(); y++) { 380*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 381*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < output_channels(); c++) { 382*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min) 383*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), channel = " << c; 384*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max) 385*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), channel = " << c; 386*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 387*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * output_channels() + c], 388*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], 389*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[((i * output_height() + y) * output_width() + x) * output_channels() + c])) 390*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), channel = " << c; 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker } 393*4bdc9457SAndroid Build Coastguard Worker } 394*4bdc9457SAndroid Build Coastguard Worker } 395*4bdc9457SAndroid Build Coastguard Worker } 396*4bdc9457SAndroid Build Coastguard Worker } 397*4bdc9457SAndroid Build Coastguard Worker 398*4bdc9457SAndroid Build Coastguard Worker private: 399*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 400*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 401*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 402*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 403*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 404*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 405*4bdc9457SAndroid Build Coastguard Worker size_t input_channels_{1}; 406*4bdc9457SAndroid Build Coastguard Worker size_t output_channels_{1}; 407*4bdc9457SAndroid Build Coastguard Worker uint32_t output_channels_tile_{1}; 408*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 409*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height_{1}; 410*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width_{1}; 411*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height_{1}; 412*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width_{1}; 413*4bdc9457SAndroid Build Coastguard Worker uint32_t output_y_start_{0}; 414*4bdc9457SAndroid Build Coastguard Worker uint32_t output_y_end_{std::numeric_limits<uint32_t>::max()}; 415*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 416*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 417*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 418*4bdc9457SAndroid Build Coastguard Worker }; 419