xref: /aosp_15_r20/external/XNNPACK/test/binary-elementwise-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <initializer_list>
16 #include <limits>
17 #include <numeric>
18 #include <random>
19 #include <vector>
20 
21 #include <fp16.h>
22 
23 #include <xnnpack.h>
24 
25 
26 class BinaryElementwiseOperatorTester {
27  public:
28   enum class OperationType {
29     Unknown,
30     Add,
31     Divide,
32     Maximum,
33     Minimum,
34     Multiply,
35     Subtract,
36     SquaredDifference,
37   };
38 
input1_shape(std::initializer_list<size_t> input1_shape)39   inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
40     assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
41     this->input1_shape_ = std::vector<size_t>(input1_shape);
42     return *this;
43   }
44 
input1_shape()45   inline const std::vector<size_t>& input1_shape() const {
46     return this->input1_shape_;
47   }
48 
input1_dim(size_t i)49   inline size_t input1_dim(size_t i) const {
50     return i < num_input1_dims() ? this->input1_shape_[i] : 1;
51   }
52 
num_input1_dims()53   inline size_t num_input1_dims() const {
54     return this->input1_shape_.size();
55   }
56 
num_input1_elements()57   inline size_t num_input1_elements() const {
58     return std::accumulate(
59       this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
60   }
61 
input1_zero_point(int16_t input1_zero_point)62   inline BinaryElementwiseOperatorTester& input1_zero_point(int16_t input1_zero_point) {
63     this->input1_zero_point_ = input1_zero_point;
64     return *this;
65   }
66 
input1_zero_point()67   inline int16_t input1_zero_point() const {
68     return this->input1_zero_point_;
69   }
70 
input1_scale(float input1_scale)71   inline BinaryElementwiseOperatorTester& input1_scale(float input1_scale) {
72     assert(std::isfinite(input1_scale));
73     this->input1_scale_ = input1_scale;
74     return *this;
75   }
76 
input1_scale()77   inline float input1_scale() const {
78     return this->input1_scale_;
79   }
80 
input2_shape(std::initializer_list<size_t> input2_shape)81   inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
82     assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
83     this->input2_shape_ = std::vector<size_t>(input2_shape);
84     return *this;
85   }
86 
input2_shape()87   inline const std::vector<size_t>& input2_shape() const {
88     return this->input2_shape_;
89   }
90 
input2_dim(size_t i)91   inline size_t input2_dim(size_t i) const {
92     return i < num_input2_dims() ? this->input2_shape_[i] : 1;
93   }
94 
num_input2_dims()95   inline size_t num_input2_dims() const {
96     return this->input2_shape_.size();
97   }
98 
num_input2_elements()99   inline size_t num_input2_elements() const {
100     return std::accumulate(
101       this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
102   }
103 
input2_zero_point(int16_t input2_zero_point)104   inline BinaryElementwiseOperatorTester& input2_zero_point(int16_t input2_zero_point) {
105     this->input2_zero_point_ = input2_zero_point;
106     return *this;
107   }
108 
input2_zero_point()109   inline int16_t input2_zero_point() const {
110     return this->input2_zero_point_;
111   }
112 
input2_scale(float input2_scale)113   inline BinaryElementwiseOperatorTester& input2_scale(float input2_scale) {
114     assert(std::isfinite(input2_scale));
115     this->input2_scale_ = input2_scale;
116     return *this;
117   }
118 
input2_scale()119   inline float input2_scale() const {
120     return this->input2_scale_;
121   }
122 
output_zero_point(int16_t output_zero_point)123   inline BinaryElementwiseOperatorTester& output_zero_point(int16_t output_zero_point) {
124     this->output_zero_point_ = output_zero_point;
125     return *this;
126   }
127 
output_zero_point()128   inline int16_t output_zero_point() const {
129     return this->output_zero_point_;
130   }
131 
output_scale(float output_scale)132   inline BinaryElementwiseOperatorTester& output_scale(float output_scale) {
133     assert(std::isfinite(output_scale));
134     this->output_scale_ = output_scale;
135     return *this;
136   }
137 
output_scale()138   inline float output_scale() const {
139     return this->output_scale_;
140   }
141 
qmin(uint8_t qmin)142   inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
143     this->qmin_ = qmin;
144     return *this;
145   }
146 
qmin()147   inline uint8_t qmin() const {
148     return this->qmin_;
149   }
150 
qmax(uint8_t qmax)151   inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
152     this->qmax_ = qmax;
153     return *this;
154   }
155 
qmax()156   inline uint8_t qmax() const {
157     return this->qmax_;
158   }
159 
operation_type(OperationType operation_type)160   inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
161     this->operation_type_ = operation_type;
162     return *this;
163   }
164 
operation_type()165   inline OperationType operation_type() const {
166     return this->operation_type_;
167   }
168 
iterations(size_t iterations)169   inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
170     this->iterations_ = iterations;
171     return *this;
172   }
173 
iterations()174   inline size_t iterations() const {
175     return this->iterations_;
176   }
177 
Compute(float a,float b)178   float Compute(float a, float b) const {
179     switch (operation_type()) {
180       case OperationType::Add:
181         return a + b;
182       case OperationType::Divide:
183         return a / b;
184       case OperationType::Maximum:
185         return std::max<float>(a, b);
186       case OperationType::Minimum:
187         return std::min<float>(a, b);
188       case OperationType::Multiply:
189         return a * b;
190       case OperationType::Subtract:
191         return a - b;
192       case OperationType::SquaredDifference:
193         return (a - b) * (a - b);
194       default:
195         return std::nanf("");
196     }
197   }
198 
TestQS8()199   void TestQS8() const {
200     ASSERT_NE(operation_type(), OperationType::Unknown);
201     ASSERT_GE(input1_zero_point(), std::numeric_limits<int8_t>::min());
202     ASSERT_LE(input1_zero_point(), std::numeric_limits<int8_t>::max());
203     ASSERT_GE(input2_zero_point(), std::numeric_limits<int8_t>::min());
204     ASSERT_LE(input2_zero_point(), std::numeric_limits<int8_t>::max());
205     ASSERT_GE(output_zero_point(), std::numeric_limits<int8_t>::min());
206     ASSERT_LE(output_zero_point(), std::numeric_limits<int8_t>::max());
207 
208     std::random_device random_device;
209     auto rng = std::mt19937(random_device());
210     std::uniform_int_distribution<int32_t> i8dist(
211       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
212 
213     // Compute generalized shapes.
214     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
215     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
216     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
217     std::fill(input1_dims.begin(), input1_dims.end(), 1);
218     std::fill(input2_dims.begin(), input2_dims.end(), 1);
219     std::fill(output_dims.begin(), output_dims.end(), 1);
220     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
221     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
222     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
223       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
224         ASSERT_EQ(input1_dims[i], input2_dims[i]);
225       }
226       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
227     }
228     const size_t num_output_elements =
229       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
230 
231     // Compute generalized strides.
232     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
233     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
234     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
235     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
236     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
237       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
238       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
239       output_strides[i - 1] = output_stride;
240       input1_stride *= input1_dims[i - 1];
241       input2_stride *= input2_dims[i - 1];
242       output_stride *= output_dims[i - 1];
243     }
244 
245     std::vector<int8_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
246     std::vector<int8_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
247     std::vector<int8_t> output(num_output_elements);
248     std::vector<float> output_ref(num_output_elements);
249     for (size_t iteration = 0; iteration < iterations(); iteration++) {
250       std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
251       std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
252       std::fill(output.begin(), output.end(), 0xAA);
253 
254       // Compute reference results.
255       for (size_t i = 0; i < output_dims[0]; i++) {
256         for (size_t j = 0; j < output_dims[1]; j++) {
257           for (size_t k = 0; k < output_dims[2]; k++) {
258             for (size_t l = 0; l < output_dims[3]; l++) {
259               for (size_t m = 0; m < output_dims[4]; m++) {
260                 for (size_t n = 0; n < output_dims[5]; n++) {
261                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
262                     input1_scale() * (int32_t(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]) - input1_zero_point()),
263                     input2_scale() * (int32_t(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]) - input2_zero_point())) /
264                       output_scale() + float(output_zero_point());
265                 }
266               }
267             }
268           }
269         }
270       }
271 
272       for (float& output_value : output_ref) {
273         output_value = std::min(std::max(output_value, float(int8_t(qmin() - 0x80))), float(int8_t(qmax() - 0x80)));
274       }
275 
276       // Create, setup, run, and destroy a binary elementwise operator.
277       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
278       xnn_operator_t binary_elementwise_op = nullptr;
279       xnn_status status = xnn_status_unsupported_parameter;
280       switch (operation_type()) {
281         case OperationType::Add:
282           status = xnn_create_add_nd_qs8(
283             input1_zero_point(), input1_scale(),
284             input2_zero_point(), input2_scale(),
285             output_zero_point(), output_scale(),
286             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
287             0, &binary_elementwise_op);
288           break;
289         case OperationType::Multiply:
290           status = xnn_create_multiply_nd_qs8(
291             input1_zero_point(), input1_scale(),
292             input2_zero_point(), input2_scale(),
293             output_zero_point(), output_scale(),
294             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
295             0, &binary_elementwise_op);
296           break;
297         case OperationType::Subtract:
298           status = xnn_create_subtract_nd_qs8(
299             input1_zero_point(), input1_scale(),
300             input2_zero_point(), input2_scale(),
301             output_zero_point(), output_scale(),
302             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
303             0, &binary_elementwise_op);
304           break;
305         default:
306           FAIL() << "Unsupported operation type";
307       }
308       if (status == xnn_status_unsupported_hardware) {
309         GTEST_SKIP();
310       }
311       ASSERT_EQ(xnn_status_success, status);
312       ASSERT_NE(nullptr, binary_elementwise_op);
313 
314       // Smart pointer to automatically delete binary_elementwise_op.
315       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
316 
317       switch (operation_type()) {
318         case OperationType::Add:
319           ASSERT_EQ(xnn_status_success,
320             xnn_setup_add_nd_qs8(
321               binary_elementwise_op,
322               num_input1_dims(),
323               input1_shape().data(),
324               num_input2_dims(),
325               input2_shape().data(),
326               input1.data(), input2.data(), output.data(),
327               nullptr /* thread pool */));
328           break;
329         case OperationType::Multiply:
330           ASSERT_EQ(xnn_status_success,
331             xnn_setup_multiply_nd_qs8(
332               binary_elementwise_op,
333               num_input1_dims(),
334               input1_shape().data(),
335               num_input2_dims(),
336               input2_shape().data(),
337               input1.data(), input2.data(), output.data(),
338               nullptr /* thread pool */));
339           break;
340         case OperationType::Subtract:
341           ASSERT_EQ(xnn_status_success,
342             xnn_setup_subtract_nd_qs8(
343               binary_elementwise_op,
344               num_input1_dims(),
345               input1_shape().data(),
346               num_input2_dims(),
347               input2_shape().data(),
348               input1.data(), input2.data(), output.data(),
349               nullptr /* thread pool */));
350           break;
351         default:
352           FAIL() << "Unsupported operation type";
353       }
354 
355       ASSERT_EQ(xnn_status_success,
356         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
357 
358       // Verify results.
359       for (size_t i = 0; i < output_dims[0]; i++) {
360         for (size_t j = 0; j < output_dims[1]; j++) {
361           for (size_t k = 0; k < output_dims[2]; k++) {
362             for (size_t l = 0; l < output_dims[3]; l++) {
363               for (size_t m = 0; m < output_dims[4]; m++) {
364                 for (size_t n = 0; n < output_dims[5]; n++) {
365                   const size_t index =
366                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
367                   ASSERT_NEAR(float(output[index]), output_ref[index], 0.6f)
368                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
369                     << ", input1 zero point = " << input1_zero_point() << ", input1 scale = " << input1_scale()
370                     << ", input2 zero point = " << input2_zero_point() << ", input2 scale = " << input2_scale()
371                     << ", output zero point = " << output_zero_point() << ", output scale = " << output_scale();
372                 }
373               }
374             }
375           }
376         }
377       }
378     }
379   }
380 
TestQU8()381   void TestQU8() const {
382     ASSERT_NE(operation_type(), OperationType::Unknown);
383     ASSERT_GE(input1_zero_point(), std::numeric_limits<uint8_t>::min());
384     ASSERT_LE(input1_zero_point(), std::numeric_limits<uint8_t>::max());
385     ASSERT_GE(input2_zero_point(), std::numeric_limits<uint8_t>::min());
386     ASSERT_LE(input2_zero_point(), std::numeric_limits<uint8_t>::max());
387     ASSERT_GE(output_zero_point(), std::numeric_limits<uint8_t>::min());
388     ASSERT_LE(output_zero_point(), std::numeric_limits<uint8_t>::max());
389 
390     std::random_device random_device;
391     auto rng = std::mt19937(random_device());
392     std::uniform_int_distribution<int32_t> u8dist(
393       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
394 
395     // Compute generalized shapes.
396     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
397     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
398     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
399     std::fill(input1_dims.begin(), input1_dims.end(), 1);
400     std::fill(input2_dims.begin(), input2_dims.end(), 1);
401     std::fill(output_dims.begin(), output_dims.end(), 1);
402     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
403     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
404     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
405       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
406         ASSERT_EQ(input1_dims[i], input2_dims[i]);
407       }
408       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
409     }
410     const size_t num_output_elements =
411       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
412 
413     // Compute generalized strides.
414     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
415     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
416     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
417     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
418     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
419       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
420       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
421       output_strides[i - 1] = output_stride;
422       input1_stride *= input1_dims[i - 1];
423       input2_stride *= input2_dims[i - 1];
424       output_stride *= output_dims[i - 1];
425     }
426 
427     std::vector<uint8_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
428     std::vector<uint8_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
429     std::vector<uint8_t> output(num_output_elements);
430     std::vector<float> output_ref(num_output_elements);
431     for (size_t iteration = 0; iteration < iterations(); iteration++) {
432       std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
433       std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
434       std::fill(output.begin(), output.end(), 0xAA);
435 
436       // Compute reference results.
437       for (size_t i = 0; i < output_dims[0]; i++) {
438         for (size_t j = 0; j < output_dims[1]; j++) {
439           for (size_t k = 0; k < output_dims[2]; k++) {
440             for (size_t l = 0; l < output_dims[3]; l++) {
441               for (size_t m = 0; m < output_dims[4]; m++) {
442                 for (size_t n = 0; n < output_dims[5]; n++) {
443                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
444                     input1_scale() * (int32_t(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]) - input1_zero_point()),
445                     input2_scale() * (int32_t(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]) - input2_zero_point())) /
446                       output_scale() + float(output_zero_point());
447                 }
448               }
449             }
450           }
451         }
452       }
453 
454       for (float& output_value : output_ref) {
455         output_value = std::min(std::max(output_value, float(int32_t(qmin()))), float(int32_t(qmax())));
456       }
457 
458       // Create, setup, run, and destroy a binary elementwise operator.
459       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
460       xnn_operator_t binary_elementwise_op = nullptr;
461       xnn_status status = xnn_status_unsupported_parameter;
462       switch (operation_type()) {
463         case OperationType::Add:
464           status = xnn_create_add_nd_qu8(
465             input1_zero_point(), input1_scale(),
466             input2_zero_point(), input2_scale(),
467             output_zero_point(), output_scale(),
468             qmin(), qmax(),
469             0, &binary_elementwise_op);
470           break;
471         case OperationType::Multiply:
472           status = xnn_create_multiply_nd_qu8(
473             input1_zero_point(), input1_scale(),
474             input2_zero_point(), input2_scale(),
475             output_zero_point(), output_scale(),
476             qmin(), qmax(),
477             0, &binary_elementwise_op);
478           break;
479         case OperationType::Subtract:
480           status = xnn_create_subtract_nd_qu8(
481             input1_zero_point(), input1_scale(),
482             input2_zero_point(), input2_scale(),
483             output_zero_point(), output_scale(),
484             qmin(), qmax(),
485             0, &binary_elementwise_op);
486           break;
487         default:
488           FAIL() << "Unsupported operation type";
489       }
490       if (status == xnn_status_unsupported_hardware) {
491         GTEST_SKIP();
492       }
493       ASSERT_EQ(xnn_status_success, status);
494       ASSERT_NE(nullptr, binary_elementwise_op);
495 
496       // Smart pointer to automatically delete binary_elementwise_op.
497       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
498 
499       switch (operation_type()) {
500         case OperationType::Add:
501           ASSERT_EQ(xnn_status_success,
502             xnn_setup_add_nd_qu8(
503               binary_elementwise_op,
504               num_input1_dims(),
505               input1_shape().data(),
506               num_input2_dims(),
507               input2_shape().data(),
508               input1.data(), input2.data(), output.data(),
509               nullptr /* thread pool */));
510           break;
511         case OperationType::Multiply:
512           ASSERT_EQ(xnn_status_success,
513             xnn_setup_multiply_nd_qu8(
514               binary_elementwise_op,
515               num_input1_dims(),
516               input1_shape().data(),
517               num_input2_dims(),
518               input2_shape().data(),
519               input1.data(), input2.data(), output.data(),
520               nullptr /* thread pool */));
521           break;
522         case OperationType::Subtract:
523           ASSERT_EQ(xnn_status_success,
524             xnn_setup_subtract_nd_qu8(
525               binary_elementwise_op,
526               num_input1_dims(),
527               input1_shape().data(),
528               num_input2_dims(),
529               input2_shape().data(),
530               input1.data(), input2.data(), output.data(),
531               nullptr /* thread pool */));
532           break;
533         default:
534           FAIL() << "Unsupported operation type";
535       }
536 
537       ASSERT_EQ(xnn_status_success,
538         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
539 
540       // Verify results.
541       for (size_t i = 0; i < output_dims[0]; i++) {
542         for (size_t j = 0; j < output_dims[1]; j++) {
543           for (size_t k = 0; k < output_dims[2]; k++) {
544             for (size_t l = 0; l < output_dims[3]; l++) {
545               for (size_t m = 0; m < output_dims[4]; m++) {
546                 for (size_t n = 0; n < output_dims[5]; n++) {
547                   const size_t index =
548                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
549                   ASSERT_NEAR(float(int32_t(output[index])), output_ref[index], 0.6f)
550                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
551                     << ", input1 zero point = " << input1_zero_point() << ", input1 scale = " << input1_scale()
552                     << ", input2 zero point = " << input2_zero_point() << ", input2 scale = " << input2_scale()
553                     << ", output zero point = " << output_zero_point() << ", output scale = " << output_scale();
554                 }
555               }
556             }
557           }
558         }
559       }
560     }
561   }
562 
TestF16()563   void TestF16() const {
564     ASSERT_NE(operation_type(), OperationType::Unknown);
565 
566     std::random_device random_device;
567     auto rng = std::mt19937(random_device());
568     std::uniform_real_distribution<float> f32dist(0.01f, 1.0f);
569 
570     // Compute generalized shapes.
571     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
572     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
573     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
574     std::fill(input1_dims.begin(), input1_dims.end(), 1);
575     std::fill(input2_dims.begin(), input2_dims.end(), 1);
576     std::fill(output_dims.begin(), output_dims.end(), 1);
577     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
578     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
579     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
580       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
581         ASSERT_EQ(input1_dims[i], input2_dims[i]);
582       }
583       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
584     }
585     const size_t num_output_elements =
586       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
587 
588     // Compute generalized strides.
589     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
590     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
591     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
592     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
593     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
594       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
595       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
596       output_strides[i - 1] = output_stride;
597       input1_stride *= input1_dims[i - 1];
598       input2_stride *= input2_dims[i - 1];
599       output_stride *= output_dims[i - 1];
600     }
601 
602     std::vector<uint16_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
603     std::vector<uint16_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
604     std::vector<uint16_t> output(num_output_elements);
605     std::vector<float> output_ref(num_output_elements);
606     for (size_t iteration = 0; iteration < iterations(); iteration++) {
607       std::generate(input1.begin(), input1.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
608       std::generate(input2.begin(), input2.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
609       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
610 
611       // Compute reference results.
612       for (size_t i = 0; i < output_dims[0]; i++) {
613         for (size_t j = 0; j < output_dims[1]; j++) {
614           for (size_t k = 0; k < output_dims[2]; k++) {
615             for (size_t l = 0; l < output_dims[3]; l++) {
616               for (size_t m = 0; m < output_dims[4]; m++) {
617                 for (size_t n = 0; n < output_dims[5]; n++) {
618                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
619                     fp16_ieee_to_fp32_value(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]),
620                     fp16_ieee_to_fp32_value(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]));
621                 }
622               }
623             }
624           }
625         }
626       }
627 
628       // Compute clamping parameters.
629       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
630       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
631       const float accumulated_range = accumulated_max - accumulated_min;
632       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
633       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
634       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
635       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
636 
637       for (float& output_value : output_ref) {
638         output_value = std::min(std::max(output_value, output_min), output_max);
639       }
640 
641       // Create, setup, run, and destroy a binary elementwise operator.
642       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
643       xnn_operator_t binary_elementwise_op = nullptr;
644       xnn_status status = xnn_status_unsupported_parameter;
645       switch (operation_type()) {
646         case OperationType::Add:
647           status = xnn_create_add_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
648           break;
649         case OperationType::Divide:
650           status = xnn_create_divide_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
651           break;
652         case OperationType::Maximum:
653           status = xnn_create_maximum_nd_f16(0, &binary_elementwise_op);
654           break;
655         case OperationType::Minimum:
656           status = xnn_create_minimum_nd_f16(0, &binary_elementwise_op);
657           break;
658         case OperationType::Multiply:
659           status = xnn_create_multiply_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
660           break;
661         case OperationType::SquaredDifference:
662           status = xnn_create_squared_difference_nd_f16(0, &binary_elementwise_op);
663           break;
664         case OperationType::Subtract:
665           status = xnn_create_subtract_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
666           break;
667         default:
668           FAIL() << "Unsupported operation type";
669       }
670       if (status == xnn_status_unsupported_hardware) {
671         GTEST_SKIP();
672       }
673       ASSERT_EQ(xnn_status_success, status);
674       ASSERT_NE(nullptr, binary_elementwise_op);
675 
676       // Smart pointer to automatically delete binary_elementwise_op.
677       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
678 
679       switch (operation_type()) {
680         case OperationType::Add:
681           ASSERT_EQ(xnn_status_success,
682             xnn_setup_add_nd_f16(
683               binary_elementwise_op,
684               num_input1_dims(),
685               input1_shape().data(),
686               num_input2_dims(),
687               input2_shape().data(),
688               input1.data(), input2.data(), output.data(),
689               nullptr /* thread pool */));
690           break;
691         case OperationType::Divide:
692           ASSERT_EQ(xnn_status_success,
693             xnn_setup_divide_nd_f16(
694               binary_elementwise_op,
695               num_input1_dims(),
696               input1_shape().data(),
697               num_input2_dims(),
698               input2_shape().data(),
699               input1.data(), input2.data(), output.data(),
700               nullptr /* thread pool */));
701           break;
702         case OperationType::Maximum:
703           ASSERT_EQ(xnn_status_success,
704             xnn_setup_maximum_nd_f16(
705               binary_elementwise_op,
706               num_input1_dims(),
707               input1_shape().data(),
708               num_input2_dims(),
709               input2_shape().data(),
710               input1.data(), input2.data(), output.data(),
711               nullptr /* thread pool */));
712           break;
713         case OperationType::Minimum:
714           ASSERT_EQ(xnn_status_success,
715             xnn_setup_minimum_nd_f16(
716               binary_elementwise_op,
717               num_input1_dims(),
718               input1_shape().data(),
719               num_input2_dims(),
720               input2_shape().data(),
721               input1.data(), input2.data(), output.data(),
722               nullptr /* thread pool */));
723           break;
724         case OperationType::Multiply:
725           ASSERT_EQ(xnn_status_success,
726             xnn_setup_multiply_nd_f16(
727               binary_elementwise_op,
728               num_input1_dims(),
729               input1_shape().data(),
730               num_input2_dims(),
731               input2_shape().data(),
732               input1.data(), input2.data(), output.data(),
733               nullptr /* thread pool */));
734           break;
735         case OperationType::SquaredDifference:
736           ASSERT_EQ(xnn_status_success,
737             xnn_setup_squared_difference_nd_f16(
738               binary_elementwise_op,
739               num_input1_dims(),
740               input1_shape().data(),
741               num_input2_dims(),
742               input2_shape().data(),
743               input1.data(), input2.data(), output.data(),
744               nullptr /* thread pool */));
745           break;
746         case OperationType::Subtract:
747           ASSERT_EQ(xnn_status_success,
748             xnn_setup_subtract_nd_f16(
749               binary_elementwise_op,
750               num_input1_dims(),
751               input1_shape().data(),
752               num_input2_dims(),
753               input2_shape().data(),
754               input1.data(), input2.data(), output.data(),
755               nullptr /* thread pool */));
756           break;
757         default:
758           FAIL() << "Unsupported operation type";
759       }
760 
761       ASSERT_EQ(xnn_status_success,
762         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
763 
764       // Verify results.
765       for (size_t i = 0; i < output_dims[0]; i++) {
766         for (size_t j = 0; j < output_dims[1]; j++) {
767           for (size_t k = 0; k < output_dims[2]; k++) {
768             for (size_t l = 0; l < output_dims[3]; l++) {
769               for (size_t m = 0; m < output_dims[4]; m++) {
770                 for (size_t n = 0; n < output_dims[5]; n++) {
771                   const size_t index =
772                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
773                   ASSERT_NEAR(fp16_ieee_to_fp32_value(output[index]), output_ref[index], std::max(1.0e-4f, std::abs(output_ref[index]) * 1.0e-2f))
774                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
775                 }
776               }
777             }
778           }
779         }
780       }
781     }
782   }
783 
TestF32()784   void TestF32() const {
785     ASSERT_NE(operation_type(), OperationType::Unknown);
786 
787     std::random_device random_device;
788     auto rng = std::mt19937(random_device());
789     std::uniform_real_distribution<float> f32dist(0.01f, 1.0f);
790 
791     // Compute generalized shapes.
792     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
793     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
794     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
795     std::fill(input1_dims.begin(), input1_dims.end(), 1);
796     std::fill(input2_dims.begin(), input2_dims.end(), 1);
797     std::fill(output_dims.begin(), output_dims.end(), 1);
798     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
799     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
800     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
801       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
802         ASSERT_EQ(input1_dims[i], input2_dims[i]);
803       }
804       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
805     }
806     const size_t num_output_elements =
807       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
808 
809     // Compute generalized strides.
810     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
811     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
812     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
813     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
814     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
815       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
816       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
817       output_strides[i - 1] = output_stride;
818       input1_stride *= input1_dims[i - 1];
819       input2_stride *= input2_dims[i - 1];
820       output_stride *= output_dims[i - 1];
821     }
822 
823     std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
824     std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
825     std::vector<float> output(num_output_elements);
826     std::vector<float> output_ref(num_output_elements);
827     for (size_t iteration = 0; iteration < iterations(); iteration++) {
828       std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
829       std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
830       std::fill(output.begin(), output.end(), nanf(""));
831 
832       // Compute reference results.
833       for (size_t i = 0; i < output_dims[0]; i++) {
834         for (size_t j = 0; j < output_dims[1]; j++) {
835           for (size_t k = 0; k < output_dims[2]; k++) {
836             for (size_t l = 0; l < output_dims[3]; l++) {
837               for (size_t m = 0; m < output_dims[4]; m++) {
838                 for (size_t n = 0; n < output_dims[5]; n++) {
839                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
840                     input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
841                     input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
842                 }
843               }
844             }
845           }
846         }
847       }
848       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
849       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
850       const float accumulated_range = accumulated_max - accumulated_min;
851       const float output_min = num_output_elements == 1 ?
852         -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
853       const float output_max = num_output_elements == 1 ?
854         +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
855       for (float& output_value : output_ref) {
856         output_value = std::min(std::max(output_value, output_min), output_max);
857       }
858 
859       // Create, setup, run, and destroy a binary elementwise operator.
860       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
861       xnn_operator_t binary_elementwise_op = nullptr;
862 
863       switch (operation_type()) {
864         case OperationType::Add:
865           ASSERT_EQ(xnn_status_success,
866             xnn_create_add_nd_f32(
867               output_min, output_max,
868               0, &binary_elementwise_op));
869           break;
870         case OperationType::Divide:
871           ASSERT_EQ(xnn_status_success,
872             xnn_create_divide_nd_f32(
873               output_min, output_max,
874               0, &binary_elementwise_op));
875           break;
876         case OperationType::Maximum:
877           ASSERT_EQ(xnn_status_success,
878             xnn_create_maximum_nd_f32(
879               0, &binary_elementwise_op));
880           break;
881         case OperationType::Minimum:
882           ASSERT_EQ(xnn_status_success,
883             xnn_create_minimum_nd_f32(
884               0, &binary_elementwise_op));
885           break;
886         case OperationType::Multiply:
887           ASSERT_EQ(xnn_status_success,
888             xnn_create_multiply_nd_f32(
889               output_min, output_max,
890               0, &binary_elementwise_op));
891           break;
892         case OperationType::Subtract:
893           ASSERT_EQ(xnn_status_success,
894             xnn_create_subtract_nd_f32(
895               output_min, output_max,
896               0, &binary_elementwise_op));
897           break;
898         case OperationType::SquaredDifference:
899           ASSERT_EQ(xnn_status_success,
900             xnn_create_squared_difference_nd_f32(
901               0, &binary_elementwise_op));
902           break;
903         default:
904           FAIL() << "Unsupported operation type";
905       }
906       ASSERT_NE(nullptr, binary_elementwise_op);
907 
908       // Smart pointer to automatically delete binary_elementwise_op.
909       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
910 
911       switch (operation_type()) {
912         case OperationType::Add:
913           ASSERT_EQ(xnn_status_success,
914             xnn_setup_add_nd_f32(
915               binary_elementwise_op,
916               num_input1_dims(),
917               input1_shape().data(),
918               num_input2_dims(),
919               input2_shape().data(),
920               input1.data(), input2.data(), output.data(),
921               nullptr /* thread pool */));
922           break;
923         case OperationType::Divide:
924           ASSERT_EQ(xnn_status_success,
925             xnn_setup_divide_nd_f32(
926               binary_elementwise_op,
927               num_input1_dims(),
928               input1_shape().data(),
929               num_input2_dims(),
930               input2_shape().data(),
931               input1.data(), input2.data(), output.data(),
932               nullptr /* thread pool */));
933           break;
934         case OperationType::Maximum:
935           ASSERT_EQ(xnn_status_success,
936             xnn_setup_maximum_nd_f32(
937               binary_elementwise_op,
938               num_input1_dims(),
939               input1_shape().data(),
940               num_input2_dims(),
941               input2_shape().data(),
942               input1.data(), input2.data(), output.data(),
943               nullptr /* thread pool */));
944           break;
945         case OperationType::Minimum:
946           ASSERT_EQ(xnn_status_success,
947             xnn_setup_minimum_nd_f32(
948               binary_elementwise_op,
949               num_input1_dims(),
950               input1_shape().data(),
951               num_input2_dims(),
952               input2_shape().data(),
953               input1.data(), input2.data(), output.data(),
954               nullptr /* thread pool */));
955           break;
956         case OperationType::Multiply:
957           ASSERT_EQ(xnn_status_success,
958             xnn_setup_multiply_nd_f32(
959               binary_elementwise_op,
960               num_input1_dims(),
961               input1_shape().data(),
962               num_input2_dims(),
963               input2_shape().data(),
964               input1.data(), input2.data(), output.data(),
965               nullptr /* thread pool */));
966           break;
967         case OperationType::Subtract:
968           ASSERT_EQ(xnn_status_success,
969             xnn_setup_subtract_nd_f32(
970               binary_elementwise_op,
971               num_input1_dims(),
972               input1_shape().data(),
973               num_input2_dims(),
974               input2_shape().data(),
975               input1.data(), input2.data(), output.data(),
976               nullptr /* thread pool */));
977           break;
978         case OperationType::SquaredDifference:
979           ASSERT_EQ(xnn_status_success,
980             xnn_setup_squared_difference_nd_f32(
981               binary_elementwise_op,
982               num_input1_dims(),
983               input1_shape().data(),
984               num_input2_dims(),
985               input2_shape().data(),
986               input1.data(), input2.data(), output.data(),
987               nullptr /* thread pool */));
988           break;
989         default:
990           FAIL() << "Unsupported operation type";
991       }
992 
993       ASSERT_EQ(xnn_status_success,
994         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
995 
996       // Verify results.
997       for (size_t i = 0; i < output_dims[0]; i++) {
998         for (size_t j = 0; j < output_dims[1]; j++) {
999           for (size_t k = 0; k < output_dims[2]; k++) {
1000             for (size_t l = 0; l < output_dims[3]; l++) {
1001               for (size_t m = 0; m < output_dims[4]; m++) {
1002                 for (size_t n = 0; n < output_dims[5]; n++) {
1003                   const size_t index =
1004                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
1005                   ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
1006                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
1007                 }
1008               }
1009             }
1010           }
1011         }
1012       }
1013     }
1014   }
1015 
1016  private:
1017   std::vector<size_t> input1_shape_;
1018   std::vector<size_t> input2_shape_;
1019   int16_t input1_zero_point_{0};
1020   float input1_scale_{1.0f};
1021   int16_t input2_zero_point_{0};
1022   float input2_scale_{1.0f};
1023   int16_t output_zero_point_{0};
1024   float output_scale_{1.0f};
1025   uint8_t qmin_{0};
1026   uint8_t qmax_{255};
1027   OperationType operation_type_{OperationType::Unknown};
1028   size_t iterations_{3};
1029 };
1030