1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <cstddef> 12 #include <cstdint> 13 14 #include <xnnpack/microfnptr.h> 15 #include <xnnpack/requantization.h> 16 17 18 class GemmMicrokernelTester { 19 public: mr(size_t mr)20 inline GemmMicrokernelTester& mr(size_t mr) { 21 this->mr_ = mr; 22 return *this; 23 } 24 mr()25 inline size_t mr() const { 26 return this->mr_; 27 } 28 nr(size_t nr)29 inline GemmMicrokernelTester& nr(size_t nr) { 30 this->nr_ = nr; 31 return *this; 32 } 33 nr()34 inline size_t nr() const { 35 return this->nr_; 36 } 37 38 kr(size_t kr)39 inline GemmMicrokernelTester& kr(size_t kr) { 40 this->kr_ = kr; 41 return *this; 42 } 43 kr()44 inline size_t kr() const { 45 return this->kr_; 46 } 47 sr(size_t sr)48 inline GemmMicrokernelTester& sr(size_t sr) { 49 this->sr_ = sr; 50 return *this; 51 } 52 sr()53 inline size_t sr() const { 54 return this->sr_; 55 } 56 m(size_t m)57 inline GemmMicrokernelTester& m(size_t m) { 58 this->m_ = m; 59 return *this; 60 } 61 m()62 inline size_t m() const { 63 return this->m_; 64 } 65 n(size_t n)66 inline GemmMicrokernelTester& n(size_t n) { 67 this->n_ = n; 68 return *this; 69 } 70 n()71 inline size_t n() const { 72 return this->n_; 73 } 74 k(size_t k)75 inline GemmMicrokernelTester& k(size_t k) { 76 this->k_ = k; 77 return *this; 78 } 79 k()80 inline size_t k() const { 81 return this->k_; 82 } 83 ks(size_t ks)84 inline GemmMicrokernelTester& ks(size_t ks) { 85 this->ks_ = ks; 86 return *this; 87 } 88 ks()89 inline size_t ks() const { 90 return this->ks_; 91 } 92 packed_k()93 inline size_t packed_k() const { 94 return round_up_po2(k(), kr() * sr()); 95 } 96 packed_n()97 inline size_t packed_n() const { 98 return round_up(n(), nr()); 99 } 100 a_stride(size_t a_stride)101 inline GemmMicrokernelTester& a_stride(size_t a_stride) { 102 this->a_stride_ = a_stride; 103 return *this; 104 } 105 a_stride()106 inline size_t a_stride() const { 107 return this->a_stride_ == 0 ? k() : this->a_stride_; 108 } 109 cm_stride(size_t cm_stride)110 inline GemmMicrokernelTester& cm_stride(size_t cm_stride) { 111 this->cm_stride_ = cm_stride; 112 return *this; 113 } 114 cm_stride()115 inline size_t cm_stride() const { 116 return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_; 117 } 118 cn_stride(size_t cn_stride)119 inline GemmMicrokernelTester& cn_stride(size_t cn_stride) { 120 this->cn_stride_ = cn_stride; 121 return *this; 122 } 123 cn_stride()124 inline size_t cn_stride() const { 125 return this->cn_stride_ == 0 ? nr() : this->cn_stride_; 126 } 127 a_zero_point(uint8_t a_zero_point)128 inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) { 129 this->a_zero_point_ = a_zero_point; 130 return *this; 131 } 132 a_zero_point()133 inline uint8_t a_zero_point() const { 134 return this->a_zero_point_; 135 } 136 b_zero_point(uint8_t b_zero_point)137 inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) { 138 this->b_zero_point_ = b_zero_point; 139 return *this; 140 } 141 b_zero_point()142 inline uint8_t b_zero_point() const { 143 return this->b_zero_point_; 144 } 145 qmin(uint8_t qmin)146 inline GemmMicrokernelTester& qmin(uint8_t qmin) { 147 this->qmin_ = qmin; 148 return *this; 149 } 150 qmin()151 inline uint8_t qmin() const { 152 return this->qmin_; 153 } 154 qmax(uint8_t qmax)155 inline GemmMicrokernelTester& qmax(uint8_t qmax) { 156 this->qmax_ = qmax; 157 return *this; 158 } 159 qmax()160 inline uint8_t qmax() const { 161 return this->qmax_; 162 } 163 a_offset(size_t a_offset)164 inline GemmMicrokernelTester& a_offset(size_t a_offset) { 165 this->a_offset_ = a_offset; 166 return *this; 167 } 168 a_offset()169 inline size_t a_offset() const { 170 return this->a_offset_; 171 } 172 zero_index(size_t zero_index)173 inline GemmMicrokernelTester& zero_index(size_t zero_index) { 174 this->zero_index_ = zero_index; 175 return *this; 176 } 177 zero_index()178 inline size_t zero_index() const { 179 return this->zero_index_; 180 } 181 extended_weights(bool extended_weights)182 inline GemmMicrokernelTester& extended_weights(bool extended_weights) { 183 this->extended_weights_ = extended_weights; 184 return *this; 185 } 186 extended_weights()187 inline bool extended_weights() const { 188 return this->extended_weights_; 189 } 190 iterations(size_t iterations)191 inline GemmMicrokernelTester& iterations(size_t iterations) { 192 this->iterations_ = iterations; 193 return *this; 194 } 195 iterations()196 inline size_t iterations() const { 197 return this->iterations_; 198 } 199 200 void Test( 201 xnn_qu8_gemm_minmax_ukernel_function gemm, 202 xnn_init_qu8_conv_minmax_params_fn init_params, 203 xnn_qu8_requantize_fn requantize) const; 204 205 void Test( 206 xnn_qu8_igemm_minmax_ukernel_function igemm, 207 xnn_init_qu8_conv_minmax_params_fn init_params, 208 xnn_qu8_requantize_fn requantize); 209 210 void Test( 211 xnn_qc8_gemm_minmax_ukernel_function gemm, 212 xnn_init_qc8_conv_minmax_params_fn init_params, 213 xnn_qs8_requantize_fn requantize) const; 214 215 void Test( 216 xnn_qc8_igemm_minmax_ukernel_function igemm, 217 xnn_init_qc8_conv_minmax_params_fn init_params, 218 xnn_qs8_requantize_fn requantize) const; 219 220 void Test( 221 xnn_qs8_gemm_minmax_ukernel_function gemm, 222 xnn_init_qs8_conv_minmax_params_fn init_params, 223 xnn_qs8_requantize_fn requantize) const; 224 225 void Test( 226 xnn_qs8_igemm_minmax_ukernel_function igemm, 227 xnn_init_qs8_conv_minmax_params_fn init_params, 228 xnn_qs8_requantize_fn requantize) const; 229 230 void Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const; 231 232 void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const; 233 234 void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_minmax_params_fn init_params) const; 235 236 void Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 237 238 void Test(xnn_f32_gemm_ukernel_function gemm) const; 239 240 void Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const; 241 242 void Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 243 244 void Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const; 245 246 void Test(xnn_f32_igemm_ukernel_function igemm) const; 247 248 void Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const; 249 250 void Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const; 251 252 #if XNN_PLATFORM_JIT 253 void Test( 254 xnn_jit_gemm_code_generator_function gemm_generator, 255 xnn_init_f32_minmax_params_fn init_params) const; 256 void Test( 257 xnn_jit_igemm_code_generator_function igemm_generator, 258 xnn_init_f32_minmax_params_fn init_params) const; 259 void Test( 260 xnn_jit_gemm_code_generator_function gemm_generator, 261 xnn_init_qc8_conv_minmax_params_fn init_params, 262 xnn_qs8_requantize_fn requantize) const; 263 void Test( 264 xnn_jit_igemm_code_generator_function igemm_generator, 265 xnn_init_qc8_conv_minmax_params_fn init_params, 266 xnn_qs8_requantize_fn requantize) const; 267 void Test( 268 xnn_jit_gemm_code_generator_function gemm_generator, 269 xnn_init_qs8_conv_minmax_params_fn init_params, 270 xnn_qs8_requantize_fn requantize) const; 271 void Test( 272 xnn_jit_igemm_code_generator_function igemm_generator, 273 xnn_init_qs8_conv_minmax_params_fn init_params, 274 xnn_qs8_requantize_fn requantize) const; 275 #endif // XNN_PLATFORM_JIT 276 277 private: 278 size_t mr_{1}; 279 size_t nr_{1}; 280 size_t kr_{1}; 281 size_t sr_{1}; 282 size_t m_{1}; 283 size_t n_{1}; 284 size_t k_{1}; 285 size_t ks_{1}; 286 size_t a_stride_{0}; 287 size_t cm_stride_{0}; 288 size_t cn_stride_{0}; 289 uint8_t a_zero_point_{127}; 290 uint8_t b_zero_point_{127}; 291 uint8_t qmin_{0}; 292 uint8_t qmax_{255}; 293 size_t a_offset_{0}; 294 size_t zero_index_{SIZE_MAX}; 295 bool extended_weights_{false}; 296 size_t iterations_{15}; 297 }; 298