1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 18*4bdc9457SAndroid Build Coastguard Worker #include <random> 19*4bdc9457SAndroid Build Coastguard Worker #include <vector> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h> 25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h> 26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h> 27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 29*4bdc9457SAndroid Build Coastguard Worker 30*4bdc9457SAndroid Build Coastguard Worker 31*4bdc9457SAndroid Build Coastguard Worker class DWConv2DMicrokernelTester { 32*4bdc9457SAndroid Build Coastguard Worker public: 33*4bdc9457SAndroid Build Coastguard Worker enum class Variant { 34*4bdc9457SAndroid Build Coastguard Worker Native, 35*4bdc9457SAndroid Build Coastguard Worker Scalar, 36*4bdc9457SAndroid Build Coastguard Worker }; 37*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)38*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& padding_left(uint32_t padding_left) { 39*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 40*4bdc9457SAndroid Build Coastguard Worker return *this; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker padding_left()43*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 44*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)47*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& padding_right(uint32_t padding_right) { 48*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 49*4bdc9457SAndroid Build Coastguard Worker return *this; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker padding_right()52*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 53*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)56*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& padding_top(uint32_t padding_top) { 57*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker padding_top()61*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 62*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker 65*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)66*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& padding_bottom(uint32_t padding_bottom) { 67*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 68*4bdc9457SAndroid Build Coastguard Worker return *this; 69*4bdc9457SAndroid Build Coastguard Worker } padding_bottom()70*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 71*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker input_height(uint32_t input_height)74*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& input_height(uint32_t input_height) { 75*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 76*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 77*4bdc9457SAndroid Build Coastguard Worker return *this; 78*4bdc9457SAndroid Build Coastguard Worker } 79*4bdc9457SAndroid Build Coastguard Worker input_height()80*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_height() const { 81*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker input_width(uint32_t input_width)84*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& input_width(uint32_t input_width) { 85*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 86*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 87*4bdc9457SAndroid Build Coastguard Worker return *this; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker input_width()90*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_width() const { 91*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker subsampling(uint32_t subsampling)94*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& subsampling(uint32_t subsampling) { 95*4bdc9457SAndroid Build Coastguard Worker assert(subsampling >= 1); 96*4bdc9457SAndroid Build Coastguard Worker this->subsampling_ = subsampling; 97*4bdc9457SAndroid Build Coastguard Worker return *this; 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker subsampling()100*4bdc9457SAndroid Build Coastguard Worker inline uint32_t subsampling() const { 101*4bdc9457SAndroid Build Coastguard Worker return this->subsampling_; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker kernel_height(uint32_t kernel_height)104*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& kernel_height(uint32_t kernel_height) { 105*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height != 0); 106*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height; 107*4bdc9457SAndroid Build Coastguard Worker return *this; 108*4bdc9457SAndroid Build Coastguard Worker } 109*4bdc9457SAndroid Build Coastguard Worker kernel_height()110*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_height() const { 111*4bdc9457SAndroid Build Coastguard Worker return this->kernel_height_; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker kernel_width(uint32_t kernel_width)114*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& kernel_width(uint32_t kernel_width) { 115*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width != 0); 116*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width; 117*4bdc9457SAndroid Build Coastguard Worker return *this; 118*4bdc9457SAndroid Build Coastguard Worker } 119*4bdc9457SAndroid Build Coastguard Worker kernel_width()120*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_width() const { 121*4bdc9457SAndroid Build Coastguard Worker return this->kernel_width_; 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker kernel_size()124*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_size() const { 125*4bdc9457SAndroid Build Coastguard Worker return kernel_height() * kernel_width(); 126*4bdc9457SAndroid Build Coastguard Worker } 127*4bdc9457SAndroid Build Coastguard Worker output_height()128*4bdc9457SAndroid Build Coastguard Worker inline uint32_t output_height() const { 129*4bdc9457SAndroid Build Coastguard Worker const uint32_t padded_input_height = padding_top() + input_height() + padding_bottom(); 130*4bdc9457SAndroid Build Coastguard Worker if (padded_input_height <= kernel_height()) { 131*4bdc9457SAndroid Build Coastguard Worker return 1; 132*4bdc9457SAndroid Build Coastguard Worker } else { 133*4bdc9457SAndroid Build Coastguard Worker return (padded_input_height - kernel_height()) / subsampling() + 1; 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker } 136*4bdc9457SAndroid Build Coastguard Worker output_width()137*4bdc9457SAndroid Build Coastguard Worker inline uint32_t output_width() const { 138*4bdc9457SAndroid Build Coastguard Worker const uint32_t padded_input_width = padding_left() + input_width() + padding_right(); 139*4bdc9457SAndroid Build Coastguard Worker if (padded_input_width <= kernel_width()) { 140*4bdc9457SAndroid Build Coastguard Worker return 1; 141*4bdc9457SAndroid Build Coastguard Worker } else { 142*4bdc9457SAndroid Build Coastguard Worker return (padded_input_width - kernel_width()) / subsampling() + 1; 143*4bdc9457SAndroid Build Coastguard Worker } 144*4bdc9457SAndroid Build Coastguard Worker } 145*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)146*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& qmin(uint8_t qmin) { 147*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 148*4bdc9457SAndroid Build Coastguard Worker return *this; 149*4bdc9457SAndroid Build Coastguard Worker } 150*4bdc9457SAndroid Build Coastguard Worker qmin()151*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 152*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 153*4bdc9457SAndroid Build Coastguard Worker } 154*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)155*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& qmax(uint8_t qmax) { 156*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 157*4bdc9457SAndroid Build Coastguard Worker return *this; 158*4bdc9457SAndroid Build Coastguard Worker } 159*4bdc9457SAndroid Build Coastguard Worker qmax()160*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 161*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 162*4bdc9457SAndroid Build Coastguard Worker } 163*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)164*4bdc9457SAndroid Build Coastguard Worker inline DWConv2DMicrokernelTester& iterations(size_t iterations) { 165*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 166*4bdc9457SAndroid Build Coastguard Worker return *this; 167*4bdc9457SAndroid Build Coastguard Worker } 168*4bdc9457SAndroid Build Coastguard Worker iterations()169*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 170*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker 173*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_dwconv2d_chw_ukernel_function dwconv, Variant variant = Variant::Native) const { 174*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 175*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 176*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 177*4bdc9457SAndroid Build Coastguard Worker 178*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> input(input_height() * input_width() + 2 * XNN_EXTRA_BYTES); 179*4bdc9457SAndroid Build Coastguard Worker std::vector<float> zero(input_width() + 2 * XNN_EXTRA_BYTES); 180*4bdc9457SAndroid Build Coastguard Worker std::vector<float> packed_weights(kernel_size() + 1); 181*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> output(output_height() * output_width()); 182*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(output_height() * output_width()); 183*4bdc9457SAndroid Build Coastguard Worker 184*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 185*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 186*4bdc9457SAndroid Build Coastguard Worker std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return f32dist(rng); }); 187*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 188*4bdc9457SAndroid Build Coastguard Worker 189*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 190*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 191*4bdc9457SAndroid Build Coastguard Worker float acc = packed_weights[0]; 192*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 193*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling() + ky - padding_top(); 194*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 195*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling() + kx - padding_left(); 196*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 197*4bdc9457SAndroid Build Coastguard Worker const float input_val = input[iy * input_width() + ix]; 198*4bdc9457SAndroid Build Coastguard Worker const float kernel_val = packed_weights[1 + ky * kernel_width() + kx]; 199*4bdc9457SAndroid Build Coastguard Worker acc += input_val * kernel_val; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker } 203*4bdc9457SAndroid Build Coastguard Worker output_ref[oy * output_width() + ox] = acc; 204*4bdc9457SAndroid Build Coastguard Worker } 205*4bdc9457SAndroid Build Coastguard Worker } 206*4bdc9457SAndroid Build Coastguard Worker 207*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 208*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 209*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 210*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 211*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin()); 212*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 213*4bdc9457SAndroid Build Coastguard Worker 214*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 215*4bdc9457SAndroid Build Coastguard Worker xnn_f32_chw_params chw_params; 216*4bdc9457SAndroid Build Coastguard Worker switch (variant) { 217*4bdc9457SAndroid Build Coastguard Worker case Variant::Native: 218*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_chw_params(&chw_params, input_width(), output_min, output_max); 219*4bdc9457SAndroid Build Coastguard Worker break; 220*4bdc9457SAndroid Build Coastguard Worker case Variant::Scalar: 221*4bdc9457SAndroid Build Coastguard Worker xnn_init_scalar_f32_chw_params(&chw_params, input_width(), output_min, output_max); 222*4bdc9457SAndroid Build Coastguard Worker break; 223*4bdc9457SAndroid Build Coastguard Worker } 224*4bdc9457SAndroid Build Coastguard Worker 225*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 226*4bdc9457SAndroid Build Coastguard Worker for (float& output_val : output_ref) { 227*4bdc9457SAndroid Build Coastguard Worker output_val = std::max(std::min(output_val, output_max), output_min); 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker 230*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 231*4bdc9457SAndroid Build Coastguard Worker dwconv( 232*4bdc9457SAndroid Build Coastguard Worker input_height(), input_width() * sizeof(float), 233*4bdc9457SAndroid Build Coastguard Worker input.data(), packed_weights.data(), zero.data(), output.data(), 234*4bdc9457SAndroid Build Coastguard Worker padding_top(), 235*4bdc9457SAndroid Build Coastguard Worker &chw_params); 236*4bdc9457SAndroid Build Coastguard Worker 237*4bdc9457SAndroid Build Coastguard Worker // Verify results. 238*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 239*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 240*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 241*4bdc9457SAndroid Build Coastguard Worker output_ref[y * output_width() + x], 242*4bdc9457SAndroid Build Coastguard Worker output[y * output_width() + x], 243*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[y * output_width() + x]) * 1.0e-5) 244*4bdc9457SAndroid Build Coastguard Worker << "x = " << x << ", y = " << y; 245*4bdc9457SAndroid Build Coastguard Worker } 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker } 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_dwconv2d_chw_ukernel_function dwconv,xnn_init_f16_chw_params_fn init_params)250*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_dwconv2d_chw_ukernel_function dwconv, xnn_init_f16_chw_params_fn init_params) const { 251*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 252*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 253*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 254*4bdc9457SAndroid Build Coastguard Worker 255*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> input(input_height() * input_width() + 2 * XNN_EXTRA_BYTES); 256*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> zero(input_width() + 2 * XNN_EXTRA_BYTES); 257*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> packed_weights(kernel_size() + 1); 258*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> output(output_height() * output_width()); 259*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(output_height() * output_width()); 260*4bdc9457SAndroid Build Coastguard Worker 261*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 262*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 263*4bdc9457SAndroid Build Coastguard Worker std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 264*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 265*4bdc9457SAndroid Build Coastguard Worker 266*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 267*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 268*4bdc9457SAndroid Build Coastguard Worker float acc = fp16_ieee_to_fp32_value(packed_weights[0]); 269*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 270*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling() + ky - padding_top(); 271*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 272*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling() + kx - padding_left(); 273*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 274*4bdc9457SAndroid Build Coastguard Worker const float input_val = fp16_ieee_to_fp32_value(input[iy * input_width() + ix]); 275*4bdc9457SAndroid Build Coastguard Worker const float kernel_val = fp16_ieee_to_fp32_value(packed_weights[1 + ky * kernel_width() + kx]); 276*4bdc9457SAndroid Build Coastguard Worker acc += input_val * kernel_val; 277*4bdc9457SAndroid Build Coastguard Worker } 278*4bdc9457SAndroid Build Coastguard Worker } 279*4bdc9457SAndroid Build Coastguard Worker } 280*4bdc9457SAndroid Build Coastguard Worker output_ref[oy * output_width() + ox] = acc; 281*4bdc9457SAndroid Build Coastguard Worker } 282*4bdc9457SAndroid Build Coastguard Worker } 283*4bdc9457SAndroid Build Coastguard Worker 284*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 285*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 286*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 287*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 288*4bdc9457SAndroid Build Coastguard Worker const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 289*4bdc9457SAndroid Build Coastguard Worker const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 290*4bdc9457SAndroid Build Coastguard Worker 291*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 292*4bdc9457SAndroid Build Coastguard Worker xnn_f16_chw_params chw_params; 293*4bdc9457SAndroid Build Coastguard Worker init_params(&chw_params, input_width(), 294*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_min), 295*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_max)); 296*4bdc9457SAndroid Build Coastguard Worker 297*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 298*4bdc9457SAndroid Build Coastguard Worker for (float& output_val : output_ref) { 299*4bdc9457SAndroid Build Coastguard Worker output_val = std::max(std::min(output_val, output_max), output_min); 300*4bdc9457SAndroid Build Coastguard Worker } 301*4bdc9457SAndroid Build Coastguard Worker 302*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 303*4bdc9457SAndroid Build Coastguard Worker dwconv( 304*4bdc9457SAndroid Build Coastguard Worker input_height(), input_width() * sizeof(uint16_t), 305*4bdc9457SAndroid Build Coastguard Worker input.data(), packed_weights.data(), zero.data(), output.data(), 306*4bdc9457SAndroid Build Coastguard Worker padding_top(), 307*4bdc9457SAndroid Build Coastguard Worker &chw_params); 308*4bdc9457SAndroid Build Coastguard Worker 309*4bdc9457SAndroid Build Coastguard Worker // Verify results. 310*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 311*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 312*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 313*4bdc9457SAndroid Build Coastguard Worker output_ref[y * output_width() + x], 314*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[y * output_width() + x]), 315*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[y * output_width() + x]) * 1.0e-2f) 316*4bdc9457SAndroid Build Coastguard Worker << "x = " << x << ", y = " << y; 317*4bdc9457SAndroid Build Coastguard Worker } 318*4bdc9457SAndroid Build Coastguard Worker } 319*4bdc9457SAndroid Build Coastguard Worker } 320*4bdc9457SAndroid Build Coastguard Worker } 321*4bdc9457SAndroid Build Coastguard Worker 322*4bdc9457SAndroid Build Coastguard Worker private: 323*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 324*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 325*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 326*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 327*4bdc9457SAndroid Build Coastguard Worker uint32_t input_height_{1}; 328*4bdc9457SAndroid Build Coastguard Worker uint32_t input_width_{1}; 329*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_{1}; 330*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height_{1}; 331*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width_{1}; 332*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 333*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 334*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 335*4bdc9457SAndroid Build Coastguard Worker }; 336