xref: /aosp_15_r20/external/XNNPACK/test/elu-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <random>
16 #include <vector>
17 
18 #include <fp16.h>
19 
20 #include <xnnpack.h>
21 
22 
23 class ELUOperatorTester {
24  public:
channels(size_t channels)25   inline ELUOperatorTester& channels(size_t channels) {
26     assert(channels != 0);
27     this->channels_ = channels;
28     return *this;
29   }
30 
channels()31   inline size_t channels() const {
32     return this->channels_;
33   }
34 
input_stride(size_t input_stride)35   inline ELUOperatorTester& input_stride(size_t input_stride) {
36     assert(input_stride != 0);
37     this->input_stride_ = input_stride;
38     return *this;
39   }
40 
input_stride()41   inline size_t input_stride() const {
42     if (this->input_stride_ == 0) {
43       return this->channels_;
44     } else {
45       assert(this->input_stride_ >= this->channels_);
46       return this->input_stride_;
47     }
48   }
49 
output_stride(size_t output_stride)50   inline ELUOperatorTester& output_stride(size_t output_stride) {
51     assert(output_stride != 0);
52     this->output_stride_ = output_stride;
53     return *this;
54   }
55 
output_stride()56   inline size_t output_stride() const {
57     if (this->output_stride_ == 0) {
58       return this->channels_;
59     } else {
60       assert(this->output_stride_ >= this->channels_);
61       return this->output_stride_;
62     }
63   }
64 
batch_size(size_t batch_size)65   inline ELUOperatorTester& batch_size(size_t batch_size) {
66     assert(batch_size != 0);
67     this->batch_size_ = batch_size;
68     return *this;
69   }
70 
batch_size()71   inline size_t batch_size() const {
72     return this->batch_size_;
73   }
74 
alpha(float alpha)75   inline ELUOperatorTester& alpha(float alpha) {
76     assert(alpha > 0.0f);
77     assert(alpha < 1.0f);
78     this->alpha_ = alpha;
79     return *this;
80   }
81 
alpha()82   inline float alpha() const {
83     return this->alpha_;
84   }
85 
input_scale(float input_scale)86   inline ELUOperatorTester& input_scale(float input_scale) {
87     assert(input_scale > 0.0f);
88     assert(std::isnormal(input_scale));
89     this->input_scale_ = input_scale;
90     return *this;
91   }
92 
input_scale()93   inline float input_scale() const {
94     return this->input_scale_;
95   }
96 
input_zero_point(uint8_t input_zero_point)97   inline ELUOperatorTester& input_zero_point(uint8_t input_zero_point) {
98     this->input_zero_point_ = input_zero_point;
99     return *this;
100   }
101 
input_zero_point()102   inline uint8_t input_zero_point() const {
103     return this->input_zero_point_;
104   }
105 
output_scale(float output_scale)106   inline ELUOperatorTester& output_scale(float output_scale) {
107     assert(output_scale > 0.0f);
108     assert(std::isnormal(output_scale));
109     this->output_scale_ = output_scale;
110     return *this;
111   }
112 
output_scale()113   inline float output_scale() const {
114     return this->output_scale_;
115   }
116 
output_zero_point(uint8_t output_zero_point)117   inline ELUOperatorTester& output_zero_point(uint8_t output_zero_point) {
118     this->output_zero_point_ = output_zero_point;
119     return *this;
120   }
121 
output_zero_point()122   inline uint8_t output_zero_point() const {
123     return this->output_zero_point_;
124   }
125 
qmin(uint8_t qmin)126   inline ELUOperatorTester& qmin(uint8_t qmin) {
127     this->qmin_ = qmin;
128     return *this;
129   }
130 
qmin()131   inline uint8_t qmin() const {
132     return this->qmin_;
133   }
134 
qmax(uint8_t qmax)135   inline ELUOperatorTester& qmax(uint8_t qmax) {
136     this->qmax_ = qmax;
137     return *this;
138   }
139 
qmax()140   inline uint8_t qmax() const {
141     return this->qmax_;
142   }
143 
iterations(size_t iterations)144   inline ELUOperatorTester& iterations(size_t iterations) {
145     this->iterations_ = iterations;
146     return *this;
147   }
148 
iterations()149   inline size_t iterations() const {
150     return this->iterations_;
151   }
152 
TestF16()153   void TestF16() const {
154     std::random_device random_device;
155     auto rng = std::mt19937(random_device());
156     std::uniform_real_distribution<float> f32dist(-25.0f, 25.0f);
157 
158     std::vector<uint16_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
159     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
160     std::vector<float> output_ref(batch_size() * channels());
161     for (size_t iteration = 0; iteration < iterations(); iteration++) {
162       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
163       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
164 
165       // Compute reference results.
166       for (size_t i = 0; i < batch_size(); i++) {
167         for (size_t c = 0; c < channels(); c++) {
168           const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]);
169           output_ref[i * channels() + c] = std::signbit(x) ? std::expm1(x) * alpha() : x;
170         }
171       }
172 
173       // Create, setup, run, and destroy ELU operator.
174       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
175       xnn_operator_t elu_op = nullptr;
176 
177       const xnn_status status = xnn_create_elu_nc_f16(
178           channels(), input_stride(), output_stride(),
179           alpha(),
180           0, &elu_op);
181       if (status == xnn_status_unsupported_hardware) {
182         GTEST_SKIP();
183       }
184       ASSERT_EQ(xnn_status_success, status);
185       ASSERT_NE(nullptr, elu_op);
186 
187       // Smart pointer to automatically delete elu_op.
188       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_elu_op(elu_op, xnn_delete_operator);
189 
190       ASSERT_EQ(xnn_status_success,
191         xnn_setup_elu_nc_f16(
192           elu_op,
193           batch_size(),
194           input.data(), output.data(),
195           nullptr /* thread pool */));
196 
197       ASSERT_EQ(xnn_status_success,
198         xnn_run_operator(elu_op, nullptr /* thread pool */));
199 
200       // Verify results.
201       for (size_t i = 0; i < batch_size(); i++) {
202         for (size_t c = 0; c < channels(); c++) {
203           ASSERT_NEAR(
204               fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
205               output_ref[i * channels() + c],
206               std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 5.0e-3f));
207         }
208       }
209     }
210   }
211 
TestF32()212   void TestF32() const {
213     std::random_device random_device;
214     auto rng = std::mt19937(random_device());
215     std::uniform_real_distribution<float> f32dist(-20.0f, 20.0f);
216 
217     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
218     std::vector<float> output((batch_size() - 1) * output_stride() + channels());
219     std::vector<double> output_ref(batch_size() * channels());
220     for (size_t iteration = 0; iteration < iterations(); iteration++) {
221       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
222       std::fill(output.begin(), output.end(), std::nanf(""));
223 
224       // Compute reference results.
225       for (size_t i = 0; i < batch_size(); i++) {
226         for (size_t c = 0; c < channels(); c++) {
227           const double x = double(input[i * input_stride() + c]);
228           output_ref[i * channels() + c] = std::signbit(x) ? std::expm1(x) * alpha() : x;
229         }
230       }
231 
232       // Create, setup, run, and destroy ELU operator.
233       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
234       xnn_operator_t elu_op = nullptr;
235 
236       ASSERT_EQ(xnn_status_success,
237         xnn_create_elu_nc_f32(
238           channels(), input_stride(), output_stride(),
239           alpha(),
240           0, &elu_op));
241       ASSERT_NE(nullptr, elu_op);
242 
243       // Smart pointer to automatically delete elu_op.
244       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_elu_op(elu_op, xnn_delete_operator);
245 
246       ASSERT_EQ(xnn_status_success,
247         xnn_setup_elu_nc_f32(
248           elu_op,
249           batch_size(),
250           input.data(), output.data(),
251           nullptr /* thread pool */));
252 
253       ASSERT_EQ(xnn_status_success,
254         xnn_run_operator(elu_op, nullptr /* thread pool */));
255 
256       // Verify results.
257       for (size_t i = 0; i < batch_size(); i++) {
258         for (size_t c = 0; c < channels(); c++) {
259           ASSERT_NEAR(output[i * output_stride() + c],
260                       output_ref[i * channels() + c],
261                       std::abs(output_ref[i * channels() + c]) * 1.0e-5)
262             << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
263             << ", input " << input[i * input_stride() + c] << ", alpha " << alpha();
264         }
265       }
266     }
267   }
268 
TestQS8()269   void TestQS8() const {
270     std::random_device random_device;
271     auto rng = std::mt19937(random_device());
272     std::uniform_int_distribution<int32_t> i8dist(
273       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
274 
275     std::vector<int8_t> input((batch_size() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
276     std::vector<int8_t> output((batch_size() - 1) * output_stride() + channels());
277     std::vector<float> output_ref(batch_size() * channels());
278     for (size_t iteration = 0; iteration < iterations(); iteration++) {
279       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
280       std::fill(output.begin(), output.end(), INT8_C(0xA5));
281 
282       // Compute reference results.
283       for (size_t i = 0; i < batch_size(); i++) {
284         for (size_t c = 0; c < channels(); c++) {
285           const float x = input_scale() *
286             (int32_t(input[i * input_stride() + c]) - int32_t(input_zero_point() - 0x80));
287           const float elu_x = std::signbit(x) ? alpha() * std::expm1(x) : x;
288           const float scaled_elu_x = elu_x / output_scale();
289           float y = scaled_elu_x;
290           y = std::min<float>(y, int32_t(qmax() - 0x80) - int32_t(output_zero_point() - 0x80));
291           y = std::max<float>(y, int32_t(qmin() - 0x80) - int32_t(output_zero_point() - 0x80));
292           output_ref[i * channels() + c] = y + int32_t(output_zero_point() - 0x80);
293         }
294       }
295 
296       // Create, setup, run, and destroy Sigmoid operator.
297       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
298       xnn_operator_t elu_op = nullptr;
299 
300       ASSERT_EQ(xnn_status_success,
301         xnn_create_elu_nc_qs8(
302           channels(), input_stride(), output_stride(),
303           alpha(),
304           int8_t(input_zero_point() - 0x80), input_scale(),
305           int8_t(output_zero_point() - 0x80), output_scale(),
306           int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
307           0, &elu_op));
308       ASSERT_NE(nullptr, elu_op);
309 
310       // Smart pointer to automatically delete elu_op.
311       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_elu_op(elu_op, xnn_delete_operator);
312 
313       ASSERT_EQ(xnn_status_success,
314         xnn_setup_elu_nc_qs8(
315           elu_op,
316           batch_size(),
317           input.data(), output.data(),
318           nullptr /* thread pool */));
319 
320       ASSERT_EQ(xnn_status_success,
321         xnn_run_operator(elu_op, nullptr /* thread pool */));
322 
323       // Verify results.
324       for (size_t i = 0; i < batch_size(); i++) {
325         for (size_t c = 0; c < channels(); c++) {
326           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.6f);
327         }
328       }
329     }
330   }
331 
332  private:
333   size_t batch_size_{1};
334   size_t channels_{1};
335   size_t input_stride_{0};
336   size_t output_stride_{0};
337   float alpha_{0.5f};
338   float input_scale_{0.75f};
339   uint8_t input_zero_point_{121};
340   float output_scale_{0.75f};
341   uint8_t output_zero_point_{121};
342   uint8_t qmin_{0};
343   uint8_t qmax_{255};
344   size_t iterations_{15};
345 };
346