1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 17*4bdc9457SAndroid Build Coastguard Worker #include <limits> 18*4bdc9457SAndroid Build Coastguard Worker #include <random> 19*4bdc9457SAndroid Build Coastguard Worker #include <vector> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker class ClampOperatorTester { 27*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)28*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& channels(size_t channels) { 29*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 30*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 31*4bdc9457SAndroid Build Coastguard Worker return *this; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker channels()34*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 35*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 36*4bdc9457SAndroid Build Coastguard Worker } 37*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)38*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& input_stride(size_t input_stride) { 39*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 40*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 41*4bdc9457SAndroid Build Coastguard Worker return *this; 42*4bdc9457SAndroid Build Coastguard Worker } 43*4bdc9457SAndroid Build Coastguard Worker input_stride()44*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 45*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 46*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 47*4bdc9457SAndroid Build Coastguard Worker } else { 48*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= this->channels_); 49*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)53*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& output_stride(size_t output_stride) { 54*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 55*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 56*4bdc9457SAndroid Build Coastguard Worker return *this; 57*4bdc9457SAndroid Build Coastguard Worker } 58*4bdc9457SAndroid Build Coastguard Worker output_stride()59*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 60*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 61*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 62*4bdc9457SAndroid Build Coastguard Worker } else { 63*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= this->channels_); 64*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker } 67*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)68*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& batch_size(size_t batch_size) { 69*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 70*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 71*4bdc9457SAndroid Build Coastguard Worker return *this; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker batch_size()74*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 75*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker qmin(int16_t qmin)78*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& qmin(int16_t qmin) { 79*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 80*4bdc9457SAndroid Build Coastguard Worker return *this; 81*4bdc9457SAndroid Build Coastguard Worker } 82*4bdc9457SAndroid Build Coastguard Worker qmin()83*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmin() const { 84*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 85*4bdc9457SAndroid Build Coastguard Worker } 86*4bdc9457SAndroid Build Coastguard Worker qmax(int16_t qmax)87*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& qmax(int16_t qmax) { 88*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 89*4bdc9457SAndroid Build Coastguard Worker return *this; 90*4bdc9457SAndroid Build Coastguard Worker } 91*4bdc9457SAndroid Build Coastguard Worker qmax()92*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmax() const { 93*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 94*4bdc9457SAndroid Build Coastguard Worker } 95*4bdc9457SAndroid Build Coastguard Worker relu_activation(bool relu_activation)96*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& relu_activation(bool relu_activation) { 97*4bdc9457SAndroid Build Coastguard Worker this->relu_activation_ = relu_activation; 98*4bdc9457SAndroid Build Coastguard Worker return *this; 99*4bdc9457SAndroid Build Coastguard Worker } 100*4bdc9457SAndroid Build Coastguard Worker relu_activation()101*4bdc9457SAndroid Build Coastguard Worker inline bool relu_activation() const { 102*4bdc9457SAndroid Build Coastguard Worker return this->relu_activation_; 103*4bdc9457SAndroid Build Coastguard Worker } 104*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)105*4bdc9457SAndroid Build Coastguard Worker inline ClampOperatorTester& iterations(size_t iterations) { 106*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 107*4bdc9457SAndroid Build Coastguard Worker return *this; 108*4bdc9457SAndroid Build Coastguard Worker } 109*4bdc9457SAndroid Build Coastguard Worker iterations()110*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 111*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker TestF16()114*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 115*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 116*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(relu_activation()); 117*4bdc9457SAndroid Build Coastguard Worker 118*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 119*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 120*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist( 121*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int16_t>::min(), std::numeric_limits<int16_t>::max()); 122*4bdc9457SAndroid Build Coastguard Worker 123*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 124*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 125*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 126*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 127*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 128*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 129*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 130*4bdc9457SAndroid Build Coastguard Worker 131*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 132*4bdc9457SAndroid Build Coastguard Worker const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(float(qmin()))); 133*4bdc9457SAndroid Build Coastguard Worker const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(float(qmax()))); 134*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 135*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 136*4bdc9457SAndroid Build Coastguard Worker const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 137*4bdc9457SAndroid Build Coastguard Worker const float y = relu_activation() ? std::max(x, 0.f) : std::min(std::max(x, output_min), output_max); 138*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker 142*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Clamp operator. 143*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 144*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t clamp_op = nullptr; 145*4bdc9457SAndroid Build Coastguard Worker 146*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_clamp_nc_f16( 147*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 148*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 149*4bdc9457SAndroid Build Coastguard Worker 0, &clamp_op); 150*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 151*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 152*4bdc9457SAndroid Build Coastguard Worker } 153*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 154*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, clamp_op); 155*4bdc9457SAndroid Build Coastguard Worker 156*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete clamp_op. 157*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator); 158*4bdc9457SAndroid Build Coastguard Worker 159*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 160*4bdc9457SAndroid Build Coastguard Worker xnn_setup_clamp_nc_f16( 161*4bdc9457SAndroid Build Coastguard Worker clamp_op, 162*4bdc9457SAndroid Build Coastguard Worker batch_size(), 163*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 164*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 165*4bdc9457SAndroid Build Coastguard Worker 166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 167*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(clamp_op, nullptr /* thread pool */)); 168*4bdc9457SAndroid Build Coastguard Worker 169*4bdc9457SAndroid Build Coastguard Worker // Verify results. 170*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 171*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 172*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max) 173*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 174*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min) 175*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 176*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f)) 177*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 178*4bdc9457SAndroid Build Coastguard Worker << ", min " << output_min << ", max " << output_max; 179*4bdc9457SAndroid Build Coastguard Worker } 180*4bdc9457SAndroid Build Coastguard Worker } 181*4bdc9457SAndroid Build Coastguard Worker } 182*4bdc9457SAndroid Build Coastguard Worker } 183*4bdc9457SAndroid Build Coastguard Worker TestF32()184*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 185*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 186*4bdc9457SAndroid Build Coastguard Worker 187*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 188*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 189*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist( 190*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int16_t>::min(), std::numeric_limits<int16_t>::max()); 191*4bdc9457SAndroid Build Coastguard Worker 192*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 193*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 194*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 195*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 196*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 197*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 198*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 199*4bdc9457SAndroid Build Coastguard Worker 200*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 201*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 202*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 203*4bdc9457SAndroid Build Coastguard Worker const float x = input[i * input_stride() + c]; 204*4bdc9457SAndroid Build Coastguard Worker const float y = relu_activation() ? std::max(x, 0.f) : 205*4bdc9457SAndroid Build Coastguard Worker std::min(std::max(x, float(qmin())), float(qmax())); 206*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 207*4bdc9457SAndroid Build Coastguard Worker } 208*4bdc9457SAndroid Build Coastguard Worker } 209*4bdc9457SAndroid Build Coastguard Worker 210*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Clamp operator. 211*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 212*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t clamp_op = nullptr; 213*4bdc9457SAndroid Build Coastguard Worker 214*4bdc9457SAndroid Build Coastguard Worker const float output_min = relu_activation() ? 0.0f : float(qmin()); 215*4bdc9457SAndroid Build Coastguard Worker const float output_max = relu_activation() ? std::numeric_limits<float>::infinity() : float(qmax()); 216*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 217*4bdc9457SAndroid Build Coastguard Worker xnn_create_clamp_nc_f32( 218*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 219*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 220*4bdc9457SAndroid Build Coastguard Worker 0, &clamp_op)); 221*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, clamp_op); 222*4bdc9457SAndroid Build Coastguard Worker 223*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete clamp_op. 224*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator); 225*4bdc9457SAndroid Build Coastguard Worker 226*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 227*4bdc9457SAndroid Build Coastguard Worker xnn_setup_clamp_nc_f32( 228*4bdc9457SAndroid Build Coastguard Worker clamp_op, 229*4bdc9457SAndroid Build Coastguard Worker batch_size(), 230*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 231*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 232*4bdc9457SAndroid Build Coastguard Worker 233*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 234*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(clamp_op, nullptr /* thread pool */)); 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker // Verify results. 237*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 238*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 239*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[i * output_stride() + c], output_max) 240*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 241*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[i * output_stride() + c], output_min) 242*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 243*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 244*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 245*4bdc9457SAndroid Build Coastguard Worker << ", min " << output_min << ", max " << output_max; 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker } 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker } 250*4bdc9457SAndroid Build Coastguard Worker TestS8()251*4bdc9457SAndroid Build Coastguard Worker void TestS8() const { 252*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<int8_t>::min()); 253*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<int8_t>::max()); 254*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 255*4bdc9457SAndroid Build Coastguard Worker 256*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 257*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 258*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 259*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 260*4bdc9457SAndroid Build Coastguard Worker 261*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 262*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 263*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels()); 264*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(batch_size() * channels()); 265*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 266*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 267*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 268*4bdc9457SAndroid Build Coastguard Worker 269*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 270*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 271*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 272*4bdc9457SAndroid Build Coastguard Worker const int8_t x = input[i * input_stride() + c]; 273*4bdc9457SAndroid Build Coastguard Worker const int8_t y = std::min(std::max(x, int8_t(qmin())), int8_t(qmax())); 274*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker 278*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Clamp operator. 279*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 280*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t clamp_op = nullptr; 281*4bdc9457SAndroid Build Coastguard Worker 282*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 283*4bdc9457SAndroid Build Coastguard Worker xnn_create_clamp_nc_s8( 284*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 285*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin()), int8_t(qmax()), 286*4bdc9457SAndroid Build Coastguard Worker 0, &clamp_op)); 287*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, clamp_op); 288*4bdc9457SAndroid Build Coastguard Worker 289*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete clamp_op. 290*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator); 291*4bdc9457SAndroid Build Coastguard Worker 292*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 293*4bdc9457SAndroid Build Coastguard Worker xnn_setup_clamp_nc_s8( 294*4bdc9457SAndroid Build Coastguard Worker clamp_op, 295*4bdc9457SAndroid Build Coastguard Worker batch_size(), 296*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 297*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 298*4bdc9457SAndroid Build Coastguard Worker 299*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 300*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(clamp_op, nullptr /* thread pool */)); 301*4bdc9457SAndroid Build Coastguard Worker 302*4bdc9457SAndroid Build Coastguard Worker // Verify results . 303*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 304*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 305*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int16_t(output[i * output_stride() + c]), qmax()) 306*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 307*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int16_t(output[i * output_stride() + c]), qmin()) 308*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 309*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int16_t(output[i * output_stride() + c]), int16_t(output_ref[i * channels() + c])) 310*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 311*4bdc9457SAndroid Build Coastguard Worker << ", min " << qmin() << ", max " << qmax(); 312*4bdc9457SAndroid Build Coastguard Worker } 313*4bdc9457SAndroid Build Coastguard Worker } 314*4bdc9457SAndroid Build Coastguard Worker } 315*4bdc9457SAndroid Build Coastguard Worker } 316*4bdc9457SAndroid Build Coastguard Worker TestU8()317*4bdc9457SAndroid Build Coastguard Worker void TestU8() const { 318*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<uint8_t>::min()); 319*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<uint8_t>::max()); 320*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 321*4bdc9457SAndroid Build Coastguard Worker 322*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 323*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 324*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 325*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 326*4bdc9457SAndroid Build Coastguard Worker 327*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 328*4bdc9457SAndroid Build Coastguard Worker (batch_size() - 1) * input_stride() + channels()); 329*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels()); 330*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(batch_size() * channels()); 331*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 332*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 333*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 334*4bdc9457SAndroid Build Coastguard Worker 335*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 336*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 337*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 338*4bdc9457SAndroid Build Coastguard Worker const uint8_t x = input[i * input_stride() + c]; 339*4bdc9457SAndroid Build Coastguard Worker const uint8_t y = std::min(std::max(x, uint8_t(qmin())), uint8_t(qmax())); 340*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + c] = y; 341*4bdc9457SAndroid Build Coastguard Worker } 342*4bdc9457SAndroid Build Coastguard Worker } 343*4bdc9457SAndroid Build Coastguard Worker 344*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Clamp operator. 345*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 346*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t clamp_op = nullptr; 347*4bdc9457SAndroid Build Coastguard Worker 348*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 349*4bdc9457SAndroid Build Coastguard Worker xnn_create_clamp_nc_u8( 350*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 351*4bdc9457SAndroid Build Coastguard Worker uint8_t(qmin()), uint8_t(qmax()), 352*4bdc9457SAndroid Build Coastguard Worker 0, &clamp_op)); 353*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, clamp_op); 354*4bdc9457SAndroid Build Coastguard Worker 355*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete clamp_op. 356*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_clamp_op(clamp_op, xnn_delete_operator); 357*4bdc9457SAndroid Build Coastguard Worker 358*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 359*4bdc9457SAndroid Build Coastguard Worker xnn_setup_clamp_nc_u8( 360*4bdc9457SAndroid Build Coastguard Worker clamp_op, 361*4bdc9457SAndroid Build Coastguard Worker batch_size(), 362*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 363*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 364*4bdc9457SAndroid Build Coastguard Worker 365*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 366*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(clamp_op, nullptr /* thread pool */)); 367*4bdc9457SAndroid Build Coastguard Worker 368*4bdc9457SAndroid Build Coastguard Worker // Verify results . 369*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 370*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 371*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int16_t(output[i * output_stride() + c]), qmax()) 372*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 373*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int16_t(output[i * output_stride() + c]), qmin()) 374*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 375*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int16_t(output[i * output_stride() + c]), int16_t(output_ref[i * channels() + c])) 376*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 377*4bdc9457SAndroid Build Coastguard Worker << ", min " << qmin() << ", max " << qmax(); 378*4bdc9457SAndroid Build Coastguard Worker } 379*4bdc9457SAndroid Build Coastguard Worker } 380*4bdc9457SAndroid Build Coastguard Worker } 381*4bdc9457SAndroid Build Coastguard Worker } 382*4bdc9457SAndroid Build Coastguard Worker 383*4bdc9457SAndroid Build Coastguard Worker private: 384*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 385*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 386*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 387*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 388*4bdc9457SAndroid Build Coastguard Worker int16_t qmin_{std::numeric_limits<int16_t>::min()}; 389*4bdc9457SAndroid Build Coastguard Worker int16_t qmax_{std::numeric_limits<int16_t>::max()}; 390*4bdc9457SAndroid Build Coastguard Worker bool relu_activation_{false}; 391*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 392*4bdc9457SAndroid Build Coastguard Worker }; 393