xref: /aosp_15_r20/external/XNNPACK/test/square-root-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cstddef>
13 #include <cstdlib>
14 #include <random>
15 #include <vector>
16 
17 #include <fp16.h>
18 
19 #include <xnnpack.h>
20 
21 
22 class SquareRootOperatorTester {
23  public:
channels(size_t channels)24   inline SquareRootOperatorTester& channels(size_t channels) {
25     assert(channels != 0);
26     this->channels_ = channels;
27     return *this;
28   }
29 
channels()30   inline size_t channels() const {
31     return this->channels_;
32   }
33 
input_stride(size_t input_stride)34   inline SquareRootOperatorTester& input_stride(size_t input_stride) {
35     assert(input_stride != 0);
36     this->input_stride_ = input_stride;
37     return *this;
38   }
39 
input_stride()40   inline size_t input_stride() const {
41     if (this->input_stride_ == 0) {
42       return this->channels_;
43     } else {
44       assert(this->input_stride_ >= this->channels_);
45       return this->input_stride_;
46     }
47   }
48 
output_stride(size_t output_stride)49   inline SquareRootOperatorTester& output_stride(size_t output_stride) {
50     assert(output_stride != 0);
51     this->output_stride_ = output_stride;
52     return *this;
53   }
54 
output_stride()55   inline size_t output_stride() const {
56     if (this->output_stride_ == 0) {
57       return this->channels_;
58     } else {
59       assert(this->output_stride_ >= this->channels_);
60       return this->output_stride_;
61     }
62   }
63 
batch_size(size_t batch_size)64   inline SquareRootOperatorTester& batch_size(size_t batch_size) {
65     assert(batch_size != 0);
66     this->batch_size_ = batch_size;
67     return *this;
68   }
69 
batch_size()70   inline size_t batch_size() const {
71     return this->batch_size_;
72   }
73 
iterations(size_t iterations)74   inline SquareRootOperatorTester& iterations(size_t iterations) {
75     this->iterations_ = iterations;
76     return *this;
77   }
78 
iterations()79   inline size_t iterations() const {
80     return this->iterations_;
81   }
82 
TestF16()83   void TestF16() const {
84     std::random_device random_device;
85     auto rng = std::mt19937(random_device());
86     std::uniform_real_distribution<float> f32dist(0.1f, 5.0f);
87 
88     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
89       (batch_size() - 1) * input_stride() + channels());
90     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
91     std::vector<float> output_ref(batch_size() * channels());
92     for (size_t iteration = 0; iteration < iterations(); iteration++) {
93       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
94       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
95 
96       // Compute reference results.
97       for (size_t i = 0; i < batch_size(); i++) {
98         for (size_t c = 0; c < channels(); c++) {
99           output_ref[i * channels() + c] = std::sqrt(fp16_ieee_to_fp32_value(input[i * input_stride() + c]));
100         }
101       }
102 
103       // Create, setup, run, and destroy Square operator.
104       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
105       xnn_operator_t sqrt_op = nullptr;
106 
107       const xnn_status status = xnn_create_square_root_nc_f16(
108         channels(), input_stride(), output_stride(),
109         0, &sqrt_op);
110       if (status == xnn_status_unsupported_hardware) {
111         GTEST_SKIP();
112       }
113       ASSERT_EQ(xnn_status_success, status);
114       ASSERT_NE(nullptr, sqrt_op);
115 
116       // Smart pointer to automatically delete sqrt_op.
117       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sqrt_op(sqrt_op, xnn_delete_operator);
118 
119       ASSERT_EQ(xnn_status_success,
120         xnn_setup_square_root_nc_f16(
121           sqrt_op,
122           batch_size(),
123           input.data(), output.data(),
124           nullptr /* thread pool */));
125 
126       ASSERT_EQ(xnn_status_success,
127         xnn_run_operator(sqrt_op, nullptr /* thread pool */));
128 
129       // Verify results.
130       for (size_t i = 0; i < batch_size(); i++) {
131         for (size_t c = 0; c < channels(); c++) {
132           ASSERT_NEAR(
133               fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
134               output_ref[i * channels() + c],
135               std::abs(output_ref[i * channels() + c]) * 5.0e-3f)
136             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
137             << ", input " << fp16_ieee_to_fp32_value(input[i * input_stride() + c]);
138         }
139       }
140     }
141   }
142 
TestF32()143   void TestF32() const {
144     std::random_device random_device;
145     auto rng = std::mt19937(random_device());
146     std::uniform_real_distribution<float> f32dist(0.0f, 5.0f);
147 
148     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
149       (batch_size() - 1) * input_stride() + channels());
150     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
151     std::vector<float> output_ref(batch_size() * channels());
152     for (size_t iteration = 0; iteration < iterations(); iteration++) {
153       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
154       std::fill(output.begin(), output.end(), std::nanf(""));
155 
156       // Compute reference results.
157       for (size_t i = 0; i < batch_size(); i++) {
158         for (size_t c = 0; c < channels(); c++) {
159           output_ref[i * channels() + c] = std::sqrt(input[i * input_stride() + c]);
160         }
161       }
162 
163       // Create, setup, run, and destroy Square operator.
164       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
165       xnn_operator_t sqrt_op = nullptr;
166 
167       ASSERT_EQ(xnn_status_success,
168         xnn_create_square_root_nc_f32(
169           channels(), input_stride(), output_stride(),
170           0, &sqrt_op));
171       ASSERT_NE(nullptr, sqrt_op);
172 
173       // Smart pointer to automatically delete sqrt_op.
174       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_sqrt_op(sqrt_op, xnn_delete_operator);
175 
176       ASSERT_EQ(xnn_status_success,
177         xnn_setup_square_root_nc_f32(
178           sqrt_op,
179           batch_size(),
180           input.data(), output.data(),
181           nullptr /* thread pool */));
182 
183       ASSERT_EQ(xnn_status_success,
184         xnn_run_operator(sqrt_op, nullptr /* thread pool */));
185 
186       // Verify results.
187       for (size_t i = 0; i < batch_size(); i++) {
188         for (size_t c = 0; c < channels(); c++) {
189           ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
190             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
191             << ", input " << input[i * input_stride() + c];
192         }
193       }
194     }
195   }
196 
197  private:
198   size_t batch_size_{1};
199   size_t channels_{1};
200   size_t input_stride_{0};
201   size_t output_stride_{0};
202   size_t iterations_{15};
203 };
204