1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 12*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <functional> 16*4bdc9457SAndroid Build Coastguard Worker #include <random> 17*4bdc9457SAndroid Build Coastguard Worker #include <vector> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker class ResizeBilinearOperatorTester { 25*4bdc9457SAndroid Build Coastguard Worker public: input_size(size_t input_height,size_t input_width)26*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& input_size(size_t input_height, size_t input_width) { 27*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 28*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 29*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 30*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 31*4bdc9457SAndroid Build Coastguard Worker return *this; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker input_height(size_t input_height)34*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& input_height(size_t input_height) { 35*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 36*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 37*4bdc9457SAndroid Build Coastguard Worker return *this; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker input_height()40*4bdc9457SAndroid Build Coastguard Worker inline size_t input_height() const { 41*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 42*4bdc9457SAndroid Build Coastguard Worker } 43*4bdc9457SAndroid Build Coastguard Worker input_width(size_t input_width)44*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& input_width(size_t input_width) { 45*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 46*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 47*4bdc9457SAndroid Build Coastguard Worker return *this; 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker input_width()50*4bdc9457SAndroid Build Coastguard Worker inline size_t input_width() const { 51*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker output_size(size_t output_height,size_t output_width)54*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& output_size(size_t output_height, size_t output_width) { 55*4bdc9457SAndroid Build Coastguard Worker assert(output_height >= 1); 56*4bdc9457SAndroid Build Coastguard Worker assert(output_width >= 1); 57*4bdc9457SAndroid Build Coastguard Worker this->output_height_ = output_height; 58*4bdc9457SAndroid Build Coastguard Worker this->output_width_ = output_width; 59*4bdc9457SAndroid Build Coastguard Worker return *this; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker output_height(size_t output_height)62*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& output_height(size_t output_height) { 63*4bdc9457SAndroid Build Coastguard Worker assert(output_height >= 1); 64*4bdc9457SAndroid Build Coastguard Worker this->output_height_ = output_height; 65*4bdc9457SAndroid Build Coastguard Worker return *this; 66*4bdc9457SAndroid Build Coastguard Worker } 67*4bdc9457SAndroid Build Coastguard Worker output_height()68*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 69*4bdc9457SAndroid Build Coastguard Worker return this->output_height_; 70*4bdc9457SAndroid Build Coastguard Worker } 71*4bdc9457SAndroid Build Coastguard Worker output_width(size_t output_width)72*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& output_width(size_t output_width) { 73*4bdc9457SAndroid Build Coastguard Worker assert(output_width >= 1); 74*4bdc9457SAndroid Build Coastguard Worker this->output_width_ = output_width; 75*4bdc9457SAndroid Build Coastguard Worker return *this; 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker output_width()78*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 79*4bdc9457SAndroid Build Coastguard Worker return this->output_width_; 80*4bdc9457SAndroid Build Coastguard Worker } 81*4bdc9457SAndroid Build Coastguard Worker height_scale()82*4bdc9457SAndroid Build Coastguard Worker inline float height_scale() const { 83*4bdc9457SAndroid Build Coastguard Worker if (align_corners() && output_height() > 1) { 84*4bdc9457SAndroid Build Coastguard Worker return float(input_height() - 1) / float(output_height() - 1); 85*4bdc9457SAndroid Build Coastguard Worker } else { 86*4bdc9457SAndroid Build Coastguard Worker return float(input_height()) / float(output_height()); 87*4bdc9457SAndroid Build Coastguard Worker } 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker width_scale()90*4bdc9457SAndroid Build Coastguard Worker inline float width_scale() const { 91*4bdc9457SAndroid Build Coastguard Worker if (align_corners() && output_width() > 1) { 92*4bdc9457SAndroid Build Coastguard Worker return float(input_width() - 1) / float(output_width() - 1); 93*4bdc9457SAndroid Build Coastguard Worker } else { 94*4bdc9457SAndroid Build Coastguard Worker return float(input_width()) / float(output_width()); 95*4bdc9457SAndroid Build Coastguard Worker } 96*4bdc9457SAndroid Build Coastguard Worker } 97*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)98*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& channels(size_t channels) { 99*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 100*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 101*4bdc9457SAndroid Build Coastguard Worker return *this; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker channels()104*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 105*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)108*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& batch_size(size_t batch_size) { 109*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 110*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 111*4bdc9457SAndroid Build Coastguard Worker return *this; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker batch_size()114*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 115*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 116*4bdc9457SAndroid Build Coastguard Worker } 117*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(size_t input_pixel_stride)118*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& input_pixel_stride(size_t input_pixel_stride) { 119*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride != 0); 120*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride; 121*4bdc9457SAndroid Build Coastguard Worker return *this; 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()124*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 125*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) { 126*4bdc9457SAndroid Build Coastguard Worker return channels(); 127*4bdc9457SAndroid Build Coastguard Worker } else { 128*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= channels()); 129*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_; 130*4bdc9457SAndroid Build Coastguard Worker } 131*4bdc9457SAndroid Build Coastguard Worker } 132*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride(size_t output_pixel_stride)133*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& output_pixel_stride(size_t output_pixel_stride) { 134*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride != 0); 135*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride; 136*4bdc9457SAndroid Build Coastguard Worker return *this; 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()139*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 140*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) { 141*4bdc9457SAndroid Build Coastguard Worker return channels(); 142*4bdc9457SAndroid Build Coastguard Worker } else { 143*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= channels()); 144*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_; 145*4bdc9457SAndroid Build Coastguard Worker } 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)148*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 149*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 150*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 151*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 152*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 153*4bdc9457SAndroid Build Coastguard Worker return *this; 154*4bdc9457SAndroid Build Coastguard Worker } 155*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)156*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& next_input_height(uint32_t next_input_height) { 157*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 158*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 159*4bdc9457SAndroid Build Coastguard Worker return *this; 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker next_input_height()162*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 163*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 164*4bdc9457SAndroid Build Coastguard Worker return input_height(); 165*4bdc9457SAndroid Build Coastguard Worker } else { 166*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 167*4bdc9457SAndroid Build Coastguard Worker } 168*4bdc9457SAndroid Build Coastguard Worker } 169*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)170*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& next_input_width(uint32_t next_input_width) { 171*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 172*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 173*4bdc9457SAndroid Build Coastguard Worker return *this; 174*4bdc9457SAndroid Build Coastguard Worker } 175*4bdc9457SAndroid Build Coastguard Worker next_input_width()176*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 177*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 178*4bdc9457SAndroid Build Coastguard Worker return input_width(); 179*4bdc9457SAndroid Build Coastguard Worker } else { 180*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 181*4bdc9457SAndroid Build Coastguard Worker } 182*4bdc9457SAndroid Build Coastguard Worker } 183*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)184*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& next_batch_size(size_t next_batch_size) { 185*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 186*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 187*4bdc9457SAndroid Build Coastguard Worker return *this; 188*4bdc9457SAndroid Build Coastguard Worker } 189*4bdc9457SAndroid Build Coastguard Worker next_batch_size()190*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 191*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 192*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 193*4bdc9457SAndroid Build Coastguard Worker } else { 194*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 195*4bdc9457SAndroid Build Coastguard Worker } 196*4bdc9457SAndroid Build Coastguard Worker } 197*4bdc9457SAndroid Build Coastguard Worker align_corners(bool align_corners)198*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& align_corners(bool align_corners) { 199*4bdc9457SAndroid Build Coastguard Worker this->align_corners_ = align_corners; 200*4bdc9457SAndroid Build Coastguard Worker return *this; 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker align_corners()203*4bdc9457SAndroid Build Coastguard Worker inline bool align_corners() const { 204*4bdc9457SAndroid Build Coastguard Worker return this->align_corners_; 205*4bdc9457SAndroid Build Coastguard Worker } 206*4bdc9457SAndroid Build Coastguard Worker tf_legacy_mode(bool tf_legacy_mode)207*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& tf_legacy_mode(bool tf_legacy_mode) { 208*4bdc9457SAndroid Build Coastguard Worker this->tf_legacy_mode_ = tf_legacy_mode; 209*4bdc9457SAndroid Build Coastguard Worker return *this; 210*4bdc9457SAndroid Build Coastguard Worker } 211*4bdc9457SAndroid Build Coastguard Worker tf_legacy_mode()212*4bdc9457SAndroid Build Coastguard Worker inline bool tf_legacy_mode() const { 213*4bdc9457SAndroid Build Coastguard Worker return this->tf_legacy_mode_; 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)216*4bdc9457SAndroid Build Coastguard Worker inline ResizeBilinearOperatorTester& iterations(size_t iterations) { 217*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 218*4bdc9457SAndroid Build Coastguard Worker return *this; 219*4bdc9457SAndroid Build Coastguard Worker } 220*4bdc9457SAndroid Build Coastguard Worker iterations()221*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 222*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 223*4bdc9457SAndroid Build Coastguard Worker } 224*4bdc9457SAndroid Build Coastguard Worker TestNHWCxF16()225*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxF16() const { 226*4bdc9457SAndroid Build Coastguard Worker if (align_corners()) { 227*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(tf_legacy_mode()); 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker 230*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 231*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 232*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 233*4bdc9457SAndroid Build Coastguard Worker 234*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 235*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels()); 236*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 237*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 238*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 239*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 240*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 241*4bdc9457SAndroid Build Coastguard Worker 242*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 243*4bdc9457SAndroid Build Coastguard Worker const float offset = (tf_legacy_mode() || align_corners()) ? 0.0f : 0.5f; 244*4bdc9457SAndroid Build Coastguard Worker for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) { 245*4bdc9457SAndroid Build Coastguard Worker for (size_t output_y = 0; output_y < output_height(); output_y++) { 246*4bdc9457SAndroid Build Coastguard Worker const float input_y = (float(output_y) + offset) * height_scale() - offset; 247*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_top = std::max<int64_t>(int64_t(std::floor(input_y)), 0); 248*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_bottom = std::min<int64_t>(int64_t(std::ceil(input_y)), input_height() - 1); 249*4bdc9457SAndroid Build Coastguard Worker const float y_alpha = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(input_y - std::floor(input_y))); 250*4bdc9457SAndroid Build Coastguard Worker for (size_t output_x = 0; output_x < output_width(); output_x++) { 251*4bdc9457SAndroid Build Coastguard Worker const float input_x = (float(output_x) + offset) * width_scale() - offset; 252*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_left = std::max<int64_t>(int64_t(std::floor(input_x)), 0); 253*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_right = std::min<int64_t>(int64_t(std::ceil(input_x)), input_width() - 1); 254*4bdc9457SAndroid Build Coastguard Worker const float x_alpha = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(input_x - std::floor(input_x))); 255*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 256*4bdc9457SAndroid Build Coastguard Worker output_ref[((batch_index * output_height() + output_y) * output_width() + output_x) * channels() + c] = 257*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_left) * input_pixel_stride() + c]) * (1.0f - y_alpha) * (1.0f - x_alpha) + 258*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_right) * input_pixel_stride() + c]) * (1.0f - y_alpha) * x_alpha + 259*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_left) * input_pixel_stride() + c]) * y_alpha * (1.0f - x_alpha) + 260*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_right) * input_pixel_stride() + c]) * y_alpha * x_alpha; 261*4bdc9457SAndroid Build Coastguard Worker } 262*4bdc9457SAndroid Build Coastguard Worker } 263*4bdc9457SAndroid Build Coastguard Worker } 264*4bdc9457SAndroid Build Coastguard Worker } 265*4bdc9457SAndroid Build Coastguard Worker 266*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Resize Bilinear operator. 267*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 268*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t resize_bilinear_op = nullptr; 269*4bdc9457SAndroid Build Coastguard Worker 270*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_resize_bilinear2d_nhwc_f16( 271*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 272*4bdc9457SAndroid Build Coastguard Worker (align_corners() ? XNN_FLAG_ALIGN_CORNERS : 0) | (tf_legacy_mode() ? XNN_FLAG_TENSORFLOW_LEGACY_MODE : 0), 273*4bdc9457SAndroid Build Coastguard Worker &resize_bilinear_op); 274*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 275*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 278*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, resize_bilinear_op); 279*4bdc9457SAndroid Build Coastguard Worker 280*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete resize_bilinear_op. 281*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_resize_bilinear_op(resize_bilinear_op, xnn_delete_operator); 282*4bdc9457SAndroid Build Coastguard Worker 283*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 284*4bdc9457SAndroid Build Coastguard Worker xnn_setup_resize_bilinear2d_nhwc_f16( 285*4bdc9457SAndroid Build Coastguard Worker resize_bilinear_op, 286*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 287*4bdc9457SAndroid Build Coastguard Worker output_height(), output_width(), 288*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 289*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 290*4bdc9457SAndroid Build Coastguard Worker 291*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 292*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(resize_bilinear_op, nullptr /* thread pool */)); 293*4bdc9457SAndroid Build Coastguard Worker 294*4bdc9457SAndroid Build Coastguard Worker // Verify results. 295*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 296*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 297*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 298*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 299*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 300*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), 301*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 302*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-2f)) << 303*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 304*4bdc9457SAndroid Build Coastguard Worker } 305*4bdc9457SAndroid Build Coastguard Worker } 306*4bdc9457SAndroid Build Coastguard Worker } 307*4bdc9457SAndroid Build Coastguard Worker } 308*4bdc9457SAndroid Build Coastguard Worker } 309*4bdc9457SAndroid Build Coastguard Worker } 310*4bdc9457SAndroid Build Coastguard Worker TestNHWCxF32()311*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxF32() const { 312*4bdc9457SAndroid Build Coastguard Worker if (align_corners()) { 313*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(tf_legacy_mode()); 314*4bdc9457SAndroid Build Coastguard Worker } 315*4bdc9457SAndroid Build Coastguard Worker 316*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 317*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 318*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 319*4bdc9457SAndroid Build Coastguard Worker 320*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 321*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 322*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 323*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 324*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 325*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 326*4bdc9457SAndroid Build Coastguard Worker 327*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 328*4bdc9457SAndroid Build Coastguard Worker const float offset = (tf_legacy_mode() || align_corners()) ? 0.0f : 0.5f; 329*4bdc9457SAndroid Build Coastguard Worker for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) { 330*4bdc9457SAndroid Build Coastguard Worker for (size_t output_y = 0; output_y < output_height(); output_y++) { 331*4bdc9457SAndroid Build Coastguard Worker const float input_y = (float(output_y) + offset) * height_scale() - offset; 332*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_top = std::max<int64_t>(int64_t(std::floor(input_y)), 0); 333*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_bottom = std::min<int64_t>(int64_t(std::ceil(input_y)), input_height() - 1); 334*4bdc9457SAndroid Build Coastguard Worker const float y_alpha = input_y - std::floor(input_y); 335*4bdc9457SAndroid Build Coastguard Worker for (size_t output_x = 0; output_x < output_width(); output_x++) { 336*4bdc9457SAndroid Build Coastguard Worker const float input_x = (float(output_x) + offset) * width_scale() - offset; 337*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_left = std::max<int64_t>(int64_t(std::floor(input_x)), 0); 338*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_right = std::min<int64_t>(int64_t(std::ceil(input_x)), input_width() - 1); 339*4bdc9457SAndroid Build Coastguard Worker const float x_alpha = input_x - std::floor(input_x); 340*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 341*4bdc9457SAndroid Build Coastguard Worker output_ref[((batch_index * output_height() + output_y) * output_width() + output_x) * channels() + c] = 342*4bdc9457SAndroid Build Coastguard Worker input[((batch_index * input_height() + input_y_top) * input_width() + input_x_left) * input_pixel_stride() + c] * (1.0f - y_alpha) * (1.0f - x_alpha) + 343*4bdc9457SAndroid Build Coastguard Worker input[((batch_index * input_height() + input_y_top) * input_width() + input_x_right) * input_pixel_stride() + c] * (1.0f - y_alpha) * x_alpha + 344*4bdc9457SAndroid Build Coastguard Worker input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_left) * input_pixel_stride() + c] * y_alpha * (1.0f - x_alpha) + 345*4bdc9457SAndroid Build Coastguard Worker input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_right) * input_pixel_stride() + c] * y_alpha * x_alpha; 346*4bdc9457SAndroid Build Coastguard Worker } 347*4bdc9457SAndroid Build Coastguard Worker } 348*4bdc9457SAndroid Build Coastguard Worker } 349*4bdc9457SAndroid Build Coastguard Worker } 350*4bdc9457SAndroid Build Coastguard Worker 351*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Resize Bilinear operator. 352*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 353*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t resize_bilinear_op = nullptr; 354*4bdc9457SAndroid Build Coastguard Worker 355*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 356*4bdc9457SAndroid Build Coastguard Worker xnn_create_resize_bilinear2d_nhwc_f32( 357*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 358*4bdc9457SAndroid Build Coastguard Worker (align_corners() ? XNN_FLAG_ALIGN_CORNERS : 0) | (tf_legacy_mode() ? XNN_FLAG_TENSORFLOW_LEGACY_MODE : 0), 359*4bdc9457SAndroid Build Coastguard Worker &resize_bilinear_op)); 360*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, resize_bilinear_op); 361*4bdc9457SAndroid Build Coastguard Worker 362*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete resize_bilinear_op. 363*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_resize_bilinear_op(resize_bilinear_op, xnn_delete_operator); 364*4bdc9457SAndroid Build Coastguard Worker 365*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 366*4bdc9457SAndroid Build Coastguard Worker xnn_setup_resize_bilinear2d_nhwc_f32( 367*4bdc9457SAndroid Build Coastguard Worker resize_bilinear_op, 368*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 369*4bdc9457SAndroid Build Coastguard Worker output_height(), output_width(), 370*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 371*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 372*4bdc9457SAndroid Build Coastguard Worker 373*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 374*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(resize_bilinear_op, nullptr /* thread pool */)); 375*4bdc9457SAndroid Build Coastguard Worker 376*4bdc9457SAndroid Build Coastguard Worker // Verify results. 377*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 378*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 379*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 380*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 381*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], 382*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 383*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-5f) << 384*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 385*4bdc9457SAndroid Build Coastguard Worker } 386*4bdc9457SAndroid Build Coastguard Worker } 387*4bdc9457SAndroid Build Coastguard Worker } 388*4bdc9457SAndroid Build Coastguard Worker } 389*4bdc9457SAndroid Build Coastguard Worker } 390*4bdc9457SAndroid Build Coastguard Worker } 391*4bdc9457SAndroid Build Coastguard Worker TestNHWCxS8()392*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxS8() const { 393*4bdc9457SAndroid Build Coastguard Worker if (align_corners()) { 394*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(tf_legacy_mode()); 395*4bdc9457SAndroid Build Coastguard Worker } 396*4bdc9457SAndroid Build Coastguard Worker 397*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 398*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 399*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 400*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 401*4bdc9457SAndroid Build Coastguard Worker 402*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 403*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 404*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 405*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 406*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 407*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 408*4bdc9457SAndroid Build Coastguard Worker 409*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 410*4bdc9457SAndroid Build Coastguard Worker const float offset = (tf_legacy_mode() || align_corners()) ? 0.0f : 0.5f; 411*4bdc9457SAndroid Build Coastguard Worker for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) { 412*4bdc9457SAndroid Build Coastguard Worker for (size_t output_y = 0; output_y < output_height(); output_y++) { 413*4bdc9457SAndroid Build Coastguard Worker const float input_y = (float(output_y) + offset) * height_scale() - offset; 414*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_top = std::max<int64_t>(int64_t(std::floor(input_y)), 0); 415*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_bottom = std::min<int64_t>(int64_t(std::ceil(input_y)), input_height() - 1); 416*4bdc9457SAndroid Build Coastguard Worker const float y_alpha = input_y - std::floor(input_y); 417*4bdc9457SAndroid Build Coastguard Worker for (size_t output_x = 0; output_x < output_width(); output_x++) { 418*4bdc9457SAndroid Build Coastguard Worker const float input_x = (float(output_x) + offset) * width_scale() - offset; 419*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_left = std::max<int64_t>(int64_t(std::floor(input_x)), 0); 420*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_right = std::min<int64_t>(int64_t(std::ceil(input_x)), input_width() - 1); 421*4bdc9457SAndroid Build Coastguard Worker const float x_alpha = input_x - std::floor(input_x); 422*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 423*4bdc9457SAndroid Build Coastguard Worker output_ref[((batch_index * output_height() + output_y) * output_width() + output_x) * channels() + c] = 424*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_left) * input_pixel_stride() + c])) * (1.0f - y_alpha) * (1.0f - x_alpha) + 425*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_right) * input_pixel_stride() + c])) * (1.0f - y_alpha) * x_alpha + 426*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_left) * input_pixel_stride() + c])) * y_alpha * (1.0f - x_alpha) + 427*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_right) * input_pixel_stride() + c])) * y_alpha * x_alpha; 428*4bdc9457SAndroid Build Coastguard Worker } 429*4bdc9457SAndroid Build Coastguard Worker } 430*4bdc9457SAndroid Build Coastguard Worker } 431*4bdc9457SAndroid Build Coastguard Worker } 432*4bdc9457SAndroid Build Coastguard Worker 433*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Resize Bilinear operator. 434*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 435*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t resize_bilinear_op = nullptr; 436*4bdc9457SAndroid Build Coastguard Worker 437*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 438*4bdc9457SAndroid Build Coastguard Worker xnn_create_resize_bilinear2d_nhwc_s8( 439*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 440*4bdc9457SAndroid Build Coastguard Worker (align_corners() ? XNN_FLAG_ALIGN_CORNERS : 0) | (tf_legacy_mode() ? XNN_FLAG_TENSORFLOW_LEGACY_MODE : 0), 441*4bdc9457SAndroid Build Coastguard Worker &resize_bilinear_op)); 442*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, resize_bilinear_op); 443*4bdc9457SAndroid Build Coastguard Worker 444*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete resize_bilinear_op. 445*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_resize_bilinear_op(resize_bilinear_op, xnn_delete_operator); 446*4bdc9457SAndroid Build Coastguard Worker 447*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 448*4bdc9457SAndroid Build Coastguard Worker xnn_setup_resize_bilinear2d_nhwc_s8( 449*4bdc9457SAndroid Build Coastguard Worker resize_bilinear_op, 450*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 451*4bdc9457SAndroid Build Coastguard Worker output_height(), output_width(), 452*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 453*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 454*4bdc9457SAndroid Build Coastguard Worker 455*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 456*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(resize_bilinear_op, nullptr /* thread pool */)); 457*4bdc9457SAndroid Build Coastguard Worker 458*4bdc9457SAndroid Build Coastguard Worker // Verify results. 459*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 460*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 461*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 462*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 463*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 464*4bdc9457SAndroid Build Coastguard Worker float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])), 465*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 466*4bdc9457SAndroid Build Coastguard Worker 0.6f) << 467*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 468*4bdc9457SAndroid Build Coastguard Worker } 469*4bdc9457SAndroid Build Coastguard Worker } 470*4bdc9457SAndroid Build Coastguard Worker } 471*4bdc9457SAndroid Build Coastguard Worker } 472*4bdc9457SAndroid Build Coastguard Worker } 473*4bdc9457SAndroid Build Coastguard Worker } 474*4bdc9457SAndroid Build Coastguard Worker TestNHWCxU8()475*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxU8() const { 476*4bdc9457SAndroid Build Coastguard Worker if (align_corners()) { 477*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(tf_legacy_mode()); 478*4bdc9457SAndroid Build Coastguard Worker } 479*4bdc9457SAndroid Build Coastguard Worker 480*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 481*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 482*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 483*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 484*4bdc9457SAndroid Build Coastguard Worker 485*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 486*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 487*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 488*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 489*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 490*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 491*4bdc9457SAndroid Build Coastguard Worker 492*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 493*4bdc9457SAndroid Build Coastguard Worker const float offset = (tf_legacy_mode() || align_corners()) ? 0.0f : 0.5f; 494*4bdc9457SAndroid Build Coastguard Worker for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) { 495*4bdc9457SAndroid Build Coastguard Worker for (size_t output_y = 0; output_y < output_height(); output_y++) { 496*4bdc9457SAndroid Build Coastguard Worker const float input_y = (float(output_y) + offset) * height_scale() - offset; 497*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_top = std::max<int64_t>(int64_t(std::floor(input_y)), 0); 498*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_bottom = std::min<int64_t>(int64_t(std::ceil(input_y)), input_height() - 1); 499*4bdc9457SAndroid Build Coastguard Worker const float y_alpha = input_y - std::floor(input_y); 500*4bdc9457SAndroid Build Coastguard Worker for (size_t output_x = 0; output_x < output_width(); output_x++) { 501*4bdc9457SAndroid Build Coastguard Worker const float input_x = (float(output_x) + offset) * width_scale() - offset; 502*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_left = std::max<int64_t>(int64_t(std::floor(input_x)), 0); 503*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_right = std::min<int64_t>(int64_t(std::ceil(input_x)), input_width() - 1); 504*4bdc9457SAndroid Build Coastguard Worker const float x_alpha = input_x - std::floor(input_x); 505*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 506*4bdc9457SAndroid Build Coastguard Worker output_ref[((batch_index * output_height() + output_y) * output_width() + output_x) * channels() + c] = 507*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_left) * input_pixel_stride() + c])) * (1.0f - y_alpha) * (1.0f - x_alpha) + 508*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_top) * input_width() + input_x_right) * input_pixel_stride() + c])) * (1.0f - y_alpha) * x_alpha + 509*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_left) * input_pixel_stride() + c])) * y_alpha * (1.0f - x_alpha) + 510*4bdc9457SAndroid Build Coastguard Worker float(int32_t(input[((batch_index * input_height() + input_y_bottom) * input_width() + input_x_right) * input_pixel_stride() + c])) * y_alpha * x_alpha; 511*4bdc9457SAndroid Build Coastguard Worker } 512*4bdc9457SAndroid Build Coastguard Worker } 513*4bdc9457SAndroid Build Coastguard Worker } 514*4bdc9457SAndroid Build Coastguard Worker } 515*4bdc9457SAndroid Build Coastguard Worker 516*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Resize Bilinear operator. 517*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 518*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t resize_bilinear_op = nullptr; 519*4bdc9457SAndroid Build Coastguard Worker 520*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 521*4bdc9457SAndroid Build Coastguard Worker xnn_create_resize_bilinear2d_nhwc_u8( 522*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 523*4bdc9457SAndroid Build Coastguard Worker (align_corners() ? XNN_FLAG_ALIGN_CORNERS : 0) | (tf_legacy_mode() ? XNN_FLAG_TENSORFLOW_LEGACY_MODE : 0), 524*4bdc9457SAndroid Build Coastguard Worker &resize_bilinear_op)); 525*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, resize_bilinear_op); 526*4bdc9457SAndroid Build Coastguard Worker 527*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete resize_bilinear_op. 528*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_resize_bilinear_op(resize_bilinear_op, xnn_delete_operator); 529*4bdc9457SAndroid Build Coastguard Worker 530*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 531*4bdc9457SAndroid Build Coastguard Worker xnn_setup_resize_bilinear2d_nhwc_u8( 532*4bdc9457SAndroid Build Coastguard Worker resize_bilinear_op, 533*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 534*4bdc9457SAndroid Build Coastguard Worker output_height(), output_width(), 535*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 536*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 537*4bdc9457SAndroid Build Coastguard Worker 538*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 539*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(resize_bilinear_op, nullptr /* thread pool */)); 540*4bdc9457SAndroid Build Coastguard Worker 541*4bdc9457SAndroid Build Coastguard Worker // Verify results. 542*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 543*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 544*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 545*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 546*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 547*4bdc9457SAndroid Build Coastguard Worker float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])), 548*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 549*4bdc9457SAndroid Build Coastguard Worker 0.6f) << 550*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 551*4bdc9457SAndroid Build Coastguard Worker } 552*4bdc9457SAndroid Build Coastguard Worker } 553*4bdc9457SAndroid Build Coastguard Worker } 554*4bdc9457SAndroid Build Coastguard Worker } 555*4bdc9457SAndroid Build Coastguard Worker } 556*4bdc9457SAndroid Build Coastguard Worker } 557*4bdc9457SAndroid Build Coastguard Worker TestNCHWxF32()558*4bdc9457SAndroid Build Coastguard Worker void TestNCHWxF32() const { 559*4bdc9457SAndroid Build Coastguard Worker if (align_corners()) { 560*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(tf_legacy_mode()); 561*4bdc9457SAndroid Build Coastguard Worker } 562*4bdc9457SAndroid Build Coastguard Worker 563*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 564*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 565*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 566*4bdc9457SAndroid Build Coastguard Worker 567*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 568*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 569*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 570*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 571*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 572*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 573*4bdc9457SAndroid Build Coastguard Worker 574*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 575*4bdc9457SAndroid Build Coastguard Worker const float offset = (tf_legacy_mode() || align_corners()) ? 0.0f : 0.5f; 576*4bdc9457SAndroid Build Coastguard Worker const int64_t input_num_pixels = input_height() * input_width(); 577*4bdc9457SAndroid Build Coastguard Worker const int64_t input_num_elements = input_num_pixels * input_pixel_stride(); 578*4bdc9457SAndroid Build Coastguard Worker const int64_t output_num_pixels = output_height() * output_width(); 579*4bdc9457SAndroid Build Coastguard Worker const int64_t output_num_elements = output_num_pixels * channels(); 580*4bdc9457SAndroid Build Coastguard Worker for (size_t batch_index = 0; batch_index < batch_size(); batch_index++) { 581*4bdc9457SAndroid Build Coastguard Worker for (size_t output_y = 0; output_y < output_height(); output_y++) { 582*4bdc9457SAndroid Build Coastguard Worker const float input_y = (float(output_y) + offset) * height_scale() - offset; 583*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_top = std::max<int64_t>(int64_t(std::floor(input_y)), 0); 584*4bdc9457SAndroid Build Coastguard Worker const int64_t input_y_bottom = std::min<int64_t>(int64_t(std::ceil(input_y)), input_height() - 1); 585*4bdc9457SAndroid Build Coastguard Worker const float y_alpha = input_y - std::floor(input_y); 586*4bdc9457SAndroid Build Coastguard Worker for (size_t output_x = 0; output_x < output_width(); output_x++) { 587*4bdc9457SAndroid Build Coastguard Worker const float input_x = (float(output_x) + offset) * width_scale() - offset; 588*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_left = std::max<int64_t>(int64_t(std::floor(input_x)), 0); 589*4bdc9457SAndroid Build Coastguard Worker const int64_t input_x_right = std::min<int64_t>(int64_t(std::ceil(input_x)), input_width() - 1); 590*4bdc9457SAndroid Build Coastguard Worker const float x_alpha = input_x - std::floor(input_x); 591*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 592*4bdc9457SAndroid Build Coastguard Worker output_ref[batch_index * output_num_elements + c * output_num_pixels + output_y * output_width() + output_x] = 593*4bdc9457SAndroid Build Coastguard Worker input[batch_index * input_num_elements + c * input_num_pixels + input_y_top * input_width() + input_x_left] * (1.0f - y_alpha) * (1.0f - x_alpha) + 594*4bdc9457SAndroid Build Coastguard Worker input[batch_index * input_num_elements + c * input_num_pixels + input_y_top * input_width() + input_x_right] * (1.0f - y_alpha) * x_alpha + 595*4bdc9457SAndroid Build Coastguard Worker input[batch_index * input_num_elements + c * input_num_pixels + input_y_bottom * input_width() + input_x_left] * y_alpha * (1.0f - x_alpha) + 596*4bdc9457SAndroid Build Coastguard Worker input[batch_index * input_num_elements + c * input_num_pixels + input_y_bottom * input_width() + input_x_right] * y_alpha * x_alpha; 597*4bdc9457SAndroid Build Coastguard Worker } 598*4bdc9457SAndroid Build Coastguard Worker } 599*4bdc9457SAndroid Build Coastguard Worker } 600*4bdc9457SAndroid Build Coastguard Worker } 601*4bdc9457SAndroid Build Coastguard Worker 602*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Resize Bilinear operator. 603*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 604*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t resize_bilinear_op = nullptr; 605*4bdc9457SAndroid Build Coastguard Worker 606*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 607*4bdc9457SAndroid Build Coastguard Worker xnn_create_resize_bilinear2d_nchw_f32( 608*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 609*4bdc9457SAndroid Build Coastguard Worker (align_corners() ? XNN_FLAG_ALIGN_CORNERS : 0) | (tf_legacy_mode() ? XNN_FLAG_TENSORFLOW_LEGACY_MODE : 0), 610*4bdc9457SAndroid Build Coastguard Worker &resize_bilinear_op)); 611*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, resize_bilinear_op); 612*4bdc9457SAndroid Build Coastguard Worker 613*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete resize_bilinear_op. 614*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_resize_bilinear_op(resize_bilinear_op, xnn_delete_operator); 615*4bdc9457SAndroid Build Coastguard Worker 616*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 617*4bdc9457SAndroid Build Coastguard Worker xnn_setup_resize_bilinear2d_nchw_f32( 618*4bdc9457SAndroid Build Coastguard Worker resize_bilinear_op, 619*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 620*4bdc9457SAndroid Build Coastguard Worker output_height(), output_width(), 621*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 622*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 623*4bdc9457SAndroid Build Coastguard Worker 624*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 625*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(resize_bilinear_op, nullptr /* thread pool */)); 626*4bdc9457SAndroid Build Coastguard Worker 627*4bdc9457SAndroid Build Coastguard Worker // Verify results. 628*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 629*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 630*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 631*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 632*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[i * output_num_elements + c * output_num_pixels + y * output_width() + x], 633*4bdc9457SAndroid Build Coastguard Worker output_ref[i * output_num_elements + c * output_num_pixels + y * output_width() + x], 634*4bdc9457SAndroid Build Coastguard Worker 1.0e-6f) << 635*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 636*4bdc9457SAndroid Build Coastguard Worker } 637*4bdc9457SAndroid Build Coastguard Worker } 638*4bdc9457SAndroid Build Coastguard Worker } 639*4bdc9457SAndroid Build Coastguard Worker } 640*4bdc9457SAndroid Build Coastguard Worker } 641*4bdc9457SAndroid Build Coastguard Worker } 642*4bdc9457SAndroid Build Coastguard Worker 643*4bdc9457SAndroid Build Coastguard Worker private: 644*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 645*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 646*4bdc9457SAndroid Build Coastguard Worker size_t output_height_{1}; 647*4bdc9457SAndroid Build Coastguard Worker size_t output_width_{1}; 648*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 649*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 650*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0}; 651*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0}; 652*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 653*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 654*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 655*4bdc9457SAndroid Build Coastguard Worker bool align_corners_{false}; 656*4bdc9457SAndroid Build Coastguard Worker bool tf_legacy_mode_{false}; 657*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 658*4bdc9457SAndroid Build Coastguard Worker }; 659