xref: /aosp_15_r20/external/XNNPACK/test/gemm-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <cstddef>
12 #include <cstdint>
13 
14 #include <xnnpack/microfnptr.h>
15 #include <xnnpack/requantization.h>
16 
17 
18 class GemmMicrokernelTester {
19  public:
mr(size_t mr)20   inline GemmMicrokernelTester& mr(size_t mr) {
21     this->mr_ = mr;
22     return *this;
23   }
24 
mr()25   inline size_t mr() const {
26     return this->mr_;
27   }
28 
nr(size_t nr)29   inline GemmMicrokernelTester& nr(size_t nr) {
30     this->nr_ = nr;
31     return *this;
32   }
33 
nr()34   inline size_t nr() const {
35     return this->nr_;
36   }
37 
38 
kr(size_t kr)39   inline GemmMicrokernelTester& kr(size_t kr) {
40     this->kr_ = kr;
41     return *this;
42   }
43 
kr()44   inline size_t kr() const {
45     return this->kr_;
46   }
47 
sr(size_t sr)48   inline GemmMicrokernelTester& sr(size_t sr) {
49     this->sr_ = sr;
50     return *this;
51   }
52 
sr()53   inline size_t sr() const {
54     return this->sr_;
55   }
56 
m(size_t m)57   inline GemmMicrokernelTester& m(size_t m) {
58     this->m_ = m;
59     return *this;
60   }
61 
m()62   inline size_t m() const {
63     return this->m_;
64   }
65 
n(size_t n)66   inline GemmMicrokernelTester& n(size_t n) {
67     this->n_ = n;
68     return *this;
69   }
70 
n()71   inline size_t n() const {
72     return this->n_;
73   }
74 
k(size_t k)75   inline GemmMicrokernelTester& k(size_t k) {
76     this->k_ = k;
77     return *this;
78   }
79 
k()80   inline size_t k() const {
81     return this->k_;
82   }
83 
ks(size_t ks)84   inline GemmMicrokernelTester& ks(size_t ks) {
85     this->ks_ = ks;
86     return *this;
87   }
88 
ks()89   inline size_t ks() const {
90     return this->ks_;
91   }
92 
packed_k()93   inline size_t packed_k() const {
94     return round_up_po2(k(), kr() * sr());
95   }
96 
packed_n()97   inline size_t packed_n() const {
98     return round_up(n(), nr());
99   }
100 
a_stride(size_t a_stride)101   inline GemmMicrokernelTester& a_stride(size_t a_stride) {
102     this->a_stride_ = a_stride;
103     return *this;
104   }
105 
a_stride()106   inline size_t a_stride() const {
107     return this->a_stride_ == 0 ? k() : this->a_stride_;
108   }
109 
cm_stride(size_t cm_stride)110   inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
111     this->cm_stride_ = cm_stride;
112     return *this;
113   }
114 
cm_stride()115   inline size_t cm_stride() const {
116     return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
117   }
118 
cn_stride(size_t cn_stride)119   inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
120     this->cn_stride_ = cn_stride;
121     return *this;
122   }
123 
cn_stride()124   inline size_t cn_stride() const {
125     return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
126   }
127 
a_zero_point(uint8_t a_zero_point)128   inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
129     this->a_zero_point_ = a_zero_point;
130     return *this;
131   }
132 
a_zero_point()133   inline uint8_t a_zero_point() const {
134     return this->a_zero_point_;
135   }
136 
b_zero_point(uint8_t b_zero_point)137   inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
138     this->b_zero_point_ = b_zero_point;
139     return *this;
140   }
141 
b_zero_point()142   inline uint8_t b_zero_point() const {
143     return this->b_zero_point_;
144   }
145 
qmin(uint8_t qmin)146   inline GemmMicrokernelTester& qmin(uint8_t qmin) {
147     this->qmin_ = qmin;
148     return *this;
149   }
150 
qmin()151   inline uint8_t qmin() const {
152     return this->qmin_;
153   }
154 
qmax(uint8_t qmax)155   inline GemmMicrokernelTester& qmax(uint8_t qmax) {
156     this->qmax_ = qmax;
157     return *this;
158   }
159 
qmax()160   inline uint8_t qmax() const {
161     return this->qmax_;
162   }
163 
a_offset(size_t a_offset)164   inline GemmMicrokernelTester& a_offset(size_t a_offset) {
165     this->a_offset_ = a_offset;
166     return *this;
167   }
168 
a_offset()169   inline size_t a_offset() const {
170     return this->a_offset_;
171   }
172 
zero_index(size_t zero_index)173   inline GemmMicrokernelTester& zero_index(size_t zero_index) {
174     this->zero_index_ = zero_index;
175     return *this;
176   }
177 
zero_index()178   inline size_t zero_index() const {
179     return this->zero_index_;
180   }
181 
extended_weights(bool extended_weights)182   inline GemmMicrokernelTester& extended_weights(bool extended_weights) {
183     this->extended_weights_ = extended_weights;
184     return *this;
185   }
186 
extended_weights()187   inline bool extended_weights() const {
188     return this->extended_weights_;
189   }
190 
iterations(size_t iterations)191   inline GemmMicrokernelTester& iterations(size_t iterations) {
192     this->iterations_ = iterations;
193     return *this;
194   }
195 
iterations()196   inline size_t iterations() const {
197     return this->iterations_;
198   }
199 
200   void Test(
201     xnn_qu8_gemm_minmax_ukernel_function gemm,
202     xnn_init_qu8_conv_minmax_params_fn init_params,
203     xnn_qu8_requantize_fn requantize) const;
204 
205   void Test(
206     xnn_qu8_igemm_minmax_ukernel_function igemm,
207     xnn_init_qu8_conv_minmax_params_fn init_params,
208     xnn_qu8_requantize_fn requantize);
209 
210   void Test(
211     xnn_qc8_gemm_minmax_ukernel_function gemm,
212     xnn_init_qc8_conv_minmax_params_fn init_params,
213     xnn_qs8_requantize_fn requantize) const;
214 
215   void Test(
216     xnn_qc8_igemm_minmax_ukernel_function igemm,
217     xnn_init_qc8_conv_minmax_params_fn init_params,
218     xnn_qs8_requantize_fn requantize) const;
219 
220   void Test(
221     xnn_qs8_gemm_minmax_ukernel_function gemm,
222     xnn_init_qs8_conv_minmax_params_fn init_params,
223     xnn_qs8_requantize_fn requantize) const;
224 
225   void Test(
226     xnn_qs8_igemm_minmax_ukernel_function igemm,
227     xnn_init_qs8_conv_minmax_params_fn init_params,
228     xnn_qs8_requantize_fn requantize) const;
229 
230   void Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const;
231 
232   void Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const;
233 
234   void Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_minmax_params_fn init_params) const;
235 
236   void Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
237 
238   void Test(xnn_f32_gemm_ukernel_function gemm) const;
239 
240   void Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const;
241 
242   void Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
243 
244   void Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const;
245 
246   void Test(xnn_f32_igemm_ukernel_function igemm) const;
247 
248   void Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const;
249 
250   void Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const;
251 
252 #if XNN_PLATFORM_JIT
253   void Test(
254     xnn_jit_gemm_code_generator_function gemm_generator,
255     xnn_init_f32_minmax_params_fn init_params) const;
256   void Test(
257     xnn_jit_igemm_code_generator_function igemm_generator,
258     xnn_init_f32_minmax_params_fn init_params) const;
259   void Test(
260     xnn_jit_gemm_code_generator_function gemm_generator,
261     xnn_init_qc8_conv_minmax_params_fn init_params,
262     xnn_qs8_requantize_fn requantize) const;
263   void Test(
264     xnn_jit_igemm_code_generator_function igemm_generator,
265     xnn_init_qc8_conv_minmax_params_fn init_params,
266     xnn_qs8_requantize_fn requantize) const;
267   void Test(
268     xnn_jit_gemm_code_generator_function gemm_generator,
269     xnn_init_qs8_conv_minmax_params_fn init_params,
270     xnn_qs8_requantize_fn requantize) const;
271   void Test(
272     xnn_jit_igemm_code_generator_function igemm_generator,
273     xnn_init_qs8_conv_minmax_params_fn init_params,
274     xnn_qs8_requantize_fn requantize) const;
275 #endif  // XNN_PLATFORM_JIT
276 
277  private:
278   size_t mr_{1};
279   size_t nr_{1};
280   size_t kr_{1};
281   size_t sr_{1};
282   size_t m_{1};
283   size_t n_{1};
284   size_t k_{1};
285   size_t ks_{1};
286   size_t a_stride_{0};
287   size_t cm_stride_{0};
288   size_t cn_stride_{0};
289   uint8_t a_zero_point_{127};
290   uint8_t b_zero_point_{127};
291   uint8_t qmin_{0};
292   uint8_t qmax_{255};
293   size_t a_offset_{0};
294   size_t zero_index_{SIZE_MAX};
295   bool extended_weights_{false};
296   size_t iterations_{15};
297 };
298