1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <cassert> 14 #include <cstddef> 15 #include <cstdlib> 16 #include <algorithm> 17 #include <cmath> 18 #include <limits> 19 #include <random> 20 #include <vector> 21 22 #include <fp16.h> 23 24 #include <xnnpack.h> 25 #include <xnnpack/cache.h> 26 27 28 class FullyConnectedOperatorTester { 29 public: 30 enum class WeightsType { 31 Default, 32 FP32, 33 }; 34 input_channels(size_t input_channels)35 inline FullyConnectedOperatorTester& input_channels(size_t input_channels) { 36 assert(input_channels >= 1); 37 this->input_channels_ = input_channels; 38 return *this; 39 } 40 input_channels()41 inline size_t input_channels() const { 42 return this->input_channels_; 43 } 44 output_channels(size_t output_channels)45 inline FullyConnectedOperatorTester& output_channels(size_t output_channels) { 46 assert(output_channels >= 1); 47 this->output_channels_ = output_channels; 48 return *this; 49 } 50 output_channels()51 inline size_t output_channels() const { 52 return this->output_channels_; 53 } 54 batch_size(size_t batch_size)55 inline FullyConnectedOperatorTester& batch_size(size_t batch_size) { 56 assert(batch_size >= 1); 57 this->batch_size_ = batch_size; 58 return *this; 59 } 60 batch_size()61 inline size_t batch_size() const { 62 return this->batch_size_; 63 } 64 input_stride(size_t input_stride)65 inline FullyConnectedOperatorTester& input_stride(size_t input_stride) { 66 assert(input_stride >= 1); 67 this->input_stride_ = input_stride; 68 return *this; 69 } 70 input_stride()71 inline size_t input_stride() const { 72 if (this->input_stride_ == 0) { 73 return input_channels(); 74 } else { 75 assert(this->input_stride_ >= input_channels()); 76 return this->input_stride_; 77 } 78 } 79 output_stride(size_t output_stride)80 inline FullyConnectedOperatorTester& output_stride(size_t output_stride) { 81 assert(output_stride >= 1); 82 this->output_stride_ = output_stride; 83 return *this; 84 } 85 output_stride()86 inline size_t output_stride() const { 87 if (this->output_stride_ == 0) { 88 return output_channels(); 89 } else { 90 assert(this->output_stride_ >= output_channels()); 91 return this->output_stride_; 92 } 93 } 94 qmin(uint8_t qmin)95 inline FullyConnectedOperatorTester& qmin(uint8_t qmin) { 96 this->qmin_ = qmin; 97 return *this; 98 } 99 qmin()100 inline uint8_t qmin() const { 101 return this->qmin_; 102 } 103 qmax(uint8_t qmax)104 inline FullyConnectedOperatorTester& qmax(uint8_t qmax) { 105 this->qmax_ = qmax; 106 return *this; 107 } 108 qmax()109 inline uint8_t qmax() const { 110 return this->qmax_; 111 } 112 transpose_weights(bool transpose_weights)113 inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) { 114 this->transpose_weights_ = transpose_weights; 115 return *this; 116 } 117 transpose_weights()118 inline bool transpose_weights() const { 119 return this->transpose_weights_; 120 } 121 has_bias(bool has_bias)122 inline FullyConnectedOperatorTester& has_bias(bool has_bias) { 123 this->has_bias_ = has_bias; 124 return *this; 125 } 126 has_bias()127 inline bool has_bias() const { 128 return this->has_bias_; 129 } 130 weights_type(WeightsType weights_type)131 inline FullyConnectedOperatorTester& weights_type(WeightsType weights_type) { 132 this->weights_type_ = weights_type; 133 return *this; 134 } 135 weights_type()136 inline WeightsType weights_type() const { 137 return this->weights_type_; 138 } 139 use_weights_cache(bool use_weights_cache)140 inline FullyConnectedOperatorTester& use_weights_cache(bool use_weights_cache) { 141 this->use_weights_cache_ = use_weights_cache; 142 return *this; 143 } 144 use_weights_cache()145 inline bool use_weights_cache() const { 146 return this->use_weights_cache_; 147 } 148 iterations(size_t iterations)149 inline FullyConnectedOperatorTester& iterations(size_t iterations) { 150 this->iterations_ = iterations; 151 return *this; 152 } 153 iterations()154 inline size_t iterations() const { 155 return this->iterations_; 156 } 157 TestQS8()158 void TestQS8() const { 159 ASSERT_EQ(weights_type(), WeightsType::Default); 160 161 std::random_device random_device; 162 auto rng = std::mt19937(random_device()); 163 std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 164 std::uniform_int_distribution<int32_t> i8dist( 165 std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 166 std::uniform_int_distribution<int32_t> w8dist( 167 -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 168 169 std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 170 (batch_size() - 1) * input_stride() + input_channels()); 171 std::vector<int8_t> kernel(output_channels() * input_channels()); 172 std::vector<int32_t> bias(output_channels()); 173 std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels()); 174 std::vector<int32_t> accumulators(batch_size() * output_channels()); 175 std::vector<double> output_ref(batch_size() * output_channels()); 176 177 const int8_t input_zero_point = 127; 178 179 for (size_t iteration = 0; iteration < iterations(); iteration++) { 180 std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 181 std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 182 std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 183 std::fill(output.begin(), output.end(), INT8_C(0xA5)); 184 185 // Compute reference results, without renormalization. 186 if (has_bias()) { 187 for (size_t i = 0; i < batch_size(); i++) { 188 for (size_t oc = 0; oc < output_channels(); oc++) { 189 accumulators[i * output_channels() + oc] = bias[oc]; 190 } 191 } 192 } else { 193 std::fill(accumulators.begin(), accumulators.end(), 0); 194 } 195 if (transpose_weights()) { 196 for (size_t i = 0; i < batch_size(); i++) { 197 for (size_t oc = 0; oc < output_channels(); oc++) { 198 for (size_t ic = 0; ic < input_channels(); ic++) { 199 accumulators[i * output_channels() + oc] += 200 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 201 int32_t(kernel[ic * output_channels() + oc]); 202 } 203 } 204 } 205 } else { 206 for (size_t i = 0; i < batch_size(); i++) { 207 for (size_t oc = 0; oc < output_channels(); oc++) { 208 for (size_t ic = 0; ic < input_channels(); ic++) { 209 accumulators[i * output_channels() + oc] += 210 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 211 int32_t(kernel[oc * input_channels() + ic]); 212 } 213 } 214 } 215 } 216 217 // Compute renormalization parameters. 218 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 219 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 220 221 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 222 const int8_t output_zero_point = int8_t(std::max(std::min( 223 lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 224 long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min()))); 225 226 // Renormalize reference results. 227 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 228 [this, output_scale, output_zero_point](int32_t x) -> double { 229 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point); 230 }); 231 232 // Create, setup, run, and destroy Fully Connected operator. 233 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 234 xnn_operator_t fully_connected_op = nullptr; 235 236 xnn_caches caches = { 237 .code_cache = NULL, 238 .weights_cache = NULL, 239 }; 240 xnn_weights_cache weights_cache; 241 if (use_weights_cache()) { 242 xnn_init_weights_cache(&weights_cache); 243 caches.weights_cache = &weights_cache; 244 } 245 246 const xnn_status status = xnn_create_fully_connected_nc_qs8( 247 input_channels(), output_channels(), 248 input_stride(), output_stride(), 249 input_zero_point, 1.0f /* input scale */, 250 1.0f /* kernel scale */, 251 kernel.data(), has_bias() ? bias.data() : nullptr, 252 output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 253 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 254 &caches, 255 &fully_connected_op); 256 if (status == xnn_status_unsupported_hardware) { 257 GTEST_SKIP(); 258 } 259 ASSERT_EQ(xnn_status_success, status); 260 ASSERT_NE(nullptr, fully_connected_op); 261 if (use_weights_cache()) { 262 ASSERT_EQ(xnn_status_success, 263 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 264 } 265 266 // Smart pointer to automatically delete fully_connected_op. 267 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 268 269 ASSERT_EQ(xnn_status_success, 270 xnn_setup_fully_connected_nc_qs8( 271 fully_connected_op, 272 batch_size(), 273 input.data(), output.data(), 274 nullptr /* thread pool */)); 275 276 ASSERT_EQ(xnn_status_success, 277 xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 278 279 // Verify results. 280 VerifyQS8(output, output_ref, double(output_zero_point)); 281 282 if (use_weights_cache()) { 283 // Create another operator with the same weights cache. 284 xnn_operator_t fully_connected_op2 = nullptr; 285 size_t old_weights_cache_size = weights_cache.cache.weights.size; 286 287 ASSERT_EQ(xnn_status_success, 288 xnn_create_fully_connected_nc_qs8( 289 input_channels(), output_channels(), input_stride(), 290 output_stride(), input_zero_point, 1.0f /* input scale */, 291 1.0f /* kernel scale */, kernel.data(), 292 has_bias() ? bias.data() : nullptr, output_zero_point, 293 output_scale, int8_t(qmin() - 0x80), 294 int8_t(qmax() - 0x80), 295 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 296 &caches, &fully_connected_op2)); 297 ASSERT_NE(nullptr, fully_connected_op2); 298 299 // Smart pointer to automatically delete fully_connected_op. 300 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 301 auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 302 std::vector<int8_t> output2(output.size(), INT8_C(0xA5)); 303 304 ASSERT_EQ(xnn_status_success, 305 xnn_setup_fully_connected_nc_qs8( 306 fully_connected_op2, 307 batch_size(), 308 input.data(), output2.data(), 309 nullptr /* thread pool */)); 310 311 ASSERT_EQ( 312 xnn_status_success, 313 xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 314 315 VerifyWeightsCache(weights_cache, old_weights_cache_size); 316 xnn_release_weights_cache(&weights_cache); 317 318 VerifyQS8(output, output_ref, double(output_zero_point)); 319 } 320 } 321 } 322 VerifyQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,double output_zero_point)323 void VerifyQS8(const std::vector<int8_t>& output, 324 const std::vector<double>& output_ref, 325 double output_zero_point) const { 326 for (size_t i = 0; i < batch_size(); i++) { 327 for (size_t c = 0; c < output_channels(); c++) { 328 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80)) 329 << "batch index = " << i << ", channel = " << c; 330 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80)) 331 << "batch index = " << i << ", channel = " << c; 332 ASSERT_NEAR(output_ref[i * output_channels() + c], 333 double(output[i * output_stride() + c]) - output_zero_point, 334 0.9) 335 << "batch index = " << i << ", channel = " << c; 336 } 337 } 338 } 339 TestQU8()340 void TestQU8() const { 341 ASSERT_EQ(weights_type(), WeightsType::Default); 342 343 std::random_device random_device; 344 auto rng = std::mt19937(random_device()); 345 std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 346 std::uniform_int_distribution<int32_t> u8dist( 347 std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 348 349 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 350 (batch_size() - 1) * input_stride() + input_channels()); 351 std::vector<uint8_t> kernel(output_channels() * input_channels()); 352 std::vector<int32_t> bias(output_channels()); 353 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels()); 354 std::vector<int32_t> accumulators(batch_size() * output_channels()); 355 std::vector<double> output_ref(batch_size() * output_channels()); 356 357 const uint8_t input_zero_point = 127; 358 const uint8_t kernel_zero_point = 127; 359 360 for (size_t iteration = 0; iteration < iterations(); iteration++) { 361 std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 362 std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); }); 363 std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 364 std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 365 366 // Compute reference results, without renormalization. 367 if (has_bias()) { 368 for (size_t i = 0; i < batch_size(); i++) { 369 for (size_t oc = 0; oc < output_channels(); oc++) { 370 accumulators[i * output_channels() + oc] = bias[oc]; 371 } 372 } 373 } else { 374 std::fill(accumulators.begin(), accumulators.end(), 0); 375 } 376 if (transpose_weights()) { 377 for (size_t i = 0; i < batch_size(); i++) { 378 for (size_t oc = 0; oc < output_channels(); oc++) { 379 for (size_t ic = 0; ic < input_channels(); ic++) { 380 accumulators[i * output_channels() + oc] += 381 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 382 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point)); 383 } 384 } 385 } 386 } else { 387 for (size_t i = 0; i < batch_size(); i++) { 388 for (size_t oc = 0; oc < output_channels(); oc++) { 389 for (size_t ic = 0; ic < input_channels(); ic++) { 390 accumulators[i * output_channels() + oc] += 391 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) * 392 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point)); 393 } 394 } 395 } 396 } 397 398 // Compute renormalization parameters. 399 const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 400 const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 401 402 const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 403 const uint8_t output_zero_point = uint8_t(std::max(std::min( 404 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 405 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min()))); 406 407 // Renormalize reference results. 408 std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 409 [this, output_scale, output_zero_point](int32_t x) -> double { 410 return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point); 411 }); 412 413 // Create, setup, run, and destroy Fully Connected operator. 414 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 415 xnn_operator_t fully_connected_op = nullptr; 416 417 xnn_caches caches = { 418 .code_cache = NULL, 419 .weights_cache = NULL, 420 }; 421 xnn_weights_cache weights_cache; 422 if (use_weights_cache()) { 423 xnn_init_weights_cache(&weights_cache); 424 caches.weights_cache = &weights_cache; 425 } 426 427 const xnn_status status = xnn_create_fully_connected_nc_qu8( 428 input_channels(), output_channels(), 429 input_stride(), output_stride(), 430 input_zero_point, 1.0f /* input scale */, 431 kernel_zero_point, 1.0f /* kernel scale */, 432 kernel.data(), has_bias() ? bias.data() : nullptr, 433 output_zero_point, output_scale, qmin(), qmax(), 434 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 435 &caches, 436 &fully_connected_op); 437 if (status == xnn_status_unsupported_hardware) { 438 GTEST_SKIP(); 439 } 440 ASSERT_EQ(xnn_status_success, status); 441 ASSERT_NE(nullptr, fully_connected_op); 442 if (use_weights_cache()) { 443 ASSERT_EQ(xnn_status_success, 444 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 445 } 446 447 // Smart pointer to automatically delete fully_connected_op. 448 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 449 450 ASSERT_EQ(xnn_status_success, 451 xnn_setup_fully_connected_nc_qu8( 452 fully_connected_op, 453 batch_size(), 454 input.data(), output.data(), 455 nullptr /* thread pool */)); 456 457 ASSERT_EQ(xnn_status_success, 458 xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 459 460 VerifyQU8(output, output_ref, double(output_zero_point)); 461 462 if (use_weights_cache()) { 463 // Create another operator with the same weights cache. 464 xnn_operator_t fully_connected_op2 = nullptr; 465 size_t old_weights_cache_size = weights_cache.cache.weights.size; 466 467 ASSERT_EQ(xnn_status_success, 468 xnn_create_fully_connected_nc_qu8( 469 input_channels(), output_channels(), input_stride(), 470 output_stride(), input_zero_point, 1.0f /* input scale */, 471 kernel_zero_point, 1.0f /* kernel scale */, kernel.data(), 472 has_bias() ? bias.data() : nullptr, output_zero_point, 473 output_scale, qmin(), qmax(), 474 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 475 &caches, &fully_connected_op2)); 476 ASSERT_NE(nullptr, fully_connected_op2); 477 478 // Smart pointer to automatically delete fully_connected_op. 479 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 480 auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 481 std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5)); 482 483 ASSERT_EQ(xnn_status_success, 484 xnn_setup_fully_connected_nc_qu8( 485 fully_connected_op2, batch_size(), input.data(), 486 output2.data(), nullptr /* thread pool */)); 487 488 ASSERT_EQ( 489 xnn_status_success, 490 xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 491 492 VerifyWeightsCache(weights_cache, old_weights_cache_size); 493 xnn_release_weights_cache(&weights_cache); 494 495 VerifyQU8(output2, output_ref, double(output_zero_point)); 496 } 497 498 } 499 } 500 VerifyQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,double output_zero_point)501 void VerifyQU8(const std::vector<uint8_t>& output, 502 const std::vector<double>& output_ref, 503 double output_zero_point) const { 504 for (size_t i = 0; i < batch_size(); i++) { 505 for (size_t c = 0; c < output_channels(); c++) { 506 ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax())) 507 << "batch index = " << i << ", channel = " << c; 508 ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin())) 509 << "batch index = " << i << ", channel = " << c; 510 ASSERT_NEAR(output_ref[i * output_channels() + c], 511 double(output[i * output_stride() + c]) - output_zero_point, 512 0.9) 513 << "batch index = " << i << ", channel = " << c; 514 } 515 } 516 } 517 TestF32()518 void TestF32() const { 519 ASSERT_EQ(weights_type(), WeightsType::Default); 520 521 std::random_device random_device; 522 auto rng = std::mt19937(random_device()); 523 std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 524 525 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 526 (batch_size() - 1) * input_stride() + input_channels()); 527 std::vector<float> kernel(output_channels() * input_channels()); 528 std::vector<float> bias(output_channels()); 529 std::vector<float> output((batch_size() - 1) * output_stride() + output_channels()); 530 std::vector<float> output_ref(batch_size() * output_channels()); 531 532 for (size_t iteration = 0; iteration < iterations(); iteration++) { 533 std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 534 std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 535 std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 536 std::fill(output.begin(), output.end(), nanf("")); 537 538 // Compute reference results, without renormalization. 539 if (has_bias()) { 540 for (size_t i = 0; i < batch_size(); i++) { 541 for (size_t oc = 0; oc < output_channels(); oc++) { 542 output_ref[i * output_channels() + oc] = bias[oc]; 543 } 544 } 545 } else { 546 std::fill(output_ref.begin(), output_ref.end(), 0.0f); 547 } 548 if (transpose_weights()) { 549 for (size_t i = 0; i < batch_size(); i++) { 550 for (size_t oc = 0; oc < output_channels(); oc++) { 551 for (size_t ic = 0; ic < input_channels(); ic++) { 552 output_ref[i * output_channels() + oc] += 553 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc]; 554 } 555 } 556 } 557 } else { 558 for (size_t i = 0; i < batch_size(); i++) { 559 for (size_t oc = 0; oc < output_channels(); oc++) { 560 for (size_t ic = 0; ic < input_channels(); ic++) { 561 output_ref[i * output_channels() + oc] += 562 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic]; 563 } 564 } 565 } 566 } 567 568 // Compute clamping parameters. 569 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 570 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 571 572 const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() : 573 accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 574 const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() : 575 accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 576 577 // Clamp reference results. 578 for (float& value : output_ref) { 579 value = std::max(std::min(value, output_max), output_min); 580 } 581 582 // Create, setup, run, and destroy Fully Connected operator. 583 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 584 xnn_operator_t fully_connected_op = nullptr; 585 586 xnn_caches caches = { 587 .code_cache = NULL, 588 .weights_cache = NULL, 589 }; 590 xnn_weights_cache weights_cache; 591 if (use_weights_cache()) { 592 xnn_init_weights_cache(&weights_cache); 593 caches.weights_cache = &weights_cache; 594 } 595 596 const xnn_status status = xnn_create_fully_connected_nc_f32( 597 input_channels(), output_channels(), 598 input_stride(), output_stride(), 599 kernel.data(), has_bias() ? bias.data() : nullptr, 600 output_min, output_max, 601 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 602 &caches, 603 &fully_connected_op); 604 if (status == xnn_status_unsupported_hardware) { 605 GTEST_SKIP(); 606 } 607 ASSERT_EQ(xnn_status_success, status); 608 ASSERT_NE(nullptr, fully_connected_op); 609 if (use_weights_cache()) { 610 ASSERT_EQ(xnn_status_success, 611 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 612 } 613 614 // Smart pointer to automatically delete fully_connected_op. 615 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 616 617 ASSERT_EQ(xnn_status_success, 618 xnn_setup_fully_connected_nc_f32( 619 fully_connected_op, 620 batch_size(), 621 input.data(), output.data(), 622 nullptr /* thread pool */)); 623 624 ASSERT_EQ(xnn_status_success, 625 xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 626 627 VerifyF32(output, output_ref, output_max, output_min); 628 629 if (use_weights_cache()) { 630 // Create another operator with the same weights cache. 631 xnn_operator_t fully_connected_op2 = nullptr; 632 size_t old_weights_cache_size = weights_cache.cache.weights.size; 633 ASSERT_EQ(xnn_status_success, 634 xnn_create_fully_connected_nc_f32( 635 input_channels(), output_channels(), input_stride(), 636 output_stride(), kernel.data(), 637 has_bias() ? bias.data() : nullptr, output_min, 638 output_max, 639 transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0, 640 &caches, &fully_connected_op2)); 641 ASSERT_NE(nullptr, fully_connected_op2); 642 643 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 644 645 std::vector<float> output2(output.size(), nanf("")); 646 ASSERT_EQ(xnn_status_success, 647 xnn_setup_fully_connected_nc_f32( 648 fully_connected_op2, 649 batch_size(), 650 input.data(), output2.data(), 651 nullptr /* thread pool */)); 652 653 ASSERT_EQ(xnn_status_success, 654 xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 655 VerifyWeightsCache(weights_cache, old_weights_cache_size); 656 xnn_release_weights_cache(&weights_cache); 657 658 VerifyF32(output, output_ref, output_max, output_min); 659 } 660 } 661 } 662 VerifyF32(const std::vector<float> & output,const std::vector<float> & output_ref,float output_max,float output_min)663 void VerifyF32(const std::vector<float>& output, 664 const std::vector<float>& output_ref, 665 float output_max, 666 float output_min) const { 667 // Verify results. 668 for (size_t i = 0; i < batch_size(); i++) { 669 for (size_t c = 0; c < output_channels(); c++) { 670 ASSERT_LE(output[i * output_stride() + c], output_max) 671 << "batch index = " << i << ", channel = " << c; 672 ASSERT_GE(output[i * output_stride() + c], output_min) 673 << "batch index = " << i << ", channel = " << c; 674 ASSERT_NEAR(output_ref[i * output_channels() + c], 675 output[i * output_stride() + c], 676 1.0e-4 * std::abs(output_ref[i * output_channels() + c])) 677 << "batch index = " << i << ", channel = " << c; 678 } 679 } 680 } 681 TestF16()682 void TestF16() const { 683 switch (weights_type()) { 684 case WeightsType::Default: 685 break; 686 case WeightsType::FP32: 687 break; 688 default: 689 GTEST_FAIL() << "unexpected weights type"; 690 } 691 692 std::random_device random_device; 693 auto rng = std::mt19937(random_device()); 694 std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 695 696 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 697 (batch_size() - 1) * input_stride() + input_channels()); 698 std::vector<uint16_t> kernel(output_channels() * input_channels()); 699 std::vector<float> kernel_as_float(kernel.size()); 700 std::vector<uint16_t> bias(output_channels()); 701 std::vector<float> bias_as_float(bias.size()); 702 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + output_channels()); 703 std::vector<float> output_ref(batch_size() * output_channels()); 704 705 for (size_t iteration = 0; iteration < iterations(); iteration++) { 706 std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 707 std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 708 std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value); 709 std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 710 std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value); 711 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 712 713 // Compute reference results, without renormalization. 714 if (has_bias()) { 715 for (size_t i = 0; i < batch_size(); i++) { 716 for (size_t oc = 0; oc < output_channels(); oc++) { 717 output_ref[i * output_channels() + oc] = fp16_ieee_to_fp32_value(bias[oc]); 718 } 719 } 720 } else { 721 std::fill(output_ref.begin(), output_ref.end(), 0.0f); 722 } 723 if (transpose_weights()) { 724 for (size_t i = 0; i < batch_size(); i++) { 725 for (size_t oc = 0; oc < output_channels(); oc++) { 726 for (size_t ic = 0; ic < input_channels(); ic++) { 727 output_ref[i * output_channels() + oc] += 728 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[ic * output_channels() + oc]); 729 } 730 } 731 } 732 } else { 733 for (size_t i = 0; i < batch_size(); i++) { 734 for (size_t oc = 0; oc < output_channels(); oc++) { 735 for (size_t ic = 0; ic < input_channels(); ic++) { 736 output_ref[i * output_channels() + oc] += 737 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[oc * input_channels() + ic]); 738 } 739 } 740 } 741 } 742 743 // Compute clamping parameters. 744 const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 745 const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 746 const float accumulated_range = accumulated_max - accumulated_min; 747 const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 748 const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 749 const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min; 750 const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max; 751 752 // Clamp reference results. 753 for (float& value : output_ref) { 754 value = std::max(std::min(value, output_max), output_min); 755 } 756 757 // Create, setup, run, and destroy Fully Connected operator. 758 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 759 xnn_operator_t fully_connected_op = nullptr; 760 761 xnn_caches caches = { 762 .code_cache = NULL, 763 .weights_cache = NULL, 764 }; 765 xnn_weights_cache weights_cache; 766 if (use_weights_cache()) { 767 xnn_init_weights_cache(&weights_cache); 768 caches.weights_cache = &weights_cache; 769 } 770 771 const void* kernel_data = kernel.data(); 772 const void* bias_data = bias.data(); 773 if (weights_type() == WeightsType::FP32) { 774 kernel_data = kernel_as_float.data(); 775 bias_data = bias_as_float.data(); 776 } 777 uint32_t flags = 0; 778 if (transpose_weights()) { 779 flags |= XNN_FLAG_TRANSPOSE_WEIGHTS; 780 } 781 if (weights_type() == WeightsType::FP32) { 782 flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 783 } 784 const xnn_status status = xnn_create_fully_connected_nc_f16( 785 input_channels(), output_channels(), 786 input_stride(), output_stride(), 787 kernel_data, has_bias() ? bias_data : nullptr, 788 output_min, output_max, 789 flags, 790 &caches, 791 &fully_connected_op); 792 if (status == xnn_status_unsupported_hardware) { 793 GTEST_SKIP(); 794 } 795 ASSERT_EQ(xnn_status_success, status); 796 ASSERT_NE(nullptr, fully_connected_op); 797 if (use_weights_cache()) { 798 ASSERT_EQ(xnn_status_success, 799 xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 800 } 801 802 // Smart pointer to automatically delete fully_connected_op. 803 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator); 804 805 ASSERT_EQ(xnn_status_success, 806 xnn_setup_fully_connected_nc_f16( 807 fully_connected_op, 808 batch_size(), 809 input.data(), output.data(), 810 nullptr /* thread pool */)); 811 812 ASSERT_EQ(xnn_status_success, 813 xnn_run_operator(fully_connected_op, nullptr /* thread pool */)); 814 815 // Verify results. 816 VerifyF16(output, output_ref, output_max, output_min); 817 818 if (use_weights_cache()) { 819 xnn_operator_t fully_connected_op2 = nullptr; 820 size_t old_weights_cache_size = weights_cache.cache.weights.size; 821 ASSERT_EQ(xnn_status_success, 822 xnn_create_fully_connected_nc_f16( 823 input_channels(), output_channels(), input_stride(), 824 output_stride(), kernel_data, 825 has_bias() ? bias_data : nullptr, output_min, output_max, 826 flags, &caches, &fully_connected_op2)); 827 if (status == xnn_status_unsupported_hardware) { 828 GTEST_SKIP(); 829 } 830 ASSERT_EQ(xnn_status_success, status); 831 ASSERT_NE(nullptr, fully_connected_op2); 832 833 // Smart pointer to automatically delete fully_connected_op2. 834 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator); 835 std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */); 836 837 ASSERT_EQ(xnn_status_success, 838 xnn_setup_fully_connected_nc_f16( 839 fully_connected_op2, 840 batch_size(), 841 input.data(), output2.data(), 842 nullptr /* thread pool */)); 843 844 ASSERT_EQ(xnn_status_success, 845 xnn_run_operator(fully_connected_op2, nullptr /* thread pool */)); 846 847 // Verify results. 848 VerifyF16(output2, output_ref, output_max, output_min); 849 VerifyWeightsCache(weights_cache, old_weights_cache_size); 850 xnn_release_weights_cache(&weights_cache); 851 } 852 } 853 } 854 VerifyF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_max,const float output_min)855 void VerifyF16(const std::vector<uint16_t>& output, 856 const std::vector<float>& output_ref, 857 const float output_max, 858 const float output_min) const { 859 for (size_t i = 0; i < batch_size(); i++) { 860 for (size_t c = 0; c < output_channels(); c++) { 861 ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max) 862 << "batch index = " << i << ", channel = " << c; 863 ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min) 864 << "batch index = " << i << ", channel = " << c; 865 ASSERT_NEAR( 866 output_ref[i * output_channels() + c], 867 fp16_ieee_to_fp32_value(output[i * output_stride() + c]), 868 1.0e-2f * std::abs(output_ref[i * output_channels() + c])) 869 << "batch index = " << i << ", channel = " << c; 870 } 871 } 872 } 873 VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)874 void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const { 875 ASSERT_EQ(weights_cache.cache.hits, 1); 876 // Ensure that we did not write more weights to the cache because it was a cache hit. 877 ASSERT_EQ(old_size, weights_cache.cache.weights.size); 878 }; 879 880 private: 881 size_t input_channels_{1}; 882 size_t input_stride_{0}; 883 size_t output_channels_{1}; 884 size_t output_stride_{0}; 885 size_t batch_size_{1}; 886 uint8_t qmin_{0}; 887 uint8_t qmax_{255}; 888 bool transpose_weights_{false}; 889 bool has_bias_{true}; 890 WeightsType weights_type_{WeightsType::Default}; 891 bool use_weights_cache_{false}; 892 size_t iterations_{1}; 893 }; 894