1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 11*4bdc9457SAndroid Build Coastguard Worker 12*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 13*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 15*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 16*4bdc9457SAndroid Build Coastguard Worker #include <functional> 17*4bdc9457SAndroid Build Coastguard Worker #include <random> 18*4bdc9457SAndroid Build Coastguard Worker #include <vector> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker class PReLUOperatorTester { 25*4bdc9457SAndroid Build Coastguard Worker public: 26*4bdc9457SAndroid Build Coastguard Worker enum class WeightsType { 27*4bdc9457SAndroid Build Coastguard Worker Default, 28*4bdc9457SAndroid Build Coastguard Worker FP32, 29*4bdc9457SAndroid Build Coastguard Worker }; 30*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)31*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& batch_size(size_t batch_size) { 32*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 33*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 34*4bdc9457SAndroid Build Coastguard Worker return *this; 35*4bdc9457SAndroid Build Coastguard Worker } 36*4bdc9457SAndroid Build Coastguard Worker batch_size()37*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 38*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 39*4bdc9457SAndroid Build Coastguard Worker } 40*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)41*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& channels(size_t channels) { 42*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 43*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 44*4bdc9457SAndroid Build Coastguard Worker return *this; 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker channels()47*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 48*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker x_stride(size_t x_stride)51*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& x_stride(size_t x_stride) { 52*4bdc9457SAndroid Build Coastguard Worker assert(x_stride != 0); 53*4bdc9457SAndroid Build Coastguard Worker this->x_stride_ = x_stride; 54*4bdc9457SAndroid Build Coastguard Worker return *this; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker x_stride()57*4bdc9457SAndroid Build Coastguard Worker inline size_t x_stride() const { 58*4bdc9457SAndroid Build Coastguard Worker if (this->x_stride_ == 0) { 59*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 60*4bdc9457SAndroid Build Coastguard Worker } else { 61*4bdc9457SAndroid Build Coastguard Worker assert(this->x_stride_ >= this->channels_); 62*4bdc9457SAndroid Build Coastguard Worker return this->x_stride_; 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker } 65*4bdc9457SAndroid Build Coastguard Worker y_stride(size_t y_stride)66*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& y_stride(size_t y_stride) { 67*4bdc9457SAndroid Build Coastguard Worker assert(y_stride != 0); 68*4bdc9457SAndroid Build Coastguard Worker this->y_stride_ = y_stride; 69*4bdc9457SAndroid Build Coastguard Worker return *this; 70*4bdc9457SAndroid Build Coastguard Worker } 71*4bdc9457SAndroid Build Coastguard Worker y_stride()72*4bdc9457SAndroid Build Coastguard Worker inline size_t y_stride() const { 73*4bdc9457SAndroid Build Coastguard Worker if (this->y_stride_ == 0) { 74*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 75*4bdc9457SAndroid Build Coastguard Worker } else { 76*4bdc9457SAndroid Build Coastguard Worker assert(this->y_stride_ >= this->channels_); 77*4bdc9457SAndroid Build Coastguard Worker return this->y_stride_; 78*4bdc9457SAndroid Build Coastguard Worker } 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker weights_type(WeightsType weights_type)81*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& weights_type(WeightsType weights_type) { 82*4bdc9457SAndroid Build Coastguard Worker this->weights_type_ = weights_type; 83*4bdc9457SAndroid Build Coastguard Worker return *this; 84*4bdc9457SAndroid Build Coastguard Worker } 85*4bdc9457SAndroid Build Coastguard Worker weights_type()86*4bdc9457SAndroid Build Coastguard Worker inline WeightsType weights_type() const { 87*4bdc9457SAndroid Build Coastguard Worker return this->weights_type_; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)90*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& iterations(size_t iterations) { 91*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 92*4bdc9457SAndroid Build Coastguard Worker return *this; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker iterations()95*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 96*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker use_weights_cache(bool use_weights_cache)99*4bdc9457SAndroid Build Coastguard Worker inline PReLUOperatorTester& use_weights_cache(bool use_weights_cache) { 100*4bdc9457SAndroid Build Coastguard Worker this->use_weights_cache_ = use_weights_cache; 101*4bdc9457SAndroid Build Coastguard Worker return *this; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker use_weights_cache()104*4bdc9457SAndroid Build Coastguard Worker inline bool use_weights_cache() const { 105*4bdc9457SAndroid Build Coastguard Worker return this->use_weights_cache_; 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker TestF16()108*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 109*4bdc9457SAndroid Build Coastguard Worker switch (weights_type()) { 110*4bdc9457SAndroid Build Coastguard Worker case WeightsType::Default: 111*4bdc9457SAndroid Build Coastguard Worker break; 112*4bdc9457SAndroid Build Coastguard Worker case WeightsType::FP32: 113*4bdc9457SAndroid Build Coastguard Worker break; 114*4bdc9457SAndroid Build Coastguard Worker default: 115*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "unexpected weights type"; 116*4bdc9457SAndroid Build Coastguard Worker } 117*4bdc9457SAndroid Build Coastguard Worker 118*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 119*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 120*4bdc9457SAndroid Build Coastguard Worker auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f); 121*4bdc9457SAndroid Build Coastguard Worker auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f); 122*4bdc9457SAndroid Build Coastguard Worker 123*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 124*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> w(channels()); 125*4bdc9457SAndroid Build Coastguard Worker std::vector<float> w_as_float(channels()); 126*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 127*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size() * channels()); 128*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 129*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&] { return fp16_ieee_from_fp32_value(f32irng(rng)); }); 130*4bdc9457SAndroid Build Coastguard Worker std::generate(w.begin(), w.end(), [&] { return fp16_ieee_from_fp32_value(f32wrng(rng)); }); 131*4bdc9457SAndroid Build Coastguard Worker std::transform(w.cbegin(), w.cend(), w_as_float.begin(), fp16_ieee_to_fp32_value); 132*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 133*4bdc9457SAndroid Build Coastguard Worker 134*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 135*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 136*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 137*4bdc9457SAndroid Build Coastguard Worker const float x_value = fp16_ieee_to_fp32_value(x[i * x_stride() + c]); 138*4bdc9457SAndroid Build Coastguard Worker const float w_value = w_as_float[c]; 139*4bdc9457SAndroid Build Coastguard Worker y_ref[i * channels() + c] = std::signbit(x_value) ? x_value * w_value : x_value; 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker } 142*4bdc9457SAndroid Build Coastguard Worker 143*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy PReLU operator. 144*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 145*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t prelu_op = nullptr; 146*4bdc9457SAndroid Build Coastguard Worker 147*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 148*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 149*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 150*4bdc9457SAndroid Build Coastguard Worker }; 151*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 152*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 153*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 154*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker 157*4bdc9457SAndroid Build Coastguard Worker const void* negative_slope_data = w.data(); 158*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 159*4bdc9457SAndroid Build Coastguard Worker negative_slope_data = w_as_float.data(); 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0; 162*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 163*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 164*4bdc9457SAndroid Build Coastguard Worker } 165*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 166*4bdc9457SAndroid Build Coastguard Worker xnn_create_prelu_nc_f16( 167*4bdc9457SAndroid Build Coastguard Worker channels(), x_stride(), y_stride(), 168*4bdc9457SAndroid Build Coastguard Worker negative_slope_data, 169*4bdc9457SAndroid Build Coastguard Worker flags, &caches, &prelu_op)); 170*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, prelu_op); 171*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 172*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 173*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 174*4bdc9457SAndroid Build Coastguard Worker } 175*4bdc9457SAndroid Build Coastguard Worker 176*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete prelu_op. 177*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 180*4bdc9457SAndroid Build Coastguard Worker xnn_setup_prelu_nc_f16( 181*4bdc9457SAndroid Build Coastguard Worker prelu_op, 182*4bdc9457SAndroid Build Coastguard Worker batch_size(), 183*4bdc9457SAndroid Build Coastguard Worker x.data(), y.data(), 184*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 185*4bdc9457SAndroid Build Coastguard Worker 186*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 187*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(prelu_op, nullptr /* thread pool */)); 188*4bdc9457SAndroid Build Coastguard Worker 189*4bdc9457SAndroid Build Coastguard Worker VerifyF16(y, y_ref); 190*4bdc9457SAndroid Build Coastguard Worker 191*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 192*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t prelu_op2 = nullptr; 193*4bdc9457SAndroid Build Coastguard Worker const size_t old_weights_cache_size = weights_cache.cache.weights.size; 194*4bdc9457SAndroid Build Coastguard Worker 195*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 196*4bdc9457SAndroid Build Coastguard Worker xnn_create_prelu_nc_f16( 197*4bdc9457SAndroid Build Coastguard Worker channels(), x_stride(), y_stride(), 198*4bdc9457SAndroid Build Coastguard Worker negative_slope_data, 199*4bdc9457SAndroid Build Coastguard Worker flags, &caches, &prelu_op2)); 200*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, prelu_op2); 201*4bdc9457SAndroid Build Coastguard Worker 202*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete prelu_op2. 203*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator); 204*4bdc9457SAndroid Build Coastguard Worker 205*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y2(y.size(), UINT16_C(0x7E00) /* NaN */); 206*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 207*4bdc9457SAndroid Build Coastguard Worker xnn_setup_prelu_nc_f16( 208*4bdc9457SAndroid Build Coastguard Worker prelu_op2, 209*4bdc9457SAndroid Build Coastguard Worker batch_size(), 210*4bdc9457SAndroid Build Coastguard Worker x.data(), y2.data(), 211*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 212*4bdc9457SAndroid Build Coastguard Worker 213*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 214*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(prelu_op2, nullptr /* thread pool */)); 215*4bdc9457SAndroid Build Coastguard Worker 216*4bdc9457SAndroid Build Coastguard Worker VerifyF16(y2, y_ref); 217*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 218*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 219*4bdc9457SAndroid Build Coastguard Worker } 220*4bdc9457SAndroid Build Coastguard Worker } 221*4bdc9457SAndroid Build Coastguard Worker } 222*4bdc9457SAndroid Build Coastguard Worker VerifyF16(const std::vector<uint16_t> & y,const std::vector<float> & y_ref)223*4bdc9457SAndroid Build Coastguard Worker void VerifyF16(const std::vector<uint16_t>& y, const std::vector<float>& y_ref) const { 224*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 225*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 226*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 227*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i * y_stride() + c]), 228*4bdc9457SAndroid Build Coastguard Worker y_ref[i * channels() + c], 229*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i * channels() + c]) * 1.0e-3f)) 230*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 231*4bdc9457SAndroid Build Coastguard Worker } 232*4bdc9457SAndroid Build Coastguard Worker } 233*4bdc9457SAndroid Build Coastguard Worker } 234*4bdc9457SAndroid Build Coastguard Worker TestF32()235*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 236*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 237*4bdc9457SAndroid Build Coastguard Worker 238*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 239*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 240*4bdc9457SAndroid Build Coastguard Worker auto f32irng = std::uniform_real_distribution<float>(-1.0f, 1.0f); 241*4bdc9457SAndroid Build Coastguard Worker auto f32wrng = std::uniform_real_distribution<float>(0.25f, 0.75f); 242*4bdc9457SAndroid Build Coastguard Worker 243*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 244*4bdc9457SAndroid Build Coastguard Worker std::vector<float> w(channels()); 245*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 246*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size() * channels()); 247*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 248*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&] { return f32irng(rng);} ); 249*4bdc9457SAndroid Build Coastguard Worker std::generate(w.begin(), w.end(), [&] { return f32wrng(rng);} ); 250*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 251*4bdc9457SAndroid Build Coastguard Worker 252*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 253*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 254*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 255*4bdc9457SAndroid Build Coastguard Worker y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c]; 256*4bdc9457SAndroid Build Coastguard Worker } 257*4bdc9457SAndroid Build Coastguard Worker } 258*4bdc9457SAndroid Build Coastguard Worker 259*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy PReLU operator. 260*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 261*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t prelu_op = nullptr; 262*4bdc9457SAndroid Build Coastguard Worker 263*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 264*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 265*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 266*4bdc9457SAndroid Build Coastguard Worker }; 267*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 268*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 269*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 270*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 271*4bdc9457SAndroid Build Coastguard Worker } 272*4bdc9457SAndroid Build Coastguard Worker 273*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 274*4bdc9457SAndroid Build Coastguard Worker xnn_create_prelu_nc_f32( 275*4bdc9457SAndroid Build Coastguard Worker channels(), x_stride(), y_stride(), 276*4bdc9457SAndroid Build Coastguard Worker w.data(), 277*4bdc9457SAndroid Build Coastguard Worker 0, &caches, &prelu_op)); 278*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, prelu_op); 279*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 280*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 281*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 282*4bdc9457SAndroid Build Coastguard Worker } 283*4bdc9457SAndroid Build Coastguard Worker 284*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete prelu_op. 285*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator); 286*4bdc9457SAndroid Build Coastguard Worker 287*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 288*4bdc9457SAndroid Build Coastguard Worker xnn_setup_prelu_nc_f32( 289*4bdc9457SAndroid Build Coastguard Worker prelu_op, 290*4bdc9457SAndroid Build Coastguard Worker batch_size(), 291*4bdc9457SAndroid Build Coastguard Worker x.data(), y.data(), 292*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 293*4bdc9457SAndroid Build Coastguard Worker 294*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 295*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(prelu_op, nullptr /* thread pool */)); 296*4bdc9457SAndroid Build Coastguard Worker 297*4bdc9457SAndroid Build Coastguard Worker VerifyF32(y, y_ref); 298*4bdc9457SAndroid Build Coastguard Worker 299*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 300*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t prelu_op2 = nullptr; 301*4bdc9457SAndroid Build Coastguard Worker const size_t old_weights_cache_size = weights_cache.cache.weights.size; 302*4bdc9457SAndroid Build Coastguard Worker 303*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 304*4bdc9457SAndroid Build Coastguard Worker xnn_create_prelu_nc_f32( 305*4bdc9457SAndroid Build Coastguard Worker channels(), x_stride(), y_stride(), 306*4bdc9457SAndroid Build Coastguard Worker w.data(), 307*4bdc9457SAndroid Build Coastguard Worker 0, &caches, &prelu_op2)); 308*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, prelu_op2); 309*4bdc9457SAndroid Build Coastguard Worker 310*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete prelu_op2. 311*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op2, xnn_delete_operator); 312*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y2(y.size(), nanf("")); 313*4bdc9457SAndroid Build Coastguard Worker 314*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 315*4bdc9457SAndroid Build Coastguard Worker xnn_setup_prelu_nc_f32( 316*4bdc9457SAndroid Build Coastguard Worker prelu_op2, 317*4bdc9457SAndroid Build Coastguard Worker batch_size(), 318*4bdc9457SAndroid Build Coastguard Worker x.data(), y2.data(), 319*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 320*4bdc9457SAndroid Build Coastguard Worker 321*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 322*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(prelu_op2, nullptr /* thread pool */)); 323*4bdc9457SAndroid Build Coastguard Worker 324*4bdc9457SAndroid Build Coastguard Worker VerifyF32(y, y_ref); 325*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 326*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 327*4bdc9457SAndroid Build Coastguard Worker } 328*4bdc9457SAndroid Build Coastguard Worker } 329*4bdc9457SAndroid Build Coastguard Worker } 330*4bdc9457SAndroid Build Coastguard Worker VerifyF32(const std::vector<float> & y,const std::vector<float> & y_ref)331*4bdc9457SAndroid Build Coastguard Worker void VerifyF32(const std::vector<float>& y, const std::vector<float>& y_ref) const { 332*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 333*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 334*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 335*4bdc9457SAndroid Build Coastguard Worker y[i * y_stride() + c], 336*4bdc9457SAndroid Build Coastguard Worker y_ref[i * channels() + c], 337*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-6f, std::abs(y_ref[i * channels() + c]) * 1.0e-6f)) 338*4bdc9457SAndroid Build Coastguard Worker << "at position " << i << " / " << batch_size() << ", channel " << c << " / " << channels(); 339*4bdc9457SAndroid Build Coastguard Worker } 340*4bdc9457SAndroid Build Coastguard Worker } 341*4bdc9457SAndroid Build Coastguard Worker } 342*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)343*4bdc9457SAndroid Build Coastguard Worker void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const { 344*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache.cache.hits, 1); 345*4bdc9457SAndroid Build Coastguard Worker // Ensure that we did not write more weights to the cache because it was a cache hit. 346*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_size, weights_cache.cache.weights.size); 347*4bdc9457SAndroid Build Coastguard Worker }; 348*4bdc9457SAndroid Build Coastguard Worker 349*4bdc9457SAndroid Build Coastguard Worker private: 350*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 351*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 352*4bdc9457SAndroid Build Coastguard Worker size_t x_stride_{0}; 353*4bdc9457SAndroid Build Coastguard Worker size_t y_stride_{0}; 354*4bdc9457SAndroid Build Coastguard Worker WeightsType weights_type_{WeightsType::Default}; 355*4bdc9457SAndroid Build Coastguard Worker bool use_weights_cache_{false}; 356*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 357*4bdc9457SAndroid Build Coastguard Worker }; 358