1 // Copyright 2020 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <algorithm> 11 #include <cassert> 12 #include <cstddef> 13 #include <cstdlib> 14 #include <random> 15 #include <vector> 16 17 #include <fp16.h> 18 19 #include <xnnpack.h> 20 21 22 class SquareRootOperatorTester { 23 public: channels(size_t channels)24 inline SquareRootOperatorTester& channels(size_t channels) { 25 assert(channels != 0); 26 this->channels_ = channels; 27 return *this; 28 } 29 channels()30 inline size_t channels() const { 31 return this->channels_; 32 } 33 input_stride(size_t input_stride)34 inline SquareRootOperatorTester& input_stride(size_t input_stride) { 35 assert(input_stride != 0); 36 this->input_stride_ = input_stride; 37 return *this; 38 } 39 input_stride()40 inline size_t input_stride() const { 41 if (this->input_stride_ == 0) { 42 return this->channels_; 43 } else { 44 assert(this->input_stride_ >= this->channels_); 45 return this->input_stride_; 46 } 47 } 48 output_stride(size_t output_stride)49 inline SquareRootOperatorTester& output_stride(size_t output_stride) { 50 assert(output_stride != 0); 51 this->output_stride_ = output_stride; 52 return *this; 53 } 54 output_stride()55 inline size_t output_stride() const { 56 if (this->output_stride_ == 0) { 57 return this->channels_; 58 } else { 59 assert(this->output_stride_ >= this->channels_); 60 return this->output_stride_; 61 } 62 } 63 batch_size(size_t batch_size)64 inline SquareRootOperatorTester& batch_size(size_t batch_size) { 65 assert(batch_size != 0); 66 this->batch_size_ = batch_size; 67 return *this; 68 } 69 batch_size()70 inline size_t batch_size() const { 71 return this->batch_size_; 72 } 73 iterations(size_t iterations)74 inline SquareRootOperatorTester& iterations(size_t iterations) { 75 this->iterations_ = iterations; 76 return *this; 77 } 78 iterations()79 inline size_t iterations() const { 80 return this->iterations_; 81 } 82 TestF16()83 void TestF16() const { 84 std::random_device random_device; 85 auto rng = std::mt19937(random_device()); 86 std::uniform_real_distribution<float> f32dist(0.1f, 5.0f); 87 88 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 89 (batch_size() - 1) * input_stride() + channels()); 90 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels()); 91 std::vector<float> output_ref(batch_size() * channels()); 92 for (size_t iteration = 0; iteration < iterations(); iteration++) { 93 std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 94 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 95 96 // Compute reference results. 97 for (size_t i = 0; i < batch_size(); i++) { 98 for (size_t c = 0; c < channels(); c++) { 99 output_ref[i * channels() + c] = std::sqrt(fp16_ieee_to_fp32_value(input[i * input_stride() + c])); 100 } 101 } 102 103 // Create, setup, run, and destroy Square operator. 104 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 105 xnn_operator_t sqrt_op = nullptr; 106 107 const xnn_status status = xnn_create_square_root_nc_f16( 108 channels(), input_stride(), output_stride(), 109 0, &sqrt_op); 110 if (status == xnn_status_unsupported_hardware) { 111 GTEST_SKIP(); 112 } 113 ASSERT_EQ(xnn_status_success, status); 114 ASSERT_NE(nullptr, sqrt_op); 115 116 // Smart pointer to automatically delete sqrt_op. 117 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sqrt_op(sqrt_op, xnn_delete_operator); 118 119 ASSERT_EQ(xnn_status_success, 120 xnn_setup_square_root_nc_f16( 121 sqrt_op, 122 batch_size(), 123 input.data(), output.data(), 124 nullptr /* thread pool */)); 125 126 ASSERT_EQ(xnn_status_success, 127 xnn_run_operator(sqrt_op, nullptr /* thread pool */)); 128 129 // Verify results. 130 for (size_t i = 0; i < batch_size(); i++) { 131 for (size_t c = 0; c < channels(); c++) { 132 ASSERT_NEAR( 133 fp16_ieee_to_fp32_value(output[i * output_stride() + c]), 134 output_ref[i * channels() + c], 135 std::abs(output_ref[i * channels() + c]) * 5.0e-3f) 136 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 137 << ", input " << fp16_ieee_to_fp32_value(input[i * input_stride() + c]); 138 } 139 } 140 } 141 } 142 TestF32()143 void TestF32() const { 144 std::random_device random_device; 145 auto rng = std::mt19937(random_device()); 146 std::uniform_real_distribution<float> f32dist(0.0f, 5.0f); 147 148 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 149 (batch_size() - 1) * input_stride() + channels()); 150 std::vector<float> output((batch_size() - 1) * output_stride() + channels()); 151 std::vector<float> output_ref(batch_size() * channels()); 152 for (size_t iteration = 0; iteration < iterations(); iteration++) { 153 std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 154 std::fill(output.begin(), output.end(), std::nanf("")); 155 156 // Compute reference results. 157 for (size_t i = 0; i < batch_size(); i++) { 158 for (size_t c = 0; c < channels(); c++) { 159 output_ref[i * channels() + c] = std::sqrt(input[i * input_stride() + c]); 160 } 161 } 162 163 // Create, setup, run, and destroy Square operator. 164 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 165 xnn_operator_t sqrt_op = nullptr; 166 167 ASSERT_EQ(xnn_status_success, 168 xnn_create_square_root_nc_f32( 169 channels(), input_stride(), output_stride(), 170 0, &sqrt_op)); 171 ASSERT_NE(nullptr, sqrt_op); 172 173 // Smart pointer to automatically delete sqrt_op. 174 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sqrt_op(sqrt_op, xnn_delete_operator); 175 176 ASSERT_EQ(xnn_status_success, 177 xnn_setup_square_root_nc_f32( 178 sqrt_op, 179 batch_size(), 180 input.data(), output.data(), 181 nullptr /* thread pool */)); 182 183 ASSERT_EQ(xnn_status_success, 184 xnn_run_operator(sqrt_op, nullptr /* thread pool */)); 185 186 // Verify results. 187 for (size_t i = 0; i < batch_size(); i++) { 188 for (size_t c = 0; c < channels(); c++) { 189 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c]) 190 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels() 191 << ", input " << input[i * input_stride() + c]; 192 } 193 } 194 } 195 } 196 197 private: 198 size_t batch_size_{1}; 199 size_t channels_{1}; 200 size_t input_stride_{0}; 201 size_t output_stride_{0}; 202 size_t iterations_{15}; 203 }; 204