xref: /aosp_15_r20/external/XNNPACK/test/fully-connected-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <cassert>
14 #include <cstddef>
15 #include <cstdlib>
16 #include <algorithm>
17 #include <cmath>
18 #include <limits>
19 #include <random>
20 #include <vector>
21 
22 #include <fp16.h>
23 
24 #include <xnnpack.h>
25 #include <xnnpack/cache.h>
26 
27 
28 class FullyConnectedOperatorTester {
29  public:
30   enum class WeightsType {
31     Default,
32     FP32,
33   };
34 
input_channels(size_t input_channels)35   inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
36     assert(input_channels >= 1);
37     this->input_channels_ = input_channels;
38     return *this;
39   }
40 
input_channels()41   inline size_t input_channels() const {
42     return this->input_channels_;
43   }
44 
output_channels(size_t output_channels)45   inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
46     assert(output_channels >= 1);
47     this->output_channels_ = output_channels;
48     return *this;
49   }
50 
output_channels()51   inline size_t output_channels() const {
52     return this->output_channels_;
53   }
54 
batch_size(size_t batch_size)55   inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
56     assert(batch_size >= 1);
57     this->batch_size_ = batch_size;
58     return *this;
59   }
60 
batch_size()61   inline size_t batch_size() const {
62     return this->batch_size_;
63   }
64 
input_stride(size_t input_stride)65   inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
66     assert(input_stride >= 1);
67     this->input_stride_ = input_stride;
68     return *this;
69   }
70 
input_stride()71   inline size_t input_stride() const {
72     if (this->input_stride_ == 0) {
73       return input_channels();
74     } else {
75       assert(this->input_stride_ >= input_channels());
76       return this->input_stride_;
77     }
78   }
79 
output_stride(size_t output_stride)80   inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
81     assert(output_stride >= 1);
82     this->output_stride_ = output_stride;
83     return *this;
84   }
85 
output_stride()86   inline size_t output_stride() const {
87     if (this->output_stride_ == 0) {
88       return output_channels();
89     } else {
90       assert(this->output_stride_ >= output_channels());
91       return this->output_stride_;
92     }
93   }
94 
qmin(uint8_t qmin)95   inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
96     this->qmin_ = qmin;
97     return *this;
98   }
99 
qmin()100   inline uint8_t qmin() const {
101     return this->qmin_;
102   }
103 
qmax(uint8_t qmax)104   inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
105     this->qmax_ = qmax;
106     return *this;
107   }
108 
qmax()109   inline uint8_t qmax() const {
110     return this->qmax_;
111   }
112 
transpose_weights(bool transpose_weights)113   inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
114     this->transpose_weights_ = transpose_weights;
115     return *this;
116   }
117 
transpose_weights()118   inline bool transpose_weights() const {
119     return this->transpose_weights_;
120   }
121 
has_bias(bool has_bias)122   inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
123     this->has_bias_ = has_bias;
124     return *this;
125   }
126 
has_bias()127   inline bool has_bias() const {
128     return this->has_bias_;
129   }
130 
weights_type(WeightsType weights_type)131   inline FullyConnectedOperatorTester& weights_type(WeightsType weights_type) {
132     this->weights_type_ = weights_type;
133     return *this;
134   }
135 
weights_type()136   inline WeightsType weights_type() const {
137     return this->weights_type_;
138   }
139 
use_weights_cache(bool use_weights_cache)140   inline FullyConnectedOperatorTester& use_weights_cache(bool use_weights_cache) {
141     this->use_weights_cache_ = use_weights_cache;
142     return *this;
143   }
144 
use_weights_cache()145   inline bool use_weights_cache() const {
146     return this->use_weights_cache_;
147   }
148 
iterations(size_t iterations)149   inline FullyConnectedOperatorTester& iterations(size_t iterations) {
150     this->iterations_ = iterations;
151     return *this;
152   }
153 
iterations()154   inline size_t iterations() const {
155     return this->iterations_;
156   }
157 
TestQS8()158   void TestQS8() const {
159     ASSERT_EQ(weights_type(), WeightsType::Default);
160 
161     std::random_device random_device;
162     auto rng = std::mt19937(random_device());
163     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
164     std::uniform_int_distribution<int32_t> i8dist(
165       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
166     std::uniform_int_distribution<int32_t> w8dist(
167       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
168 
169     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
170       (batch_size() - 1) * input_stride() + input_channels());
171     std::vector<int8_t> kernel(output_channels() * input_channels());
172     std::vector<int32_t> bias(output_channels());
173     std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels());
174     std::vector<int32_t> accumulators(batch_size() * output_channels());
175     std::vector<double> output_ref(batch_size() * output_channels());
176 
177     const int8_t input_zero_point = 127;
178 
179     for (size_t iteration = 0; iteration < iterations(); iteration++) {
180       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
181       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
182       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
183       std::fill(output.begin(), output.end(), INT8_C(0xA5));
184 
185       // Compute reference results, without renormalization.
186       if (has_bias()) {
187         for (size_t i = 0; i < batch_size(); i++) {
188           for (size_t oc = 0; oc < output_channels(); oc++) {
189             accumulators[i * output_channels() + oc] = bias[oc];
190           }
191         }
192       } else {
193         std::fill(accumulators.begin(), accumulators.end(), 0);
194       }
195       if (transpose_weights()) {
196         for (size_t i = 0; i < batch_size(); i++) {
197           for (size_t oc = 0; oc < output_channels(); oc++) {
198             for (size_t ic = 0; ic < input_channels(); ic++) {
199               accumulators[i * output_channels() + oc] +=
200                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
201                 int32_t(kernel[ic * output_channels() + oc]);
202             }
203           }
204         }
205       } else {
206         for (size_t i = 0; i < batch_size(); i++) {
207           for (size_t oc = 0; oc < output_channels(); oc++) {
208             for (size_t ic = 0; ic < input_channels(); ic++) {
209               accumulators[i * output_channels() + oc] +=
210                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
211                 int32_t(kernel[oc * input_channels() + ic]);
212             }
213           }
214         }
215       }
216 
217       // Compute renormalization parameters.
218       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
219       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
220 
221       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
222       const int8_t output_zero_point = int8_t(std::max(std::min(
223         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
224         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
225 
226       // Renormalize reference results.
227       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
228         [this, output_scale, output_zero_point](int32_t x) -> double {
229           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
230         });
231 
232       // Create, setup, run, and destroy Fully Connected operator.
233       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
234       xnn_operator_t fully_connected_op = nullptr;
235 
236       xnn_caches caches = {
237         .code_cache = NULL,
238         .weights_cache = NULL,
239       };
240       xnn_weights_cache weights_cache;
241       if (use_weights_cache()) {
242         xnn_init_weights_cache(&weights_cache);
243         caches.weights_cache = &weights_cache;
244       }
245 
246       const xnn_status status = xnn_create_fully_connected_nc_qs8(
247           input_channels(), output_channels(),
248           input_stride(), output_stride(),
249           input_zero_point, 1.0f /* input scale */,
250           1.0f /* kernel scale */,
251           kernel.data(), has_bias() ? bias.data() : nullptr,
252           output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
253           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
254           &caches,
255           &fully_connected_op);
256       if (status == xnn_status_unsupported_hardware) {
257         GTEST_SKIP();
258       }
259       ASSERT_EQ(xnn_status_success, status);
260       ASSERT_NE(nullptr, fully_connected_op);
261       if (use_weights_cache()) {
262         ASSERT_EQ(xnn_status_success,
263                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
264       }
265 
266       // Smart pointer to automatically delete fully_connected_op.
267       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
268 
269       ASSERT_EQ(xnn_status_success,
270         xnn_setup_fully_connected_nc_qs8(
271           fully_connected_op,
272           batch_size(),
273           input.data(), output.data(),
274           nullptr /* thread pool */));
275 
276       ASSERT_EQ(xnn_status_success,
277         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
278 
279       // Verify results.
280       VerifyQS8(output, output_ref, double(output_zero_point));
281 
282       if (use_weights_cache()) {
283         // Create another operator with the same weights cache.
284         xnn_operator_t fully_connected_op2 = nullptr;
285         size_t old_weights_cache_size = weights_cache.cache.weights.size;
286 
287         ASSERT_EQ(xnn_status_success,
288                   xnn_create_fully_connected_nc_qs8(
289                       input_channels(), output_channels(), input_stride(),
290                       output_stride(), input_zero_point, 1.0f /* input scale */,
291                       1.0f /* kernel scale */, kernel.data(),
292                       has_bias() ? bias.data() : nullptr, output_zero_point,
293                       output_scale, int8_t(qmin() - 0x80),
294                       int8_t(qmax() - 0x80),
295                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
296                       &caches, &fully_connected_op2));
297         ASSERT_NE(nullptr, fully_connected_op2);
298 
299         // Smart pointer to automatically delete fully_connected_op.
300         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
301             auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
302         std::vector<int8_t> output2(output.size(), INT8_C(0xA5));
303 
304         ASSERT_EQ(xnn_status_success,
305                   xnn_setup_fully_connected_nc_qs8(
306                       fully_connected_op2,
307                       batch_size(),
308                       input.data(), output2.data(),
309                       nullptr /* thread pool */));
310 
311         ASSERT_EQ(
312             xnn_status_success,
313             xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
314 
315         VerifyWeightsCache(weights_cache, old_weights_cache_size);
316         xnn_release_weights_cache(&weights_cache);
317 
318         VerifyQS8(output, output_ref, double(output_zero_point));
319       }
320     }
321   }
322 
VerifyQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,double output_zero_point)323   void VerifyQS8(const std::vector<int8_t>& output,
324                  const std::vector<double>& output_ref,
325                  double output_zero_point) const {
326     for (size_t i = 0; i < batch_size(); i++) {
327       for (size_t c = 0; c < output_channels(); c++) {
328         ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
329             << "batch index = " << i << ", channel = " << c;
330         ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
331             << "batch index = " << i << ", channel = " << c;
332         ASSERT_NEAR(output_ref[i * output_channels() + c],
333                     double(output[i * output_stride() + c]) - output_zero_point,
334                     0.9)
335             << "batch index = " << i << ", channel = " << c;
336       }
337     }
338   }
339 
TestQU8()340   void TestQU8() const {
341     ASSERT_EQ(weights_type(), WeightsType::Default);
342 
343     std::random_device random_device;
344     auto rng = std::mt19937(random_device());
345     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
346     std::uniform_int_distribution<int32_t> u8dist(
347       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
348 
349     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
350       (batch_size() - 1) * input_stride() + input_channels());
351     std::vector<uint8_t> kernel(output_channels() * input_channels());
352     std::vector<int32_t> bias(output_channels());
353     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
354     std::vector<int32_t> accumulators(batch_size() * output_channels());
355     std::vector<double> output_ref(batch_size() * output_channels());
356 
357     const uint8_t input_zero_point = 127;
358     const uint8_t kernel_zero_point = 127;
359 
360     for (size_t iteration = 0; iteration < iterations(); iteration++) {
361       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
362       std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
363       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
364       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
365 
366       // Compute reference results, without renormalization.
367       if (has_bias()) {
368         for (size_t i = 0; i < batch_size(); i++) {
369           for (size_t oc = 0; oc < output_channels(); oc++) {
370             accumulators[i * output_channels() + oc] = bias[oc];
371           }
372         }
373       } else {
374         std::fill(accumulators.begin(), accumulators.end(), 0);
375       }
376       if (transpose_weights()) {
377         for (size_t i = 0; i < batch_size(); i++) {
378           for (size_t oc = 0; oc < output_channels(); oc++) {
379             for (size_t ic = 0; ic < input_channels(); ic++) {
380               accumulators[i * output_channels() + oc] +=
381                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
382                 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
383             }
384           }
385         }
386       } else {
387         for (size_t i = 0; i < batch_size(); i++) {
388           for (size_t oc = 0; oc < output_channels(); oc++) {
389             for (size_t ic = 0; ic < input_channels(); ic++) {
390               accumulators[i * output_channels() + oc] +=
391                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
392                 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
393             }
394           }
395         }
396       }
397 
398       // Compute renormalization parameters.
399       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
400       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
401 
402       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
403       const uint8_t output_zero_point = uint8_t(std::max(std::min(
404         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
405         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
406 
407       // Renormalize reference results.
408       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
409         [this, output_scale, output_zero_point](int32_t x) -> double {
410           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
411         });
412 
413       // Create, setup, run, and destroy Fully Connected operator.
414       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
415       xnn_operator_t fully_connected_op = nullptr;
416 
417       xnn_caches caches = {
418         .code_cache = NULL,
419         .weights_cache = NULL,
420       };
421       xnn_weights_cache weights_cache;
422       if (use_weights_cache()) {
423         xnn_init_weights_cache(&weights_cache);
424         caches.weights_cache = &weights_cache;
425       }
426 
427       const xnn_status status = xnn_create_fully_connected_nc_qu8(
428           input_channels(), output_channels(),
429           input_stride(), output_stride(),
430           input_zero_point, 1.0f /* input scale */,
431           kernel_zero_point, 1.0f /* kernel scale */,
432           kernel.data(), has_bias() ? bias.data() : nullptr,
433           output_zero_point, output_scale, qmin(), qmax(),
434           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
435           &caches,
436           &fully_connected_op);
437       if (status == xnn_status_unsupported_hardware) {
438         GTEST_SKIP();
439       }
440       ASSERT_EQ(xnn_status_success, status);
441       ASSERT_NE(nullptr, fully_connected_op);
442       if (use_weights_cache()) {
443         ASSERT_EQ(xnn_status_success,
444                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
445       }
446 
447       // Smart pointer to automatically delete fully_connected_op.
448       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
449 
450       ASSERT_EQ(xnn_status_success,
451         xnn_setup_fully_connected_nc_qu8(
452           fully_connected_op,
453           batch_size(),
454           input.data(), output.data(),
455           nullptr /* thread pool */));
456 
457       ASSERT_EQ(xnn_status_success,
458         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
459 
460       VerifyQU8(output, output_ref, double(output_zero_point));
461 
462       if (use_weights_cache()) {
463         // Create another operator with the same weights cache.
464         xnn_operator_t fully_connected_op2 = nullptr;
465         size_t old_weights_cache_size = weights_cache.cache.weights.size;
466 
467         ASSERT_EQ(xnn_status_success,
468                   xnn_create_fully_connected_nc_qu8(
469                       input_channels(), output_channels(), input_stride(),
470                       output_stride(), input_zero_point, 1.0f /* input scale */,
471                       kernel_zero_point, 1.0f /* kernel scale */, kernel.data(),
472                       has_bias() ? bias.data() : nullptr, output_zero_point,
473                       output_scale, qmin(), qmax(),
474                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
475                       &caches, &fully_connected_op2));
476         ASSERT_NE(nullptr, fully_connected_op2);
477 
478         // Smart pointer to automatically delete fully_connected_op.
479         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
480             auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
481         std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5));
482 
483         ASSERT_EQ(xnn_status_success,
484                   xnn_setup_fully_connected_nc_qu8(
485                       fully_connected_op2, batch_size(), input.data(),
486                       output2.data(), nullptr /* thread pool */));
487 
488         ASSERT_EQ(
489             xnn_status_success,
490             xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
491 
492         VerifyWeightsCache(weights_cache, old_weights_cache_size);
493         xnn_release_weights_cache(&weights_cache);
494 
495         VerifyQU8(output2, output_ref, double(output_zero_point));
496       }
497 
498     }
499   }
500 
VerifyQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,double output_zero_point)501   void VerifyQU8(const std::vector<uint8_t>& output,
502                  const std::vector<double>& output_ref,
503                  double output_zero_point) const {
504     for (size_t i = 0; i < batch_size(); i++) {
505       for (size_t c = 0; c < output_channels(); c++) {
506         ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
507             << "batch index = " << i << ", channel = " << c;
508         ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
509             << "batch index = " << i << ", channel = " << c;
510         ASSERT_NEAR(output_ref[i * output_channels() + c],
511                     double(output[i * output_stride() + c]) - output_zero_point,
512                     0.9)
513             << "batch index = " << i << ", channel = " << c;
514       }
515     }
516   }
517 
TestF32()518   void TestF32() const {
519     ASSERT_EQ(weights_type(), WeightsType::Default);
520 
521     std::random_device random_device;
522     auto rng = std::mt19937(random_device());
523     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
524 
525     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
526       (batch_size() - 1) * input_stride() + input_channels());
527     std::vector<float> kernel(output_channels() * input_channels());
528     std::vector<float> bias(output_channels());
529     std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
530     std::vector<float> output_ref(batch_size() * output_channels());
531 
532     for (size_t iteration = 0; iteration < iterations(); iteration++) {
533       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
534       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
535       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
536       std::fill(output.begin(), output.end(), nanf(""));
537 
538       // Compute reference results, without renormalization.
539       if (has_bias()) {
540         for (size_t i = 0; i < batch_size(); i++) {
541           for (size_t oc = 0; oc < output_channels(); oc++) {
542             output_ref[i * output_channels() + oc] = bias[oc];
543           }
544         }
545       } else {
546         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
547       }
548       if (transpose_weights()) {
549         for (size_t i = 0; i < batch_size(); i++) {
550           for (size_t oc = 0; oc < output_channels(); oc++) {
551             for (size_t ic = 0; ic < input_channels(); ic++) {
552               output_ref[i * output_channels() + oc] +=
553                 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
554             }
555           }
556         }
557       } else {
558         for (size_t i = 0; i < batch_size(); i++) {
559           for (size_t oc = 0; oc < output_channels(); oc++) {
560             for (size_t ic = 0; ic < input_channels(); ic++) {
561               output_ref[i * output_channels() + oc] +=
562                 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
563             }
564           }
565         }
566       }
567 
568       // Compute clamping parameters.
569       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
570       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
571 
572       const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
573         accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
574       const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
575         accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
576 
577       // Clamp reference results.
578       for (float& value : output_ref) {
579         value = std::max(std::min(value, output_max), output_min);
580       }
581 
582       // Create, setup, run, and destroy Fully Connected operator.
583       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
584       xnn_operator_t fully_connected_op = nullptr;
585 
586       xnn_caches caches = {
587         .code_cache = NULL,
588         .weights_cache = NULL,
589       };
590       xnn_weights_cache weights_cache;
591       if (use_weights_cache()) {
592         xnn_init_weights_cache(&weights_cache);
593         caches.weights_cache = &weights_cache;
594       }
595 
596       const xnn_status status = xnn_create_fully_connected_nc_f32(
597           input_channels(), output_channels(),
598           input_stride(), output_stride(),
599           kernel.data(), has_bias() ? bias.data() : nullptr,
600           output_min, output_max,
601           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
602           &caches,
603           &fully_connected_op);
604       if (status == xnn_status_unsupported_hardware) {
605         GTEST_SKIP();
606       }
607       ASSERT_EQ(xnn_status_success, status);
608       ASSERT_NE(nullptr, fully_connected_op);
609       if (use_weights_cache()) {
610         ASSERT_EQ(xnn_status_success,
611                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
612       }
613 
614       // Smart pointer to automatically delete fully_connected_op.
615       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
616 
617       ASSERT_EQ(xnn_status_success,
618         xnn_setup_fully_connected_nc_f32(
619           fully_connected_op,
620           batch_size(),
621           input.data(), output.data(),
622           nullptr /* thread pool */));
623 
624       ASSERT_EQ(xnn_status_success,
625         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
626 
627       VerifyF32(output, output_ref, output_max, output_min);
628 
629       if (use_weights_cache()) {
630         // Create another operator with the same weights cache.
631         xnn_operator_t fully_connected_op2 = nullptr;
632         size_t old_weights_cache_size = weights_cache.cache.weights.size;
633         ASSERT_EQ(xnn_status_success,
634                   xnn_create_fully_connected_nc_f32(
635                       input_channels(), output_channels(), input_stride(),
636                       output_stride(), kernel.data(),
637                       has_bias() ? bias.data() : nullptr, output_min,
638                       output_max,
639                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
640                       &caches, &fully_connected_op2));
641         ASSERT_NE(nullptr, fully_connected_op2);
642 
643         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
644 
645         std::vector<float> output2(output.size(), nanf(""));
646         ASSERT_EQ(xnn_status_success,
647                   xnn_setup_fully_connected_nc_f32(
648                       fully_connected_op2,
649                       batch_size(),
650                       input.data(), output2.data(),
651                       nullptr /* thread pool */));
652 
653         ASSERT_EQ(xnn_status_success,
654                   xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
655         VerifyWeightsCache(weights_cache, old_weights_cache_size);
656         xnn_release_weights_cache(&weights_cache);
657 
658         VerifyF32(output, output_ref, output_max, output_min);
659       }
660     }
661   }
662 
VerifyF32(const std::vector<float> & output,const std::vector<float> & output_ref,float output_max,float output_min)663   void VerifyF32(const std::vector<float>& output,
664                  const std::vector<float>& output_ref,
665                  float output_max,
666                  float output_min) const {
667     // Verify results.
668     for (size_t i = 0; i < batch_size(); i++) {
669       for (size_t c = 0; c < output_channels(); c++) {
670         ASSERT_LE(output[i * output_stride() + c], output_max)
671             << "batch index = " << i << ", channel = " << c;
672         ASSERT_GE(output[i * output_stride() + c], output_min)
673             << "batch index = " << i << ", channel = " << c;
674         ASSERT_NEAR(output_ref[i * output_channels() + c],
675                     output[i * output_stride() + c],
676                     1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
677             << "batch index = " << i << ", channel = " << c;
678       }
679     }
680   }
681 
TestF16()682   void TestF16() const {
683     switch (weights_type()) {
684       case WeightsType::Default:
685         break;
686       case WeightsType::FP32:
687         break;
688       default:
689         GTEST_FAIL() << "unexpected weights type";
690     }
691 
692     std::random_device random_device;
693     auto rng = std::mt19937(random_device());
694     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
695 
696     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
697       (batch_size() - 1) * input_stride() + input_channels());
698     std::vector<uint16_t> kernel(output_channels() * input_channels());
699     std::vector<float> kernel_as_float(kernel.size());
700     std::vector<uint16_t> bias(output_channels());
701     std::vector<float> bias_as_float(bias.size());
702     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + output_channels());
703     std::vector<float> output_ref(batch_size() * output_channels());
704 
705     for (size_t iteration = 0; iteration < iterations(); iteration++) {
706       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
707       std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
708       std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value);
709       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
710       std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value);
711       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
712 
713       // Compute reference results, without renormalization.
714       if (has_bias()) {
715         for (size_t i = 0; i < batch_size(); i++) {
716           for (size_t oc = 0; oc < output_channels(); oc++) {
717             output_ref[i * output_channels() + oc] = fp16_ieee_to_fp32_value(bias[oc]);
718           }
719         }
720       } else {
721         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
722       }
723       if (transpose_weights()) {
724         for (size_t i = 0; i < batch_size(); i++) {
725           for (size_t oc = 0; oc < output_channels(); oc++) {
726             for (size_t ic = 0; ic < input_channels(); ic++) {
727               output_ref[i * output_channels() + oc] +=
728                 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[ic * output_channels() + oc]);
729             }
730           }
731         }
732       } else {
733         for (size_t i = 0; i < batch_size(); i++) {
734           for (size_t oc = 0; oc < output_channels(); oc++) {
735             for (size_t ic = 0; ic < input_channels(); ic++) {
736               output_ref[i * output_channels() + oc] +=
737                 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[oc * input_channels() + ic]);
738             }
739           }
740         }
741       }
742 
743       // Compute clamping parameters.
744       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
745       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
746       const float accumulated_range = accumulated_max - accumulated_min;
747       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
748       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
749       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
750       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
751 
752       // Clamp reference results.
753       for (float& value : output_ref) {
754         value = std::max(std::min(value, output_max), output_min);
755       }
756 
757       // Create, setup, run, and destroy Fully Connected operator.
758       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
759       xnn_operator_t fully_connected_op = nullptr;
760 
761       xnn_caches caches = {
762         .code_cache = NULL,
763         .weights_cache = NULL,
764       };
765       xnn_weights_cache weights_cache;
766       if (use_weights_cache()) {
767         xnn_init_weights_cache(&weights_cache);
768         caches.weights_cache = &weights_cache;
769       }
770 
771       const void* kernel_data = kernel.data();
772       const void* bias_data = bias.data();
773       if (weights_type() == WeightsType::FP32) {
774         kernel_data = kernel_as_float.data();
775         bias_data = bias_as_float.data();
776       }
777       uint32_t flags = 0;
778       if (transpose_weights()) {
779         flags |= XNN_FLAG_TRANSPOSE_WEIGHTS;
780       }
781       if (weights_type() == WeightsType::FP32) {
782         flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
783       }
784       const xnn_status status = xnn_create_fully_connected_nc_f16(
785           input_channels(), output_channels(),
786           input_stride(), output_stride(),
787           kernel_data, has_bias() ? bias_data : nullptr,
788           output_min, output_max,
789           flags,
790           &caches,
791           &fully_connected_op);
792       if (status == xnn_status_unsupported_hardware) {
793         GTEST_SKIP();
794       }
795       ASSERT_EQ(xnn_status_success, status);
796       ASSERT_NE(nullptr, fully_connected_op);
797       if (use_weights_cache()) {
798         ASSERT_EQ(xnn_status_success,
799                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
800       }
801 
802       // Smart pointer to automatically delete fully_connected_op.
803       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
804 
805       ASSERT_EQ(xnn_status_success,
806         xnn_setup_fully_connected_nc_f16(
807           fully_connected_op,
808           batch_size(),
809           input.data(), output.data(),
810           nullptr /* thread pool */));
811 
812       ASSERT_EQ(xnn_status_success,
813         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
814 
815       // Verify results.
816       VerifyF16(output, output_ref, output_max, output_min);
817 
818       if (use_weights_cache()) {
819         xnn_operator_t fully_connected_op2 = nullptr;
820         size_t old_weights_cache_size = weights_cache.cache.weights.size;
821         ASSERT_EQ(xnn_status_success,
822                   xnn_create_fully_connected_nc_f16(
823                       input_channels(), output_channels(), input_stride(),
824                       output_stride(), kernel_data,
825                       has_bias() ? bias_data : nullptr, output_min, output_max,
826                       flags, &caches, &fully_connected_op2));
827         if (status == xnn_status_unsupported_hardware) {
828           GTEST_SKIP();
829         }
830         ASSERT_EQ(xnn_status_success, status);
831         ASSERT_NE(nullptr, fully_connected_op2);
832 
833         // Smart pointer to automatically delete fully_connected_op2.
834         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
835         std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */);
836 
837         ASSERT_EQ(xnn_status_success,
838                   xnn_setup_fully_connected_nc_f16(
839                       fully_connected_op2,
840                       batch_size(),
841                       input.data(), output2.data(),
842                       nullptr /* thread pool */));
843 
844         ASSERT_EQ(xnn_status_success,
845                   xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
846 
847         // Verify results.
848         VerifyF16(output2, output_ref, output_max, output_min);
849         VerifyWeightsCache(weights_cache, old_weights_cache_size);
850         xnn_release_weights_cache(&weights_cache);
851       }
852     }
853   }
854 
VerifyF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_max,const float output_min)855   void VerifyF16(const std::vector<uint16_t>& output,
856                  const std::vector<float>& output_ref,
857                  const float output_max,
858                  const float output_min) const {
859     for (size_t i = 0; i < batch_size(); i++) {
860       for (size_t c = 0; c < output_channels(); c++) {
861         ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max)
862           << "batch index = " << i << ", channel = " << c;
863         ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min)
864           << "batch index = " << i << ", channel = " << c;
865         ASSERT_NEAR(
866             output_ref[i * output_channels() + c],
867             fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
868             1.0e-2f * std::abs(output_ref[i * output_channels() + c]))
869           << "batch index = " << i << ", channel = " << c;
870       }
871     }
872   }
873 
VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)874   void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const {
875     ASSERT_EQ(weights_cache.cache.hits, 1);
876     // Ensure that we did not write more weights to the cache because it was a cache hit.
877     ASSERT_EQ(old_size, weights_cache.cache.weights.size);
878   };
879 
880  private:
881   size_t input_channels_{1};
882   size_t input_stride_{0};
883   size_t output_channels_{1};
884   size_t output_stride_{0};
885   size_t batch_size_{1};
886   uint8_t qmin_{0};
887   uint8_t qmax_{255};
888   bool transpose_weights_{false};
889   bool has_bias_{true};
890   WeightsType weights_type_{WeightsType::Default};
891   bool use_weights_cache_{false};
892   size_t iterations_{1};
893 };
894