1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <stdint.h> 12 #include <stddef.h> 13 14 #include <xnnpack/common.h> 15 #include <xnnpack/operator.h> 16 17 18 #ifdef __cplusplus 19 extern "C" { 20 #endif 21 22 23 struct xnn_qu8_packing_params { 24 uint8_t input_zero_point; 25 uint8_t kernel_zero_point; 26 }; 27 28 struct xnn_qs8_packing_params { 29 int8_t input_zero_point; 30 }; 31 32 33 typedef void (*xnn_pack_gemm_goi_w_function)( 34 size_t g, 35 size_t nc, 36 size_t kc, 37 size_t nr, 38 size_t kr, 39 size_t sr, 40 const void* k, 41 const void* b, 42 void* packed_w, 43 size_t extra_bytes, 44 const void* params); 45 46 XNN_INTERNAL void xnn_pack_f32_gemm_goi_w( 47 size_t g, 48 size_t nc, 49 size_t kc, 50 size_t nr, 51 size_t kr, 52 size_t sr, 53 const float* k, 54 const float* b, 55 float* packed_w, 56 size_t extra_bytes, 57 const void* params); 58 59 XNN_INTERNAL void xnn_pack_f16_gemm_goi_w( 60 size_t g, 61 size_t nc, 62 size_t kc, 63 size_t nr, 64 size_t kr, 65 size_t sr, 66 const uint16_t* k, 67 const uint16_t* b, 68 uint16_t* packed_w, 69 size_t extra_bytes, 70 const void* params); 71 72 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_goi_w( 73 size_t g, 74 size_t nc, 75 size_t kc, 76 size_t nr, 77 size_t kr, 78 size_t sr, 79 const float* k, 80 const float* b, 81 uint16_t* packed_w, 82 size_t extra_bytes, 83 const void* params); 84 85 XNN_INTERNAL void xnn_pack_qu8_gemm_goi_w( 86 size_t g, 87 size_t nc, 88 size_t kc, 89 size_t nr, 90 size_t kr, 91 size_t sr, 92 const uint8_t* k, 93 const int32_t* b, 94 void* packed_w, 95 size_t extra_bytes, 96 const struct xnn_qu8_packing_params* params); 97 98 XNN_INTERNAL void xnn_pack_qs8_gemm_goi_w( 99 size_t g, 100 size_t nc, 101 size_t kc, 102 size_t nr, 103 size_t kr, 104 size_t sr, 105 const int8_t* k, 106 const int32_t* b, 107 void* packed_w, 108 size_t extra_bytes, 109 const struct xnn_qs8_packing_params* params); 110 111 XNN_INTERNAL void xnn_pack_qs8_gemm_xw_goi_w( 112 size_t g, 113 size_t nc, 114 size_t kc, 115 size_t nr, 116 size_t kr, 117 size_t sr, 118 const int8_t* k, 119 const int32_t* b, 120 void* packed_w, 121 size_t extra_bytes, 122 const struct xnn_qs8_packing_params* params); 123 124 125 typedef void (*xnn_pack_gemm_io_w_function)( 126 size_t nc, 127 size_t kc, 128 size_t nr, 129 size_t kr, 130 size_t sr, 131 const void* k, 132 const void* b, 133 void* packed_w, 134 const void* params); 135 136 XNN_INTERNAL void xnn_pack_f32_gemm_io_w( 137 size_t nc, 138 size_t kc, 139 size_t nr, 140 size_t kr, 141 size_t sr, 142 const float* k, 143 const float* b, 144 float* packed_w, 145 const void* params); 146 147 XNN_INTERNAL void xnn_pack_f16_gemm_io_w( 148 size_t nc, 149 size_t kc, 150 size_t nr, 151 size_t kr, 152 size_t sr, 153 const uint16_t* k, 154 const uint16_t* b, 155 uint16_t* packed_w, 156 const void* params); 157 158 XNN_INTERNAL void xnn_pack_f32_to_f16_gemm_io_w( 159 size_t nc, 160 size_t kc, 161 size_t nr, 162 size_t kr, 163 size_t sr, 164 const float* k, 165 const float* b, 166 uint16_t* packed_w, 167 const void* params); 168 169 XNN_INTERNAL void xnn_pack_qu8_gemm_io_w( 170 size_t nc, 171 size_t kc, 172 size_t nr, 173 size_t kr, 174 size_t sr, 175 const uint8_t* k, 176 const int32_t* b, 177 void* packed_w, 178 const struct xnn_qu8_packing_params* params); 179 180 XNN_INTERNAL void xnn_pack_qs8_gemm_io_w( 181 size_t nc, 182 size_t kc, 183 size_t nr, 184 size_t kr, 185 size_t sr, 186 const int8_t* k, 187 const int32_t* b, 188 void* packed_w, 189 const struct xnn_qs8_packing_params* params); 190 191 192 typedef void (*xnn_pack_conv_goki_w_function)( 193 size_t g, 194 size_t nc, 195 size_t ks, 196 size_t kc, 197 size_t nr, 198 size_t kr, 199 size_t sr, 200 const void* k, 201 const void* b, 202 void* packed_w, 203 size_t extra_bytes, 204 const void* params); 205 206 XNN_INTERNAL void xnn_pack_f32_conv_goki_w( 207 size_t g, 208 size_t nc, 209 size_t ks, 210 size_t kc, 211 size_t nr, 212 size_t kr, 213 size_t sr, 214 const float* k, 215 const float* b, 216 float* packed_w, 217 size_t extra_bytes, 218 const void* params); 219 220 XNN_INTERNAL void xnn_pack_f16_conv_goki_w( 221 size_t g, 222 size_t nc, 223 size_t ks, 224 size_t kc, 225 size_t nr, 226 size_t kr, 227 size_t sr, 228 const uint16_t* k, 229 const uint16_t* b, 230 uint16_t* packed_w, 231 size_t extra_bytes, 232 const void* params); 233 234 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_goki_w( 235 size_t g, 236 size_t nc, 237 size_t ks, 238 size_t kc, 239 size_t nr, 240 size_t kr, 241 size_t sr, 242 const float* k, 243 const float* b, 244 uint16_t* packed_w, 245 size_t extra_bytes, 246 const void* params); 247 248 XNN_INTERNAL void xnn_pack_qu8_conv_goki_w( 249 size_t g, 250 size_t nc, 251 size_t ks, 252 size_t kc, 253 size_t nr, 254 size_t kr, 255 size_t sr, 256 const uint8_t* k, 257 const int32_t* b, 258 void* packed_w, 259 size_t extra_bytes, 260 const struct xnn_qu8_packing_params* params); 261 262 XNN_INTERNAL void xnn_pack_qs8_conv_goki_w( 263 size_t g, 264 size_t nc, 265 size_t ks, 266 size_t kc, 267 size_t nr, 268 size_t kr, 269 size_t sr, 270 const int8_t* k, 271 const int32_t* b, 272 void* packed_w, 273 size_t extra_bytes, 274 const struct xnn_qs8_packing_params* params); 275 276 277 typedef void (*xnn_pack_conv_kgo_w_function)( 278 size_t g, 279 size_t nc, 280 size_t ks, 281 size_t nr, 282 size_t kr, 283 size_t sr, 284 const void* k, 285 const void* b, 286 void* packed_w, 287 size_t extra_bytes, 288 const void* params); 289 290 XNN_INTERNAL void xnn_pack_f32_conv_kgo_w( 291 size_t g, 292 size_t nc, 293 size_t ks, 294 size_t nr, 295 size_t kr, 296 size_t sr, 297 const float* k, 298 const float* b, 299 float* packed_w, 300 size_t extra_bytes, 301 const void* params); 302 303 XNN_INTERNAL void xnn_pack_f16_conv_kgo_w( 304 size_t g, 305 size_t nc, 306 size_t ks, 307 size_t nr, 308 size_t kr, 309 size_t sr, 310 const uint16_t* k, 311 const uint16_t* b, 312 uint16_t* packed_w, 313 size_t extra_bytes, 314 const void* params); 315 316 XNN_INTERNAL void xnn_pack_f32_to_f16_conv_kgo_w( 317 size_t g, 318 size_t nc, 319 size_t ks, 320 size_t nr, 321 size_t kr, 322 size_t sr, 323 const float* k, 324 const float* b, 325 uint16_t* packed_w, 326 size_t extra_bytes, 327 const void* params); 328 329 XNN_INTERNAL void xnn_pack_qu8_conv_kgo_w( 330 size_t g, 331 size_t nc, 332 size_t ks, 333 size_t nr, 334 size_t kr, 335 size_t sr, 336 const uint8_t* k, 337 const int32_t* b, 338 void* packed_w, 339 size_t extra_bytes, 340 const struct xnn_qu8_packing_params* params); 341 342 XNN_INTERNAL void xnn_pack_qs8_conv_kgo_w( 343 size_t g, 344 size_t nc, 345 size_t ks, 346 size_t nr, 347 size_t kr, 348 size_t sr, 349 const int8_t* k, 350 const int32_t* b, 351 void* packed_w, 352 size_t extra_bytes, 353 const struct xnn_qs8_packing_params* params); 354 355 356 typedef void (*xnn_pack_deconv_goki_w_function)( 357 size_t g, 358 size_t nc, 359 size_t kh, 360 size_t kw, 361 size_t kc, 362 size_t sh, 363 size_t sw, 364 size_t nr, 365 size_t kr, 366 size_t sr, 367 const void* k, 368 const void* b, 369 void* packed_w, 370 struct subconvolution_params* subconv_params, 371 const void* params); 372 373 XNN_INTERNAL void xnn_pack_f32_deconv_goki_w( 374 size_t g, 375 size_t nc, 376 size_t kh, 377 size_t kw, 378 size_t kc, 379 size_t sh, 380 size_t sw, 381 size_t nr, 382 size_t kr, 383 size_t sr, 384 const float* k, 385 const float* b, 386 float* packed_w, 387 struct subconvolution_params* subconv_params, 388 const void* params); 389 390 XNN_INTERNAL void xnn_pack_f16_deconv_goki_w( 391 size_t g, 392 size_t nc, 393 size_t kh, 394 size_t kw, 395 size_t kc, 396 size_t sh, 397 size_t sw, 398 size_t nr, 399 size_t kr, 400 size_t sr, 401 const uint16_t* k, 402 const uint16_t* b, 403 uint16_t* packed_w, 404 struct subconvolution_params* subconv_params, 405 const void* params); 406 407 XNN_INTERNAL void xnn_pack_f32_to_f16_deconv_goki_w( 408 size_t g, 409 size_t nc, 410 size_t kh, 411 size_t kw, 412 size_t kc, 413 size_t sh, 414 size_t sw, 415 size_t nr, 416 size_t kr, 417 size_t sr, 418 const float* k, 419 const float* b, 420 uint16_t* packed_w, 421 struct subconvolution_params* subconv_params, 422 const void* params); 423 424 XNN_INTERNAL void xnn_pack_qs8_deconv_goki_w( 425 size_t g, 426 size_t nc, 427 size_t kh, 428 size_t kw, 429 size_t kc, 430 size_t sh, 431 size_t sw, 432 size_t nr, 433 size_t kr, 434 size_t sr, 435 const int8_t* k, 436 const int32_t* b, 437 void* packed_w, 438 struct subconvolution_params* subconv_params, 439 const struct xnn_qs8_packing_params* params); 440 441 XNN_INTERNAL void xnn_pack_qu8_deconv_goki_w( 442 size_t g, 443 size_t nc, 444 size_t kh, 445 size_t kw, 446 size_t kc, 447 size_t sh, 448 size_t sw, 449 size_t nr, 450 size_t kr, 451 size_t sr, 452 const uint8_t* k, 453 const int32_t* b, 454 void* packed_w, 455 struct subconvolution_params* subconv_params, 456 const struct xnn_qu8_packing_params* params); 457 458 459 // Pack weights and bias such that: 460 // 1. Each block contains `cr` bias and `cr * h * w` weights. 461 // 2. Within each "cr block", `cr` biases are at the beginning of the block. 462 // 3. Weights are written such that the channel values at the same x-y is are adjacent in memory. 463 // 4. The weights are then written column major (WHC layout). 464 // "ghw" in the function name is the layout of the weights, (g)roups, (h)eight, (w)idth. 465 typedef void (*xnn_pack_dwconv_ghw_w_function)( 466 size_t primary_tile, 467 size_t h, 468 size_t w, 469 size_t c, 470 size_t cr, 471 const void* k, 472 const void* b, 473 void* packed_w, 474 size_t extra_bytes, 475 const void* params); 476 477 XNN_INTERNAL void xnn_pack_f32_dwconv_ghw_w( 478 size_t primary_tile, 479 size_t h, 480 size_t w, 481 size_t c, 482 size_t cr, 483 const float* k, 484 const float* b, 485 float* packed_w, 486 size_t extra_bytes, 487 const void* params); 488 489 XNN_INTERNAL void xnn_pack_f16_dwconv_ghw_w( 490 size_t primary_tile, 491 size_t h, 492 size_t w, 493 size_t c, 494 size_t cr, 495 const uint16_t* k, 496 const uint16_t* b, 497 uint16_t* packed_w, 498 size_t extra_bytes, 499 const void* params); 500 501 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_ghw_w( 502 size_t primary_tile, 503 size_t h, 504 size_t w, 505 size_t c, 506 size_t cr, 507 const float* k, 508 const float* b, 509 uint16_t* packed_w, 510 size_t extra_bytes, 511 const void* params); 512 513 XNN_INTERNAL void xnn_pack_qu8_dwconv_ghw_w( 514 size_t primary_tile, 515 size_t h, 516 size_t w, 517 size_t c, 518 size_t cr, 519 const uint8_t* k, 520 const int32_t* b, 521 void* packed_w, 522 size_t extra_bytes, 523 const struct xnn_qu8_packing_params* params); 524 525 XNN_INTERNAL void xnn_pack_qs8_dwconv_ghw_w( 526 size_t primary_tile, 527 size_t h, 528 size_t w, 529 size_t c, 530 size_t cr, 531 const int8_t* k, 532 const int32_t* b, 533 void* packed_w, 534 size_t extra_bytes, 535 const struct xnn_qs8_packing_params* params); 536 537 538 typedef void (*xnn_pack_dwconv_hwg_w_function)( 539 size_t primary_tile, 540 size_t h, 541 size_t w, 542 size_t c, 543 size_t cr, 544 const void* k, 545 const void* b, 546 void* packed_w, 547 size_t extra_bytes, 548 const void* params); 549 550 XNN_INTERNAL void xnn_pack_f32_dwconv_hwg_w( 551 size_t primary_tile, 552 size_t h, 553 size_t w, 554 size_t c, 555 size_t cr, 556 const float* k, 557 const float* b, 558 float* packed_w, 559 size_t extra_bytes, 560 const void* params); 561 562 XNN_INTERNAL void xnn_pack_f16_dwconv_hwg_w( 563 size_t primary_tile, 564 size_t h, 565 size_t w, 566 size_t c, 567 size_t cr, 568 const uint16_t* k, 569 const uint16_t* b, 570 uint16_t* packed_w, 571 size_t extra_bytes, 572 const void* params); 573 574 XNN_INTERNAL void xnn_pack_f32_to_f16_dwconv_hwg_w( 575 size_t primary_tile, 576 size_t h, 577 size_t w, 578 size_t c, 579 size_t cr, 580 const float* k, 581 const float* b, 582 uint16_t* packed_w, 583 size_t extra_bytes, 584 const void* params); 585 586 XNN_INTERNAL void xnn_pack_qu8_dwconv_hwg_w( 587 size_t primary_tile, 588 size_t h, 589 size_t w, 590 size_t c, 591 size_t cr, 592 const uint8_t* k, 593 const int32_t* b, 594 void* packed_w, 595 size_t extra_bytes, 596 const struct xnn_qu8_packing_params* params); 597 598 XNN_INTERNAL void xnn_pack_qs8_dwconv_hwg_w( 599 size_t primary_tile, 600 size_t h, 601 size_t w, 602 size_t c, 603 size_t cr, 604 const int8_t* k, 605 const int32_t* b, 606 void* packed_w, 607 size_t extra_bytes, 608 const struct xnn_qs8_packing_params* params); 609 610 611 XNN_INTERNAL void xnn_pack_f32_gemminc_goi_w( 612 size_t g, 613 size_t nc, 614 size_t kc, 615 size_t nr, 616 size_t kr, 617 size_t sr, 618 const float* k, 619 float* packed_w, 620 const void* params); 621 622 XNN_INTERNAL void xnn_pack_f16_gemminc_goi_w( 623 size_t g, 624 size_t nc, 625 size_t kc, 626 size_t nr, 627 size_t kr, 628 size_t sr, 629 const uint16_t* k, 630 uint16_t* packed_w, 631 const void* params); 632 633 634 XNN_INTERNAL void xnn_pack_f32_dconv_oki_w( 635 size_t nc, 636 size_t kc, 637 size_t nr, 638 size_t kh, 639 size_t kw, 640 const float* k, 641 const float* b, 642 float* packed_w, 643 const void* params); 644 645 XNN_INTERNAL void xnn_pack_f16_dconv_oki_w( 646 size_t nc, 647 size_t kc, 648 size_t nr, 649 size_t kh, 650 size_t kw, 651 const uint16_t* k, 652 const uint16_t* b, 653 uint16_t* packed_w, 654 const void* params); 655 656 657 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_ghw_w( 658 size_t kernel_size, 659 size_t groups, 660 const float* kernel, 661 const float* bias, 662 float* packed_weights, 663 const void* params); 664 665 XNN_INTERNAL void xnn_pack_f16_chw_dwconv_ghw_w( 666 size_t kernel_size, 667 size_t groups, 668 const uint16_t* kernel, 669 const uint16_t* bias, 670 uint16_t* packed_weights, 671 const void* params); 672 673 674 XNN_INTERNAL void xnn_pack_f32_chw_dwconv_hwg_w( 675 size_t kernel_size, 676 size_t groups, 677 const float* kernel, 678 const float* bias, 679 float* packed_weights, 680 const void* params); 681 682 683 typedef void (*xnn_pack_vmulcaddc_w_function)( 684 size_t c, 685 size_t cr, 686 const void* s, 687 const void* b, 688 void* packed_w, 689 const void* params); 690 691 XNN_INTERNAL void xnn_pack_f32_vmulcaddc_w( 692 size_t c, 693 size_t cr, 694 const float* s, 695 const float* b, 696 float* packed_w, 697 const void* params); 698 699 XNN_INTERNAL void xnn_pack_f16_vmulcaddc_w( 700 size_t c, 701 size_t cr, 702 const uint16_t* s, 703 const uint16_t* b, 704 uint16_t* packed_w, 705 const void* params); 706 707 XNN_INTERNAL void xnn_pack_f32_to_f16_vmulcaddc_w( 708 size_t c, 709 size_t cr, 710 const float* s, 711 const float* b, 712 uint16_t* packed_w, 713 const void* params); 714 715 716 typedef void (*xnn_pack_prelu_w_function)( 717 size_t c, 718 const void* s, 719 void* packed_w); 720 721 XNN_INTERNAL void xnn_pack_f32_prelu_w( 722 size_t c, 723 const float* s, 724 float* packed_w); 725 726 XNN_INTERNAL void xnn_pack_f16_prelu_w( 727 size_t c, 728 const uint16_t* s, 729 uint16_t* packed_w); 730 731 XNN_INTERNAL void xnn_pack_f32_to_f16_prelu_w( 732 size_t c, 733 const float* s, 734 uint16_t* packed_w); 735 736 737 #ifdef __cplusplus 738 } // extern "C" 739 #endif 740