xref: /aosp_15_r20/external/XNNPACK/test/gemm-microkernel-tester.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 #include "gemm-microkernel-tester.h"
2 
3 #include <gtest/gtest.h>
4 
5 #include <algorithm>
6 #include <cassert>
7 #include <cmath>
8 #include <cstddef>
9 #include <cstdlib>
10 #include <limits>
11 #include <numeric>
12 #include <random>
13 #include <vector>
14 
15 #include <fp16.h>
16 
17 #include <xnnpack.h>
18 #include <xnnpack/allocator.h>
19 #include <xnnpack/aligned-allocator.h>
20 #include <xnnpack/pack.h>
21 #include <xnnpack/microfnptr.h>
22 #include <xnnpack/microparams-init.h>
23 #include <xnnpack/requantization.h>
24 
25 
Test(xnn_qu8_gemm_minmax_ukernel_function gemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize) const26 void GemmMicrokernelTester::Test(
27   xnn_qu8_gemm_minmax_ukernel_function gemm,
28   xnn_init_qu8_conv_minmax_params_fn init_params,
29   xnn_qu8_requantize_fn requantize) const
30 {
31   ASSERT_LE(m(), mr());
32 
33   std::random_device random_device;
34   auto rng = std::mt19937(random_device());
35   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
36   auto u8rng = std::bind(
37     std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
38 
39   std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
40   std::vector<uint8_t> b(n() * k());
41   std::vector<int32_t> bias(n());
42   std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
43   std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
44   std::vector<int32_t> acc(m() * n());
45   std::vector<uint8_t> c_ref(m() * n());
46 
47   for (size_t iteration = 0; iteration < iterations(); iteration++) {
48     do {
49       std::generate(a.begin(), a.end(), std::ref(u8rng));
50     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
51     do {
52       std::generate(b.begin(), b.end(), std::ref(u8rng));
53     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
54     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
55     std::fill(c.begin(), c.end(), 0xA5);
56 
57     std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
58     const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
59     xnn_pack_qu8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
60       b.data(), bias.data(), packed_w.data(), 0, &packing_params);
61 
62     // Compute 32-bit results and output quantization arguments.
63     std::fill(acc.begin(), acc.end(), 0);
64     for (size_t m_index = 0; m_index < m(); m_index++) {
65       for (size_t n_index = 0; n_index < n(); n_index++) {
66         for (size_t k_index = 0; k_index < k(); k_index++) {
67           acc[m_index * n() + n_index] +=
68               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
69               (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
70         }
71         acc[m_index * n() + n_index] += bias[n_index];
72       }
73     }
74 
75     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
76     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
77     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
78     const uint8_t c_zero_point = uint8_t(std::max(std::min(
79       lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
80       long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
81 
82     const float requantization_scale = 1.0f / float(c_scale);
83     union xnn_qu8_conv_minmax_params quantization_params;
84     init_params(&quantization_params,
85       b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
86 
87     gemm(
88       m(), n(), k(),
89       a.data(), a_stride() * sizeof(uint8_t),
90       packed_w.data(),
91       c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
92       &quantization_params);
93 
94     for (size_t m_index = 0; m_index < m(); m_index++) {
95       for (size_t n_index = 0; n_index < n(); n_index++) {
96         c_ref[m_index * n() + n_index] = requantize(
97           acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
98       }
99     }
100 
101     for (size_t i = 0; i < m(); i++) {
102       for (size_t j = 0; j < n(); j++) {
103         ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
104         ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
105         ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
106             << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
107             << " (accumulator = " << acc[i * n() + j]
108             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
109             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
110             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
111       }
112     }
113   }
114 }
115 
Test(xnn_qu8_igemm_minmax_ukernel_function igemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)116 void GemmMicrokernelTester::Test(
117   xnn_qu8_igemm_minmax_ukernel_function igemm,
118   xnn_init_qu8_conv_minmax_params_fn init_params,
119   xnn_qu8_requantize_fn requantize)
120 {
121   ASSERT_LE(m(), mr());
122 
123   std::random_device random_device;
124   auto rng = std::mt19937(random_device());
125   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
126   auto u8rng = std::bind(
127     std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
128 
129   std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
130   std::vector<uint8_t> b(n() * ks() * k());
131   std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
132   std::vector<int32_t> bias(n());
133   std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
134   std::vector<int32_t> acc(m() * n());
135   std::vector<uint8_t> c_ref(m() * n());
136   std::vector<uint8_t> junk(k() + 8);
137   std::vector<const uint8_t*> im2col(mr() * ks());
138 
139   std::fill(junk.begin(), junk.end(), 0xA5);
140 
141   for (size_t iteration = 0; iteration < iterations(); iteration++) {
142     do {
143       std::generate(a.begin(), a.end(), std::ref(u8rng));
144     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
145     do {
146       std::generate(b.begin(), b.end(), std::ref(u8rng));
147     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
148     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
149     std::fill(c.begin(), c.end(), 0xA5);
150 
151     std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
152     const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
153     xnn_pack_qu8_conv_goki_w(
154       1, n(), ks(), k(), nr(), kr(), sr(),
155       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
156 
157     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
158       for (size_t m_index = 0; m_index < mr(); m_index++) {
159         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
160       }
161     }
162     std::shuffle(im2col.begin(), im2col.end(), rng);
163     if (zero_index() != SIZE_MAX) {
164       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
165         im2col[ks_index * mr() + zero_index()] = a.data();
166       }
167     }
168     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
169       for (size_t m_index = m(); m_index < mr(); m_index++) {
170         im2col[ks_index * mr() + m_index] = junk.data();
171       }
172     }
173 
174     // Compute 32-bit results and output quantization arguments.
175     std::fill(acc.begin(), acc.end(), 0);
176     for (size_t m_index = 0; m_index < m(); m_index++) {
177       for (size_t n_index = 0; n_index < n(); n_index++) {
178         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
179           for (size_t k_index = 0; k_index < k(); k_index++) {
180             if (im2col[ks_index * mr() + m_index] == a.data()) {
181               acc[m_index * n() + n_index] +=
182                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point())) *
183                 (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
184             } else {
185               acc[m_index * n() + n_index] +=
186                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point())) *
187                 (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
188             }
189           }
190         }
191         acc[m_index * n() + n_index] += bias[n_index];
192       }
193     }
194 
195     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
196     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
197     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
198     const uint8_t c_zero_point = uint8_t(std::max(std::min(
199       lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
200       long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
201 
202     const float requantization_scale = 1.0f / float(c_scale);
203     union xnn_qu8_conv_minmax_params quantization_params;
204     init_params(&quantization_params,
205       b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
206 
207     const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
208 
209     igemm(
210       m(), n(), k(), ks() * mr() * sizeof(void*),
211       im2col.data(), packed_w.data(),
212       c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
213       a_offset() * sizeof(uint8_t), zero_pointer,
214       &quantization_params);
215 
216     for (size_t m_index = 0; m_index < m(); m_index++) {
217       for (size_t n_index = 0; n_index < n(); n_index++) {
218         c_ref[m_index * n() + n_index] = requantize(
219           acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
220       }
221     }
222 
223     for (size_t i = 0; i < m(); i++) {
224       for (size_t j = 0; j < n(); j++) {
225         ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
226         ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
227         ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
228             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
229             << " (accumulator = " << acc[i * n() + j]
230             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
231             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
232             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
233       }
234     }
235   }
236 }
237 
Test(xnn_qc8_gemm_minmax_ukernel_function gemm,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const238 void GemmMicrokernelTester::Test(
239   xnn_qc8_gemm_minmax_ukernel_function gemm,
240   xnn_init_qc8_conv_minmax_params_fn init_params,
241   xnn_qs8_requantize_fn requantize) const
242 {
243   ASSERT_LE(m(), mr());
244 
245   std::random_device random_device;
246   auto rng = std::mt19937(random_device());
247   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
248   auto i8rng = std::bind(
249     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
250     std::ref(rng));
251   auto w8rng = std::bind(
252     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
253     std::ref(rng));
254 
255   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
256   std::vector<int8_t> b(n() * k());
257   std::vector<int32_t> bias(n());
258   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
259   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
260   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
261   std::vector<int32_t> acc(m() * n());
262   std::vector<float> scale(n());
263   std::vector<int8_t> c_ref(m() * n());
264 
265   for (size_t iteration = 0; iteration < iterations(); iteration++) {
266     do {
267       std::generate(a.begin(), a.end(), std::ref(i8rng));
268     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
269     do {
270       std::generate(b.begin(), b.end(), std::ref(w8rng));
271     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
272     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
273     std::fill(c.begin(), c.end(), 0xA5);
274 
275     std::fill(packed_w.begin(), packed_w.end(), 0);
276     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
277     if (extended_weights()) {
278       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
279         b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
280     } else {
281       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
282         b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
283     }
284 
285     // Compute 32-bit results and output quantization arguments.
286     std::fill(acc.begin(), acc.end(), 0);
287     for (size_t m_index = 0; m_index < m(); m_index++) {
288       for (size_t n_index = 0; n_index < n(); n_index++) {
289         for (size_t k_index = 0; k_index < k(); k_index++) {
290           acc[m_index * n() + n_index] +=
291               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
292               int32_t(b[n_index * k() + k_index]);
293         }
294         acc[m_index * n() + n_index] += bias[n_index];
295       }
296     }
297 
298     const int8_t c_zero_point = -1;
299     for (size_t n_index = 0; n_index < n(); n_index++) {
300       int32_t accumulated_min = acc[n_index];
301       int32_t accumulated_max = acc[n_index];
302       for (size_t m_index = 0; m_index < m(); m_index++) {
303         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
304         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
305       }
306       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
307       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
308       scale[n_index] = 1.0f / c_scale;
309     }
310 
311     if (extended_weights()) {
312       xnn_init_qc8_scale_fp32_params(
313         n(), nr(),
314         nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
315         (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
316     } else {
317       xnn_init_qc8_scale_fp32_params(
318         n(), nr(),
319         nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
320         (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
321     }
322 
323     union xnn_qc8_conv_minmax_params minmax_params;
324     init_params(&minmax_params,
325       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
326 
327     gemm(
328       m(), n(), k(),
329       a.data(), a_stride() * sizeof(int8_t),
330       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
331       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
332       &minmax_params);
333 
334     for (size_t m_index = 0; m_index < m(); m_index++) {
335       for (size_t n_index = 0; n_index < n(); n_index++) {
336         c_ref[m_index * n() + n_index] = requantize(
337           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
338       }
339     }
340 
341     for (size_t i = 0; i < m(); i++) {
342       for (size_t j = 0; j < n(); j++) {
343         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
344         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
345         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
346             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
347             << " (accumulator = " << acc[i * n() + j]
348             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
349             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
350             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
351       }
352     }
353   }
354 }
355 
Test(xnn_qc8_igemm_minmax_ukernel_function igemm,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const356 void GemmMicrokernelTester::Test(
357   xnn_qc8_igemm_minmax_ukernel_function igemm,
358   xnn_init_qc8_conv_minmax_params_fn init_params,
359   xnn_qs8_requantize_fn requantize) const
360 {
361   ASSERT_LE(m(), mr());
362 
363   std::random_device random_device;
364   auto rng = std::mt19937(random_device());
365   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
366   auto i8rng = std::bind(
367     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
368     std::ref(rng));
369   auto w8rng = std::bind(
370     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
371     std::ref(rng));
372 
373   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
374   std::vector<int8_t> b(n() * ks() * k());
375   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
376   std::vector<int32_t> bias(n());
377   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
378   std::vector<int32_t> acc(m() * n());
379   std::vector<float> scale(n());
380   std::vector<int8_t> c_ref(m() * n());
381   std::vector<int8_t> junk(k() + 8);
382   std::vector<const int8_t*> im2col(mr() * ks());
383 
384   std::fill(junk.begin(), junk.end(), 0xA5);
385 
386   for (size_t iteration = 0; iteration < iterations(); iteration++) {
387     do {
388       std::generate(a.begin(), a.end(), std::ref(i8rng));
389     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
390     do {
391       std::generate(b.begin(), b.end(), std::ref(w8rng));
392     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
393     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
394     std::fill(c.begin(), c.end(), 0xA5);
395 
396     std::fill(packed_w.begin(), packed_w.end(), 0);
397     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
398     xnn_pack_qs8_conv_goki_w(
399       1, n(), ks(), k(), nr(), kr(), sr(),
400       b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
401 
402     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
403       for (size_t m_index = 0; m_index < mr(); m_index++) {
404         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
405       }
406     }
407     std::shuffle(im2col.begin(), im2col.end(), rng);
408     if (zero_index() != SIZE_MAX) {
409       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
410         im2col[ks_index * mr() + zero_index()] = a.data();
411       }
412     }
413     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
414       for (size_t m_index = m(); m_index < mr(); m_index++) {
415         im2col[ks_index * mr() + m_index] = junk.data();
416       }
417     }
418 
419     // Compute 32-bit results and output quantization arguments.
420     std::fill(acc.begin(), acc.end(), 0);
421     for (size_t m_index = 0; m_index < m(); m_index++) {
422       for (size_t n_index = 0; n_index < n(); n_index++) {
423         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
424           for (size_t k_index = 0; k_index < k(); k_index++) {
425             if (im2col[ks_index * mr() + m_index] == a.data()) {
426               acc[m_index * n() + n_index] +=
427                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
428                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
429             } else {
430               acc[m_index * n() + n_index] +=
431                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
432                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
433             }
434           }
435         }
436         acc[m_index * n() + n_index] += bias[n_index];
437       }
438     }
439 
440     const int8_t c_zero_point = -1;
441     for (size_t n_index = 0; n_index < n(); n_index++) {
442       int32_t accumulated_min = acc[n_index];
443       int32_t accumulated_max = acc[n_index];
444       for (size_t m_index = 0; m_index < m(); m_index++) {
445         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
446         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
447       }
448       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
449       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
450       scale[n_index] = 1.0f / c_scale;
451     }
452 
453     xnn_init_qc8_scale_fp32_params(
454       n(), nr(),
455       nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
456       (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
457 
458     union xnn_qc8_conv_minmax_params minmax_params;
459     init_params(&minmax_params,
460       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
461 
462     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
463 
464     igemm(
465       m(), n(), k(), ks() * mr() * sizeof(void*),
466       im2col.data(), packed_w.data(),
467       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
468       a_offset() * sizeof(uint8_t), zero_pointer,
469       &minmax_params);
470 
471     for (size_t m_index = 0; m_index < m(); m_index++) {
472       for (size_t n_index = 0; n_index < n(); n_index++) {
473         c_ref[m_index * n() + n_index] = requantize(
474           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
475       }
476     }
477 
478     for (size_t i = 0; i < m(); i++) {
479       for (size_t j = 0; j < n(); j++) {
480         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
481         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
482         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
483             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
484             << " (accumulator = " << acc[i * n() + j]
485             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
486             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
487             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
488       }
489     }
490   }
491 }
492 
Test(xnn_qs8_gemm_minmax_ukernel_function gemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const493 void GemmMicrokernelTester::Test(
494   xnn_qs8_gemm_minmax_ukernel_function gemm,
495   xnn_init_qs8_conv_minmax_params_fn init_params,
496   xnn_qs8_requantize_fn requantize) const
497 {
498   ASSERT_LE(m(), mr());
499 
500   std::random_device random_device;
501   auto rng = std::mt19937(random_device());
502   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
503   auto i8rng = std::bind(
504     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
505     std::ref(rng));
506   auto w8rng = std::bind(
507     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
508     std::ref(rng));
509 
510   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
511   std::vector<int8_t> b(n() * k());
512   std::vector<int32_t> bias(n());
513   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
514   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
515   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
516   std::vector<int32_t> acc(m() * n());
517   std::vector<int8_t> c_ref(m() * n());
518 
519   for (size_t iteration = 0; iteration < iterations(); iteration++) {
520     do {
521       std::generate(a.begin(), a.end(), std::ref(i8rng));
522     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
523     do {
524       std::generate(b.begin(), b.end(), std::ref(w8rng));
525     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
526     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
527     std::fill(c.begin(), c.end(), 0xA5);
528 
529     std::fill(packed_w.begin(), packed_w.end(), 0);
530     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
531     if (extended_weights()) {
532       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
533         b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
534     } else {
535       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
536         b.data(), bias.data(), packed_w.data(), 0, &packing_params);
537     }
538 
539     // Compute 32-bit results and output quantization arguments.
540     std::fill(acc.begin(), acc.end(), 0);
541     for (size_t m_index = 0; m_index < m(); m_index++) {
542       for (size_t n_index = 0; n_index < n(); n_index++) {
543         for (size_t k_index = 0; k_index < k(); k_index++) {
544           acc[m_index * n() + n_index] +=
545               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
546               int32_t(b[n_index * k() + k_index]);
547         }
548         acc[m_index * n() + n_index] += bias[n_index];
549       }
550     }
551 
552     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
553     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
554     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
555     const int8_t c_zero_point = int8_t(std::max(std::min(
556       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
557       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
558 
559     const float requantization_scale = 1.0f / float(c_scale);
560     union xnn_qs8_conv_minmax_params quantization_params;
561     init_params(&quantization_params,
562       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
563 
564     gemm(
565       m(), n(), k(),
566       a.data(), a_stride() * sizeof(int8_t),
567       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
568       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
569       &quantization_params);
570 
571     for (size_t m_index = 0; m_index < m(); m_index++) {
572       for (size_t n_index = 0; n_index < n(); n_index++) {
573         c_ref[m_index * n() + n_index] = requantize(
574           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
575       }
576     }
577 
578     for (size_t i = 0; i < m(); i++) {
579       for (size_t j = 0; j < n(); j++) {
580         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
581         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
582         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
583             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
584             << " (accumulator = " << acc[i * n() + j]
585             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
586             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
587             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
588       }
589     }
590   }
591 }
592 
Test(xnn_qs8_igemm_minmax_ukernel_function igemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const593 void GemmMicrokernelTester::Test(
594   xnn_qs8_igemm_minmax_ukernel_function igemm,
595   xnn_init_qs8_conv_minmax_params_fn init_params,
596   xnn_qs8_requantize_fn requantize) const
597 {
598   ASSERT_LE(m(), mr());
599 
600   std::random_device random_device;
601   auto rng = std::mt19937(random_device());
602   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
603   auto i8rng = std::bind(
604     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
605     std::ref(rng));
606   auto w8rng = std::bind(
607     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
608     std::ref(rng));
609 
610   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
611   std::vector<int8_t> b(n() * ks() * k());
612   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
613   std::vector<int32_t> bias(n());
614   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
615   std::vector<int32_t> acc(m() * n());
616   std::vector<int8_t> c_ref(m() * n());
617   std::vector<int8_t> junk(k() + 8);
618   std::vector<const int8_t*> im2col(mr() * ks());
619 
620   std::fill(junk.begin(), junk.end(), 0xA5);
621 
622   for (size_t iteration = 0; iteration < iterations(); iteration++) {
623     do {
624       std::generate(a.begin(), a.end(), std::ref(i8rng));
625     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
626     do {
627       std::generate(b.begin(), b.end(), std::ref(w8rng));
628     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
629     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
630     std::fill(c.begin(), c.end(), 0xA5);
631 
632     std::fill(packed_w.begin(), packed_w.end(), 0);
633     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
634     xnn_pack_qs8_conv_goki_w(
635       1, n(), ks(), k(), nr(), kr(), sr(),
636       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
637 
638     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
639       for (size_t m_index = 0; m_index < mr(); m_index++) {
640         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
641       }
642     }
643     std::shuffle(im2col.begin(), im2col.end(), rng);
644     if (zero_index() != SIZE_MAX) {
645       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
646         im2col[ks_index * mr() + zero_index()] = a.data();
647       }
648     }
649     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
650       for (size_t m_index = m(); m_index < mr(); m_index++) {
651         im2col[ks_index * mr() + m_index] = junk.data();
652       }
653     }
654 
655     // Compute 32-bit results and output quantization arguments.
656     std::fill(acc.begin(), acc.end(), 0);
657     for (size_t m_index = 0; m_index < m(); m_index++) {
658       for (size_t n_index = 0; n_index < n(); n_index++) {
659         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
660           for (size_t k_index = 0; k_index < k(); k_index++) {
661             if (im2col[ks_index * mr() + m_index] == a.data()) {
662               acc[m_index * n() + n_index] +=
663                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
664                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
665             } else {
666               acc[m_index * n() + n_index] +=
667                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
668                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
669             }
670           }
671         }
672         acc[m_index * n() + n_index] += bias[n_index];
673       }
674     }
675 
676     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
677     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
678     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
679     const uint8_t c_zero_point = uint8_t(std::max(std::min(
680       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
681       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
682 
683     const float requantization_scale = 1.0f / float(c_scale);
684     union xnn_qs8_conv_minmax_params quantization_params;
685     init_params(&quantization_params,
686       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
687 
688     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
689 
690     igemm(
691       m(), n(), k(), ks() * mr() * sizeof(void*),
692       im2col.data(), packed_w.data(),
693       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
694       a_offset() * sizeof(uint8_t), zero_pointer,
695       &quantization_params);
696 
697     for (size_t m_index = 0; m_index < m(); m_index++) {
698       for (size_t n_index = 0; n_index < n(); n_index++) {
699         c_ref[m_index * n() + n_index] = requantize(
700           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
701       }
702     }
703 
704     for (size_t i = 0; i < m(); i++) {
705       for (size_t j = 0; j < n(); j++) {
706         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
707         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
708         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
709             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
710             << " (accumulator = " << acc[i * n() + j]
711             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
712             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
713             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
714       }
715     }
716   }
717 }
718 
Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax,xnn_init_bf16_minmax_params_fn init_params) const719 void GemmMicrokernelTester::Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const
720 {
721   ASSERT_LE(m(), mr());
722   ASSERT_GE(a_stride(), k());
723   ASSERT_GE(cm_stride(), n());
724 
725   std::random_device random_device;
726   auto rng = std::mt19937(random_device());
727   auto f32rng = std::bind(std::uniform_real_distribution<float>(0.5f, 1.0f), std::ref(rng));
728 
729   std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
730   std::vector<uint16_t> b(n() * k());
731   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + packed_n());
732   std::vector<uint16_t> bias(n());
733   std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
734   std::vector<float> c_ref(m() * n());
735 
736   for (size_t iteration = 0; iteration < iterations(); iteration++) {
737     std::generate(a.begin(), a.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
738     std::generate(b.begin(), b.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
739     std::generate(bias.begin(), bias.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
740     std::fill(c.begin(), c.end(), UINT32_C(0x7FC0) /* NaN */);
741     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
742 
743     std::fill(packed_w.begin(), packed_w.end(), 0);
744     xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
745 
746     for (size_t m_index = 0; m_index < m(); m_index++) {
747       for (size_t n_index = 0; n_index < n(); n_index++) {
748         c_ref[m_index * n() + n_index] = fp32_from_bits(uint32_t(bias[n_index]) << 16);
749         for (size_t k_index = 0; k_index < k(); k_index++) {
750           ASSERT_LE(n(), packed_n());
751           ASSERT_LT(m_index * n() + n_index, c_ref.size());
752           ASSERT_LT(m_index * k() + k_index, a.size());
753           c_ref[m_index * n() + n_index] +=
754             fp32_from_bits(uint32_t(a[m_index * a_stride() + k_index]) << 16) *
755             fp32_from_bits(uint32_t(b[n_index * k() + k_index]) << 16);
756         }
757       }
758     }
759 
760     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
761     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
762     const float c_min = fp32_from_bits(fp32_to_bits(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())) & UINT32_C(0xFFFF0000));
763     const float c_max = fp32_from_bits(fp32_to_bits(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())) & UINT32_C(0xFFFF0000));
764 
765     // Prepare parameters.
766     xnn_bf16_minmax_params params;
767     init_params(&params,
768       fp32_to_bits(c_min) >> 16,
769       fp32_to_bits(c_max) >> 16);
770 
771     for (float& c_value : c_ref) {
772       c_value = std::max(std::min(c_value, c_max), c_min);
773     }
774 
775     gemm_minmax(m(), n(), k() * sizeof(uint16_t),
776       a.data(), a_stride() * sizeof(uint16_t),
777       packed_w.data(),
778       c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
779       &params);
780 
781     // Validate micro-kernel outputs.
782     for (size_t i = 0; i < m(); i++) {
783       for (size_t j = 0; j < n(); j++) {
784         ASSERT_NEAR(
785             fp32_from_bits(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << 16),
786             c_ref[i * n() + j],
787             std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 3.0e-2f))
788           << "at " << i << ", " << j << ": Mr x Nr x Kr = " << mr() << " x " << nr()
789           << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
790       }
791     }
792   }
793 }
794 
Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f16_minmax_params_fn init_params) const795 void GemmMicrokernelTester::Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const
796 {
797   ASSERT_LE(m(), mr());
798   ASSERT_GE(a_stride(), k());
799   ASSERT_GE(cm_stride(), n());
800 
801   std::random_device random_device;
802   auto rng = std::mt19937(random_device());
803   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
804   auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
805 
806   std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
807   std::vector<uint16_t> b(n() * k());
808   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + packed_n());
809   std::vector<uint16_t> bias(n());
810   std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
811   std::vector<float> c_ref(m() * n());
812 
813   for (size_t iteration = 0; iteration < iterations(); iteration++) {
814     std::generate(a.begin(), a.end(), std::ref(f16rng));
815     std::generate(b.begin(), b.end(), std::ref(f16rng));
816     std::generate(bias.begin(), bias.end(), std::ref(f16rng));
817     std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
818     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
819 
820     std::fill(packed_w.begin(), packed_w.end(), 0);
821     xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
822 
823     for (size_t m_index = 0; m_index < m(); m_index++) {
824       for (size_t n_index = 0; n_index < n(); n_index++) {
825         for (size_t k_index = 0; k_index < k(); k_index++) {
826           ASSERT_LE(n(), packed_n());
827           ASSERT_LT(m_index * n() + n_index, c_ref.size());
828           ASSERT_LT(m_index * k() + k_index, a.size());
829           c_ref[m_index * n() + n_index] +=
830             fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_index]) *
831             fp16_ieee_to_fp32_value(b[n_index * k() + k_index]);
832         }
833         c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
834       }
835     }
836 
837     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
838     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
839     const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
840     const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
841 
842     // Prepare parameters.
843     xnn_f16_minmax_params params;
844     init_params(&params,
845       fp16_ieee_from_fp32_value(c_min),
846       fp16_ieee_from_fp32_value(c_max));
847 
848     for (float& c_value : c_ref) {
849       c_value = std::max(std::min(c_value, c_max), c_min);
850     }
851 
852     gemm_minmax(m(), n(), k() * sizeof(uint16_t),
853       a.data(), a_stride() * sizeof(uint16_t),
854       packed_w.data(),
855       c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
856       &params);
857 
858     // Validate micro-kernel outputs.
859     for (size_t i = 0; i < m(); i++) {
860       for (size_t j = 0; j < n(); j++) {
861         ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
862             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
863             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
864             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
865       }
866     }
867   }
868 }
869 
Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f16_minmax_params_fn init_params) const870 void GemmMicrokernelTester::Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_minmax_params_fn init_params) const {
871   ASSERT_LE(m(), mr());
872 
873   std::random_device random_device;
874   auto rng = std::mt19937(random_device());
875   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
876   auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
877 
878   std::vector<uint16_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
879   std::vector<uint16_t> b(n() * ks() * k());
880   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
881   std::vector<uint16_t> bias(n());
882   std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
883   std::vector<float> c_ref(m() * n());
884   std::vector<uint16_t> junk(k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
885   std::vector<const uint16_t*> im2col(mr() * ks());
886   std::fill(junk.begin(), junk.end(), UINT16_C(0x7E00) /* NaN */);
887 
888   for (size_t iteration = 0; iteration < iterations(); iteration++) {
889     std::generate(a.begin(), a.end(), std::ref(f16rng));
890     std::generate(b.begin(), b.end(), std::ref(f16rng));
891     std::generate(bias.begin(), bias.end(), std::ref(f16rng));
892     std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
893     std::fill(c_ref.begin(), c_ref.end(), 0);
894 
895     std::fill(packed_w.begin(), packed_w.end(), 0);
896     xnn_pack_f16_conv_goki_w(
897       1, n(), ks(), k(), nr(), kr(), sr(),
898       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
899 
900     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
901       for (size_t m_index = 0; m_index < mr(); m_index++) {
902         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
903       }
904     }
905     std::shuffle(im2col.begin(), im2col.end(), rng);
906     if (zero_index() != SIZE_MAX) {
907       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
908         im2col[ks_index * mr() + zero_index()] = a.data();
909       }
910     }
911     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
912       for (size_t m_index = m(); m_index < mr(); m_index++) {
913         im2col[ks_index * mr() + m_index] = junk.data();
914       }
915     }
916 
917     std::fill(c_ref.begin(), c_ref.end(), 0.0);
918     for (size_t m_index = 0; m_index < m(); m_index++) {
919       for (size_t n_index = 0; n_index < n(); n_index++) {
920         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
921           for (size_t k_index = 0; k_index < k(); k_index++) {
922             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
923             ASSERT_LT(k_index, k());
924             ASSERT_LT(k_index, a_stride());
925             if (im2col[ks_index * mr() + m_index] == a.data()) {
926               c_ref[m_index * n() + n_index] +=
927                 fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index]) *
928                 fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
929             } else {
930               c_ref[m_index * n() + n_index] +=
931                 fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
932                 fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
933             }
934           }
935         }
936         c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
937       }
938     }
939 
940     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
941     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
942     const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * uint16_t(qmin())));
943     const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * uint16_t(255 - qmax())));
944     for (size_t m_index = 0; m_index < m(); m_index++) {
945       for (size_t n_index = 0; n_index < n(); n_index++) {
946         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
947         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
948       }
949     }
950 
951     // Prepare parameters.
952     xnn_f16_minmax_params params;
953     init_params(&params,
954       fp16_ieee_from_fp32_value(c_min),
955       fp16_ieee_from_fp32_value(c_max));
956 
957     for (float& c_value : c_ref) {
958       c_value = std::max(std::min(c_value, c_max), c_min);
959     }
960 
961     const uint16_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
962 
963     igemm_minmax(
964       m(), n(), k() * sizeof(uint16_t), ks() * mr() * sizeof(void*),
965       reinterpret_cast<const void**>(im2col.data()), packed_w.data(),
966       c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
967       a_offset() * sizeof(uint16_t), zero_pointer,
968       &params);
969 
970     for (size_t i = 0; i < m(); i++) {
971       for (size_t j = 0; j < n(); j++) {
972         ASSERT_LE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_max)
973             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
974             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
975             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
976         ASSERT_GE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_min)
977             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
978             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
979             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
980         ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
981             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
982             << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
983             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
984       }
985     }
986   }
987 }
988 
Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax,xnn_init_f32_minmax_params_fn init_params) const989 void GemmMicrokernelTester::Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
990   ASSERT_LE(m(), mr());
991   ASSERT_GE(cm_stride(), n());
992 
993   std::random_device random_device;
994   auto rng = std::mt19937(random_device());
995   std::uniform_real_distribution<float> f32dist;
996 
997   std::vector<float> a(packed_k() * mr());
998   std::vector<float> b(n() * k());
999   std::vector<float> bias(n());
1000   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1001   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1002   std::vector<float> c_ref(m() * n());
1003 
1004   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1005     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1006     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1007     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1008     std::fill(c.begin(), c.end(), nanf(""));
1009     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1010 
1011     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1012     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1013 
1014     for (size_t i = m(); i < mr(); i++) {
1015       for (size_t l = 0; l < k(); l++) {
1016         a[l * mr() + i] = a[l * mr() + m() - 1];
1017       }
1018     }
1019 
1020     for (size_t i = 0; i < m(); i++) {
1021       for (size_t j = 0; j < n(); j++) {
1022         for (size_t l = 0; l < k(); l++) {
1023           c_ref[i * n() + j] +=
1024             a[l * mr() + i] *
1025             b[j * k() + l];
1026         }
1027         c_ref[i * n() + j] += bias[j];
1028       }
1029     }
1030 
1031     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1032     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1033     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1034     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1035 
1036     // Prepare parameters.
1037     xnn_f32_minmax_params params;
1038     init_params(&params, c_min, c_max);
1039 
1040     for (float& c_value : c_ref) {
1041       c_value = std::max(std::min(c_value, c_max), c_min);
1042     }
1043 
1044     ppmm_minmax(m(), n(), k() * sizeof(float),
1045       a.data(), packed_w.data(),
1046       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1047       &params);
1048 
1049     // Validate micro-kernel outputs.
1050     for (size_t i = 0; i < m(); i++) {
1051       for (size_t j = 0; j < n(); j++) {
1052         ASSERT_NEAR(
1053             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1054             c_ref[i * n() + j],
1055             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1056             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1057             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1058             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1059       }
1060     }
1061   }
1062 }
1063 
Test(xnn_f32_gemm_ukernel_function gemm) const1064 void GemmMicrokernelTester::Test(xnn_f32_gemm_ukernel_function gemm) const {
1065   ASSERT_LE(m(), mr());
1066   ASSERT_GE(a_stride(), k());
1067   ASSERT_GE(cm_stride(), n());
1068 
1069   std::random_device random_device;
1070   auto rng = std::mt19937(random_device());
1071   std::uniform_real_distribution<float> f32dist;
1072 
1073   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1074   std::vector<float> b(n() * k());
1075   std::vector<float> bias(n());
1076   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1077   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1078   std::vector<float> c_ref(m() * n());
1079 
1080   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1081     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1082     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1083     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1084     std::fill(c.begin(), c.end(), nanf(""));
1085     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1086 
1087     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1088     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1089 
1090     for (size_t m_index = 0; m_index < m(); m_index++) {
1091       for (size_t n_index = 0; n_index < n(); n_index++) {
1092         for (size_t k_index = 0; k_index < k(); k_index++) {
1093           ASSERT_LE(n(), packed_n());
1094           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1095           c_ref[m_index * n() + n_index] +=
1096             a[m_index * a_stride() + k_index] *
1097             b[n_index * k() + k_index];
1098         }
1099         c_ref[m_index * n() + n_index] += bias[n_index];
1100       }
1101     }
1102 
1103     gemm(m(), n(), k() * sizeof(float),
1104       a.data(), a_stride() * sizeof(float),
1105       packed_w.data(),
1106       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1107       nullptr);
1108 
1109     // Validate micro-kernel outputs.
1110     for (size_t i = 0; i < m(); i++) {
1111       for (size_t j = 0; j < n(); j++) {
1112         ASSERT_NEAR(
1113             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1114             c_ref[i * n() + j],
1115             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1116             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1117             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1118             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1119       }
1120     }
1121   }
1122 }
1123 
Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const1124 void GemmMicrokernelTester::Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const {
1125   ASSERT_LE(m(), mr());
1126   ASSERT_GE(a_stride(), k());
1127   ASSERT_GE(cm_stride(), n());
1128 
1129   std::random_device random_device;
1130   auto rng = std::mt19937(random_device());
1131   std::uniform_real_distribution<float> f32dist;
1132 
1133   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1134   std::vector<float> b(n() * k());
1135   std::vector<float> bias(n());
1136   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1137   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1138   std::vector<float> c_ref(m() * n());
1139 
1140   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1141     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1142     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1143     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1144     std::fill(c.begin(), c.end(), nanf(""));
1145     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1146 
1147     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1148     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1149 
1150     for (size_t m_index = 0; m_index < m(); m_index++) {
1151       for (size_t n_index = 0; n_index < n(); n_index++) {
1152         for (size_t k_index = 0; k_index < k(); k_index++) {
1153           ASSERT_LE(n(), packed_n());
1154           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1155           c_ref[m_index * n() + n_index] +=
1156             a[m_index * a_stride() + k_index] *
1157             b[n_index * k() + k_index];
1158         }
1159         c_ref[m_index * n() + n_index] = std::max(0.0f, c_ref[m_index * n() + n_index] + bias[n_index]);
1160       }
1161     }
1162 
1163     gemm_relu(m(), n(), k() * sizeof(float),
1164       a.data(), a_stride() * sizeof(float),
1165       packed_w.data(),
1166       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1167       nullptr);
1168 
1169     // Validate micro-kernel outputs.
1170     for (size_t i = 0; i < m(); i++) {
1171       for (size_t j = 0; j < n(); j++) {
1172         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1173             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1174             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1175             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1176         ASSERT_NEAR(
1177             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1178             c_ref[i * n() + j],
1179             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1180             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1181             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1182             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1183       }
1184     }
1185   }
1186 }
1187 
Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1188 void GemmMicrokernelTester::Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1189   ASSERT_LE(m(), mr());
1190   ASSERT_GE(a_stride(), k());
1191   ASSERT_GE(cm_stride(), n());
1192 
1193   std::random_device random_device;
1194   auto rng = std::mt19937(random_device());
1195   std::uniform_real_distribution<float> f32dist;
1196 
1197   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1198   std::vector<float> b(n() * k());
1199   std::vector<float> bias(n());
1200   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1201   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1202   std::vector<float> c_ref(m() * n());
1203 
1204   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1205     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1206     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1207     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1208     std::fill(c.begin(), c.end(), nanf(""));
1209     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1210 
1211     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1212     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1213 
1214     for (size_t m_index = 0; m_index < m(); m_index++) {
1215       for (size_t n_index = 0; n_index < n(); n_index++) {
1216         for (size_t k_index = 0; k_index < k(); k_index++) {
1217           ASSERT_LE(n(), packed_n());
1218           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1219           c_ref[m_index * n() + n_index] +=
1220             a[m_index * a_stride() + k_index] *
1221             b[n_index * k() + k_index];
1222         }
1223         c_ref[m_index * n() + n_index] += bias[n_index];
1224       }
1225     }
1226 
1227     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1228     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1229     const float c_min =
1230         qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1231                     : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1232     const float c_max =
1233         qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1234                       : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1235 
1236     // Prepare parameters.
1237     xnn_f32_minmax_params params;
1238     init_params(&params, c_min, c_max);
1239 
1240     for (size_t m_index = 0; m_index < m(); m_index++) {
1241       for (size_t n_index = 0; n_index < n(); n_index++) {
1242         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1243       }
1244     }
1245 
1246     gemm_minmax(m(), n(), k() * sizeof(float),
1247       a.data(), a_stride() * sizeof(float),
1248       packed_w.data(),
1249       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1250       &params);
1251 
1252     // Validate micro-kernel outputs.
1253     for (size_t i = 0; i < m(); i++) {
1254       for (size_t j = 0; j < n(); j++) {
1255         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1256             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1257             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1258             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1259         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1260             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1261             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1262             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1263         ASSERT_NEAR(
1264             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1265             c_ref[i * n() + j],
1266             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1267             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1268             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1269             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1270       }
1271     }
1272   }
1273 }
1274 
Test(xnn_f32_gemminc_minmax_ukernel_function gemminc,xnn_init_f32_minmax_params_fn init_params) const1275 void GemmMicrokernelTester::Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const {
1276   ASSERT_LE(m(), mr());
1277   ASSERT_GE(a_stride(), k());
1278   ASSERT_GE(cm_stride(), n());
1279 
1280   std::random_device random_device;
1281   auto rng = std::mt19937(random_device());
1282   std::uniform_real_distribution<float> f32dist;
1283 
1284   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1285   std::vector<float> b(n() * k());
1286   std::vector<float> bias(n());
1287   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k());  // no packed_n()
1288   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1289   std::vector<float> c_ref(m() * n());
1290   std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
1291 
1292   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1293     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1294     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1295     std::fill(c.begin(), c.end(), nanf(""));
1296     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1297     std::generate(acc.begin(), acc.end(), [&]() { return f32dist(rng); });
1298 
1299     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1300     xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data(), nullptr);
1301 
1302     for (size_t m_index = 0; m_index < m(); m_index++) {
1303       for (size_t n_index = 0; n_index < n(); n_index++) {
1304         for (size_t k_index = 0; k_index < k(); k_index++) {
1305           ASSERT_LE(n(), packed_n());
1306           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1307           c_ref[m_index * n() + n_index] +=
1308             a[m_index * a_stride() + k_index] *
1309             b[n_index * k() + k_index];
1310         }
1311         c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
1312       }
1313     }
1314 
1315     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1316     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1317     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1318     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1319 
1320     // Prepare parameters.
1321     xnn_f32_minmax_params params;
1322     init_params(&params, c_min, c_max);
1323 
1324     for (size_t m_index = 0; m_index < m(); m_index++) {
1325       for (size_t n_index = 0; n_index < n(); n_index++) {
1326         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1327       }
1328     }
1329 
1330     gemminc(m(), n(), k() * sizeof(float),
1331       a.data(), a_stride() * sizeof(float),
1332       packed_w.data(),
1333       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1334       acc.data(),
1335       &params);
1336 
1337     // Validate micro-kernel outputs.
1338     for (size_t i = 0; i < m(); i++) {
1339       for (size_t j = 0; j < n(); j++) {
1340         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1341             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1342             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1343             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1344         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1345             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1346             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1347             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1348         ASSERT_NEAR(
1349             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1350             c_ref[i * n() + j],
1351             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1352             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1353             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1354             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1355       }
1356     }
1357   }
1358 }
1359 
Test(xnn_f32_igemm_ukernel_function igemm) const1360 void GemmMicrokernelTester::Test(xnn_f32_igemm_ukernel_function igemm) const {
1361   ASSERT_LE(m(), mr());
1362 
1363   std::random_device random_device;
1364   auto rng = std::mt19937(random_device());
1365   std::uniform_real_distribution<float> f32dist;
1366 
1367   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1368   std::vector<float> b(n() * ks() * k());
1369   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1370   std::vector<float> bias(n());
1371   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1372   std::vector<float> c_ref(m() * n());
1373   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1374   std::vector<const float*> im2col(mr() * ks());
1375   std::fill(junk.begin(), junk.end(), nanf(""));
1376 
1377   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1378     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1379     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1380     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1381     std::fill(c.begin(), c.end(), nanf(""));
1382     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1383 
1384     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1385     xnn_pack_f32_conv_goki_w(
1386       1, n(), ks(), k(), nr(), kr(), sr(),
1387       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1388 
1389     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1390       for (size_t m_index = 0; m_index < mr(); m_index++) {
1391         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1392       }
1393     }
1394     std::shuffle(im2col.begin(), im2col.end(), rng);
1395     if (zero_index() != SIZE_MAX) {
1396       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1397         im2col[ks_index * mr() + zero_index()] = a.data();
1398       }
1399     }
1400     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1401       for (size_t m_index = m(); m_index < mr(); m_index++) {
1402         im2col[ks_index * mr() + m_index] = junk.data();
1403       }
1404     }
1405 
1406     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1407     for (size_t m_index = 0; m_index < m(); m_index++) {
1408       for (size_t n_index = 0; n_index < n(); n_index++) {
1409         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1410           for (size_t k_index = 0; k_index < k(); k_index++) {
1411             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1412             ASSERT_LT(k_index, k());
1413             ASSERT_LT(k_index, a_stride());
1414             if (im2col[ks_index * mr() + m_index] == a.data()) {
1415               c_ref[m_index * n() + n_index] +=
1416                 (im2col[ks_index * mr() + m_index][k_index]) *
1417                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1418             } else {
1419               c_ref[m_index * n() + n_index] +=
1420                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1421                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1422             }
1423           }
1424         }
1425         c_ref[m_index * n() + n_index] += bias[n_index];
1426       }
1427     }
1428 
1429     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1430 
1431     igemm(
1432       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1433       im2col.data(), packed_w.data(),
1434       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1435       a_offset() * sizeof(float), zero_pointer,
1436       nullptr);
1437 
1438     for (size_t i = 0; i < m(); i++) {
1439       for (size_t j = 0; j < n(); j++) {
1440         ASSERT_NEAR(
1441             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1442             c_ref[i * n() + j],
1443             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1444             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1445             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1446             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1447       }
1448     }
1449   }
1450 }
1451 
Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const1452 void GemmMicrokernelTester::Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const {
1453   ASSERT_LE(m(), mr());
1454 
1455   std::random_device random_device;
1456   auto rng = std::mt19937(random_device());
1457   std::uniform_real_distribution<float> f32dist;
1458 
1459   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1460   std::vector<float> b(n() * ks() * k());
1461   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1462   std::vector<float> bias(n());
1463   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1464   std::vector<float> c_ref(m() * n());
1465   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1466   std::vector<const float*> im2col(mr() * ks());
1467   std::fill(junk.begin(), junk.end(), nanf(""));
1468 
1469   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1470     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1471     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1472     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1473     std::fill(c.begin(), c.end(), nanf(""));
1474     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1475 
1476     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1477     xnn_pack_f32_conv_goki_w(
1478       1, n(), ks(), k(), nr(), kr(), sr(),
1479       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1480 
1481     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1482       for (size_t m_index = 0; m_index < mr(); m_index++) {
1483         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1484       }
1485     }
1486     std::shuffle(im2col.begin(), im2col.end(), rng);
1487     if (zero_index() != SIZE_MAX) {
1488       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1489         im2col[ks_index * mr() + zero_index()] = a.data();
1490       }
1491     }
1492     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1493       for (size_t m_index = m(); m_index < mr(); m_index++) {
1494         im2col[ks_index * mr() + m_index] = junk.data();
1495       }
1496     }
1497 
1498     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1499     for (size_t m_index = 0; m_index < m(); m_index++) {
1500       for (size_t n_index = 0; n_index < n(); n_index++) {
1501         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1502           for (size_t k_index = 0; k_index < k(); k_index++) {
1503             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1504             ASSERT_LT(k_index, k());
1505             ASSERT_LT(k_index, a_stride());
1506             if (im2col[ks_index * mr() + m_index] == a.data()) {
1507               c_ref[m_index * n() + n_index] +=
1508                 (im2col[ks_index * mr() + m_index][k_index]) *
1509                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1510             } else {
1511               c_ref[m_index * n() + n_index] +=
1512                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1513                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1514             }
1515           }
1516         }
1517         c_ref[m_index * n() + n_index] = std::max(0.0f, bias[n_index] + c_ref[m_index * n() + n_index]);
1518       }
1519     }
1520 
1521     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1522 
1523     igemm_relu(
1524       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1525       im2col.data(), packed_w.data(),
1526       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1527       a_offset() * sizeof(float), zero_pointer,
1528       nullptr);
1529 
1530     for (size_t i = 0; i < m(); i++) {
1531       for (size_t j = 0; j < n(); j++) {
1532         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1533             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1534             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1535             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1536         ASSERT_NEAR(
1537             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1538             c_ref[i * n() + j],
1539             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1540             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1541             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1542             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1543       }
1544     }
1545   }
1546 }
1547 
Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1548 void GemmMicrokernelTester::Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1549   ASSERT_LE(m(), mr());
1550 
1551   std::random_device random_device;
1552   auto rng = std::mt19937(random_device());
1553   std::uniform_real_distribution<float> f32dist;
1554 
1555   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1556   std::vector<float> b(n() * ks() * k());
1557   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1558   std::vector<float> bias(n());
1559   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1560   std::vector<float> c_ref(m() * n());
1561   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1562   std::vector<const float*> im2col(mr() * ks());
1563   std::fill(junk.begin(), junk.end(), nanf(""));
1564 
1565   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1566     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1567     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1568     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1569     std::fill(c.begin(), c.end(), nanf(""));
1570     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1571 
1572     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1573     xnn_pack_f32_conv_goki_w(
1574       1, n(), ks(), k(), nr(), kr(), sr(),
1575       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1576 
1577     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1578       for (size_t m_index = 0; m_index < mr(); m_index++) {
1579         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1580       }
1581     }
1582     std::shuffle(im2col.begin(), im2col.end(), rng);
1583     if (zero_index() != SIZE_MAX) {
1584       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1585         im2col[ks_index * mr() + zero_index()] = a.data();
1586       }
1587     }
1588     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1589       for (size_t m_index = m(); m_index < mr(); m_index++) {
1590         im2col[ks_index * mr() + m_index] = junk.data();
1591       }
1592     }
1593 
1594     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1595     for (size_t m_index = 0; m_index < m(); m_index++) {
1596       for (size_t n_index = 0; n_index < n(); n_index++) {
1597         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1598           for (size_t k_index = 0; k_index < k(); k_index++) {
1599             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1600             ASSERT_LT(k_index, k());
1601             ASSERT_LT(k_index, a_stride());
1602             if (im2col[ks_index * mr() + m_index] == a.data()) {
1603               c_ref[m_index * n() + n_index] +=
1604                 (im2col[ks_index * mr() + m_index][k_index]) *
1605                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1606             } else {
1607               c_ref[m_index * n() + n_index] +=
1608                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1609                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1610             }
1611           }
1612         }
1613         c_ref[m_index * n() + n_index] += bias[n_index];
1614       }
1615     }
1616 
1617     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1618     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1619     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1620     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1621     for (size_t m_index = 0; m_index < m(); m_index++) {
1622       for (size_t n_index = 0; n_index < n(); n_index++) {
1623         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1624         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1625       }
1626     }
1627 
1628     // Prepare parameters.
1629     xnn_f32_minmax_params params;
1630     init_params(&params, c_min, c_max);
1631 
1632     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1633 
1634     igemm_minmax(
1635       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1636       im2col.data(), packed_w.data(),
1637       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1638       a_offset() * sizeof(float), zero_pointer,
1639       &params);
1640 
1641     for (size_t i = 0; i < m(); i++) {
1642       for (size_t j = 0; j < n(); j++) {
1643         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1644             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1645             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1646             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1647         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1648             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1649             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1650             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1651         ASSERT_NEAR(
1652             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1653             c_ref[i * n() + j],
1654             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1655             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1656             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1657             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1658       }
1659     }
1660   }
1661 }
1662 
1663 #if XNN_PLATFORM_JIT
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_f32_minmax_params_fn init_params) const1664 void GemmMicrokernelTester::Test(
1665     xnn_jit_gemm_code_generator_function gemm_generator,
1666     xnn_init_f32_minmax_params_fn init_params) const
1667 {
1668   ASSERT_LE(m(), mr());
1669   ASSERT_GE(a_stride(), k());
1670   ASSERT_GE(cm_stride(), n());
1671 
1672   std::random_device random_device;
1673   auto rng = std::mt19937(random_device());
1674   std::uniform_real_distribution<float> f32dist;
1675 
1676   std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1677   std::vector<float> b(n() * k());
1678   std::vector<float> bias(n());
1679   std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1680   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1681   std::vector<float> c_ref(m() * n());
1682 
1683   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1684     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1685     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1686     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1687     std::fill(c.begin(), c.end(), nanf(""));
1688     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1689 
1690     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1691     xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1692 
1693     for (size_t m_index = 0; m_index < m(); m_index++) {
1694       for (size_t n_index = 0; n_index < n(); n_index++) {
1695         for (size_t k_index = 0; k_index < k(); k_index++) {
1696           ASSERT_LE(n(), packed_n());
1697           ASSERT_LT(m_index * n() + n_index, c_ref.size());
1698           c_ref[m_index * n() + n_index] +=
1699             a[m_index * a_stride() + k_index] *
1700             b[n_index * k() + k_index];
1701         }
1702         c_ref[m_index * n() + n_index] += bias[n_index];
1703       }
1704     }
1705 
1706     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1707     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1708     const float c_min =
1709         qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1710                     : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1711     const float c_max =
1712         qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1713                       : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1714 
1715     // Prepare parameters.
1716     xnn_f32_minmax_params params;
1717     init_params(&params, c_min, c_max);
1718 
1719     for (size_t m_index = 0; m_index < m(); m_index++) {
1720       for (size_t n_index = 0; n_index < n(); n_index++) {
1721         c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1722       }
1723     }
1724 
1725     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1726     struct xnn_code_buffer code_buffer;
1727     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1728     jit_gemm_params p = (jit_gemm_params) {
1729       .f32_minmax = {
1730         .min = c_min,
1731         .max = c_max
1732       }
1733     };
1734     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k() * sizeof(float), &p));
1735     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1736     xnn_f32_gemm_minmax_ukernel_function gemm_minmax =
1737         reinterpret_cast<xnn_f32_gemm_minmax_ukernel_function>(code_buffer.start);
1738 
1739     gemm_minmax(m(), n(), k() * sizeof(float),
1740       a.data(), a_stride() * sizeof(float),
1741       packed_w.data(),
1742       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1743       &params);
1744 
1745     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1746 
1747     // Validate micro-kernel outputs.
1748     for (size_t i = 0; i < m(); i++) {
1749       for (size_t j = 0; j < n(); j++) {
1750         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1751             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1752             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1753             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1754         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1755             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1756             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1757             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1758         ASSERT_NEAR(
1759             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1760             c_ref[i * n() + j],
1761             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1762             << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1763             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1764             << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1765       }
1766     }
1767   }
1768 }
1769 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_f32_minmax_params_fn init_params) const1770 void GemmMicrokernelTester::Test(
1771     xnn_jit_igemm_code_generator_function igemm_generator,
1772     xnn_init_f32_minmax_params_fn init_params) const
1773 {
1774   ASSERT_LE(m(), mr());
1775 
1776   std::random_device random_device;
1777   auto rng = std::mt19937(random_device());
1778   std::uniform_real_distribution<float> f32dist;
1779 
1780   std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1781   std::vector<float> b(n() * ks() * k());
1782   std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1783   std::vector<float> bias(n());
1784   std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1785   std::vector<float> c_ref(m() * n());
1786   std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1787   std::vector<const float*> im2col(mr() * ks());
1788   std::fill(junk.begin(), junk.end(), nanf(""));
1789 
1790   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1791     std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1792     std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1793     std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1794     std::fill(c.begin(), c.end(), nanf(""));
1795     std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1796 
1797     std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1798     xnn_pack_f32_conv_goki_w(
1799       1, n(), ks(), k(), nr(), kr(), sr(),
1800       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1801 
1802     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1803       for (size_t m_index = 0; m_index < mr(); m_index++) {
1804         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1805       }
1806     }
1807     std::shuffle(im2col.begin(), im2col.end(), rng);
1808     if (zero_index() != SIZE_MAX) {
1809       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1810         im2col[ks_index * mr() + zero_index()] = a.data();
1811       }
1812     }
1813     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1814       for (size_t m_index = m(); m_index < mr(); m_index++) {
1815         im2col[ks_index * mr() + m_index] = junk.data();
1816       }
1817     }
1818 
1819     std::fill(c_ref.begin(), c_ref.end(), 0.0);
1820     for (size_t m_index = 0; m_index < m(); m_index++) {
1821       for (size_t n_index = 0; n_index < n(); n_index++) {
1822         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1823           for (size_t k_index = 0; k_index < k(); k_index++) {
1824             ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1825             ASSERT_LT(k_index, k());
1826             ASSERT_LT(k_index, a_stride());
1827             if (im2col[ks_index * mr() + m_index] == a.data()) {
1828               c_ref[m_index * n() + n_index] +=
1829                 (im2col[ks_index * mr() + m_index][k_index]) *
1830                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1831             } else {
1832               c_ref[m_index * n() + n_index] +=
1833                 (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1834                 (b[(n_index * ks() + ks_index) * k() + k_index]);
1835             }
1836           }
1837         }
1838         c_ref[m_index * n() + n_index] += bias[n_index];
1839       }
1840     }
1841 
1842     const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1843     const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1844     const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1845     const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1846     for (size_t m_index = 0; m_index < m(); m_index++) {
1847       for (size_t n_index = 0; n_index < n(); n_index++) {
1848         c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1849         c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1850       }
1851     }
1852 
1853     // Prepare parameters.
1854     xnn_f32_minmax_params params;
1855     init_params(&params, c_min, c_max);
1856 
1857     const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1858 
1859     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1860     struct xnn_code_buffer code_buffer;
1861     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1862     jit_gemm_params p = (jit_gemm_params) {
1863       .f32_minmax = {
1864         .min = c_min,
1865         .max = c_max
1866       }
1867     };
1868     ASSERT_EQ(xnn_status_success,
1869               igemm_generator(&code_buffer, mr(), n() % nr(), k() * sizeof(float), ks() * mr() * sizeof(void *), &p));
1870     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1871     xnn_f32_igemm_minmax_ukernel_function igemm_minmax =
1872         reinterpret_cast<xnn_f32_igemm_minmax_ukernel_function>(code_buffer.start);
1873 
1874     igemm_minmax(
1875       m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1876       im2col.data(), packed_w.data(),
1877       c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1878       a_offset() * sizeof(float), zero_pointer,
1879       &params);
1880 
1881     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1882 
1883     for (size_t i = 0; i < m(); i++) {
1884       for (size_t j = 0; j < n(); j++) {
1885         ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1886             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1887             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1888             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1889         ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1890             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1891             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1892             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1893         ASSERT_NEAR(
1894             c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1895             c_ref[i * n() + j],
1896             std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1897             << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1898             << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1899             << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1900       }
1901     }
1902   }
1903 }
1904 
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const1905 void GemmMicrokernelTester::Test(
1906   xnn_jit_gemm_code_generator_function gemm_generator,
1907   xnn_init_qc8_conv_minmax_params_fn init_params,
1908   xnn_qs8_requantize_fn requantize) const
1909 {
1910   ASSERT_LE(m(), mr());
1911 
1912   std::random_device random_device;
1913   auto rng = std::mt19937(random_device());
1914   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
1915   auto i8rng = std::bind(
1916     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
1917     std::ref(rng));
1918   auto w8rng = std::bind(
1919     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
1920     std::ref(rng));
1921 
1922   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
1923   std::vector<int8_t> b(n() * k());
1924   std::vector<int32_t> bias(n());
1925   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
1926   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
1927   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1928   std::vector<int32_t> acc(m() * n());
1929   std::vector<float> scale(n());
1930   std::vector<int8_t> c_ref(m() * n());
1931 
1932   for (size_t iteration = 0; iteration < iterations(); iteration++) {
1933     do {
1934       std::generate(a.begin(), a.end(), std::ref(i8rng));
1935     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
1936     do {
1937       std::generate(b.begin(), b.end(), std::ref(w8rng));
1938     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
1939     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
1940     std::fill(c.begin(), c.end(), 0xA5);
1941 
1942     std::fill(packed_w.begin(), packed_w.end(), 0);
1943     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
1944     if (extended_weights()) {
1945       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
1946         b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
1947     } else {
1948       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
1949         b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
1950     }
1951 
1952     // Compute 32-bit results and output quantization arguments.
1953     std::fill(acc.begin(), acc.end(), 0);
1954     for (size_t m_index = 0; m_index < m(); m_index++) {
1955       for (size_t n_index = 0; n_index < n(); n_index++) {
1956         for (size_t k_index = 0; k_index < k(); k_index++) {
1957           acc[m_index * n() + n_index] +=
1958               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
1959               int32_t(b[n_index * k() + k_index]);
1960         }
1961         acc[m_index * n() + n_index] += bias[n_index];
1962       }
1963     }
1964 
1965     const int8_t c_zero_point = -1;
1966     for (size_t n_index = 0; n_index < n(); n_index++) {
1967       int32_t accumulated_min = acc[n_index];
1968       int32_t accumulated_max = acc[n_index];
1969       for (size_t m_index = 0; m_index < m(); m_index++) {
1970         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
1971         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
1972       }
1973       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
1974       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
1975       scale[n_index] = 1.0f / c_scale;
1976     }
1977 
1978     if (extended_weights()) {
1979       xnn_init_qc8_scale_fp32_params(
1980         n(), nr(),
1981         nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1982         (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
1983     } else {
1984       xnn_init_qc8_scale_fp32_params(
1985         n(), nr(),
1986         nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1987         (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
1988     }
1989 
1990     union xnn_qc8_conv_minmax_params minmax_params;
1991     init_params(&minmax_params,
1992       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
1993 
1994     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1995     struct xnn_code_buffer code_buffer;
1996     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1997     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k(), nullptr));
1998     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1999     xnn_qc8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qc8_gemm_minmax_ukernel_function>(code_buffer.start);
2000 
2001     gemm(
2002       m(), n(), k(),
2003       a.data(), a_stride() * sizeof(int8_t),
2004       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
2005       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2006       &minmax_params);
2007 
2008     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2009 
2010     for (size_t m_index = 0; m_index < m(); m_index++) {
2011       for (size_t n_index = 0; n_index < n(); n_index++) {
2012         c_ref[m_index * n() + n_index] = requantize(
2013           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2014       }
2015     }
2016 
2017     for (size_t i = 0; i < m(); i++) {
2018       for (size_t j = 0; j < n(); j++) {
2019         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2020         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2021         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2022             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
2023             << " (accumulator = " << acc[i * n() + j]
2024             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
2025             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2026             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
2027       }
2028     }
2029   }
2030 }
2031 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2032 void GemmMicrokernelTester::Test(
2033   xnn_jit_igemm_code_generator_function igemm_generator,
2034   xnn_init_qc8_conv_minmax_params_fn init_params,
2035   xnn_qs8_requantize_fn requantize) const
2036 {
2037   ASSERT_LE(m(), mr());
2038 
2039   std::random_device random_device;
2040   auto rng = std::mt19937(random_device());
2041   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2042   auto i8rng = std::bind(
2043     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2044     std::ref(rng));
2045   auto w8rng = std::bind(
2046     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2047     std::ref(rng));
2048 
2049   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
2050   std::vector<int8_t> b(n() * ks() * k());
2051   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
2052   std::vector<int32_t> bias(n());
2053   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2054   std::vector<int32_t> acc(m() * n());
2055   std::vector<float> scale(n());
2056   std::vector<int8_t> c_ref(m() * n());
2057   std::vector<int8_t> junk(k() + 8);
2058   std::vector<const int8_t*> im2col(mr() * ks());
2059 
2060   std::fill(junk.begin(), junk.end(), 0xA5);
2061 
2062   for (size_t iteration = 0; iteration < iterations(); iteration++) {
2063     do {
2064       std::generate(a.begin(), a.end(), std::ref(i8rng));
2065     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2066     do {
2067       std::generate(b.begin(), b.end(), std::ref(w8rng));
2068     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2069     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2070     std::fill(c.begin(), c.end(), 0xA5);
2071 
2072     std::fill(packed_w.begin(), packed_w.end(), 0);
2073     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2074     xnn_pack_qs8_conv_goki_w(
2075       1, n(), ks(), k(), nr(), kr(), sr(),
2076       b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
2077 
2078     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2079       for (size_t m_index = 0; m_index < mr(); m_index++) {
2080         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
2081       }
2082     }
2083     std::shuffle(im2col.begin(), im2col.end(), rng);
2084     if (zero_index() != SIZE_MAX) {
2085       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2086         im2col[ks_index * mr() + zero_index()] = a.data();
2087       }
2088     }
2089     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2090       for (size_t m_index = m(); m_index < mr(); m_index++) {
2091         im2col[ks_index * mr() + m_index] = junk.data();
2092       }
2093     }
2094 
2095     // Compute 32-bit results and output quantization arguments.
2096     std::fill(acc.begin(), acc.end(), 0);
2097     for (size_t m_index = 0; m_index < m(); m_index++) {
2098       for (size_t n_index = 0; n_index < n(); n_index++) {
2099         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2100           for (size_t k_index = 0; k_index < k(); k_index++) {
2101             if (im2col[ks_index * mr() + m_index] == a.data()) {
2102               acc[m_index * n() + n_index] +=
2103                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2104                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2105             } else {
2106               acc[m_index * n() + n_index] +=
2107                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2108                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2109             }
2110           }
2111         }
2112         acc[m_index * n() + n_index] += bias[n_index];
2113       }
2114     }
2115 
2116     const int8_t c_zero_point = -1;
2117     for (size_t n_index = 0; n_index < n(); n_index++) {
2118       int32_t accumulated_min = acc[n_index];
2119       int32_t accumulated_max = acc[n_index];
2120       for (size_t m_index = 0; m_index < m(); m_index++) {
2121         accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
2122         accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
2123       }
2124       const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
2125       const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
2126       scale[n_index] = 1.0f / c_scale;
2127     }
2128 
2129     xnn_init_qc8_scale_fp32_params(
2130       n(), nr(),
2131       nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
2132       (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
2133 
2134     union xnn_qc8_conv_minmax_params minmax_params;
2135     init_params(&minmax_params,
2136       c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2137 
2138     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2139 
2140     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2141     struct xnn_code_buffer code_buffer;
2142     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2143     ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer, mr(), n() % nr(), k(), ks() * mr() * sizeof(void *), nullptr));
2144     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2145     xnn_qc8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qc8_igemm_minmax_ukernel_function>(code_buffer.start);
2146 
2147     igemm(
2148       m(), n(), k(), ks() * mr() * sizeof(void*),
2149       im2col.data(), packed_w.data(),
2150       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2151       a_offset() * sizeof(uint8_t), zero_pointer,
2152       &minmax_params);
2153 
2154     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2155 
2156     for (size_t m_index = 0; m_index < m(); m_index++) {
2157       for (size_t n_index = 0; n_index < n(); n_index++) {
2158         c_ref[m_index * n() + n_index] = requantize(
2159           acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2160       }
2161     }
2162 
2163     for (size_t i = 0; i < m(); i++) {
2164       for (size_t j = 0; j < n(); j++) {
2165         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2166         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2167         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2168             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2169             << " (accumulator = " << acc[i * n() + j]
2170             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2171             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2172             << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
2173       }
2174     }
2175   }
2176 }
2177 
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2178 void GemmMicrokernelTester::Test(
2179   xnn_jit_gemm_code_generator_function gemm_generator,
2180   xnn_init_qs8_conv_minmax_params_fn init_params,
2181   xnn_qs8_requantize_fn requantize) const
2182 {
2183   ASSERT_LE(m(), mr());
2184 
2185   std::random_device random_device;
2186   auto rng = std::mt19937(random_device());
2187   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2188   auto i8rng = std::bind(
2189     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2190     std::ref(rng));
2191   auto w8rng = std::bind(
2192     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2193     std::ref(rng));
2194 
2195   std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
2196   std::vector<int8_t> b(n() * k());
2197   std::vector<int32_t> bias(n());
2198   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2199   std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
2200   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2201   std::vector<int32_t> acc(m() * n());
2202   std::vector<int8_t> c_ref(m() * n());
2203 
2204   for (size_t iteration = 0; iteration < iterations(); iteration++) {
2205     do {
2206       std::generate(a.begin(), a.end(), std::ref(i8rng));
2207     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2208     do {
2209       std::generate(b.begin(), b.end(), std::ref(w8rng));
2210     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2211     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2212     std::fill(c.begin(), c.end(), 0xA5);
2213 
2214     std::fill(packed_w.begin(), packed_w.end(), 0);
2215     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2216     if (extended_weights()) {
2217       xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
2218         b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
2219     } else {
2220       xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
2221         b.data(), bias.data(), packed_w.data(), 0, &packing_params);
2222     }
2223 
2224     // Compute 32-bit results and output quantization arguments.
2225     std::fill(acc.begin(), acc.end(), 0);
2226     for (size_t m_index = 0; m_index < m(); m_index++) {
2227       for (size_t n_index = 0; n_index < n(); n_index++) {
2228         for (size_t k_index = 0; k_index < k(); k_index++) {
2229           acc[m_index * n() + n_index] +=
2230               (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
2231               int32_t(b[n_index * k() + k_index]);
2232         }
2233         acc[m_index * n() + n_index] += bias[n_index];
2234       }
2235     }
2236 
2237     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2238     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2239     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2240     const int8_t c_zero_point = int8_t(std::max(std::min(
2241       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2242       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2243 
2244     const float requantization_scale = 1.0f / float(c_scale);
2245     union xnn_qs8_conv_minmax_params quantization_params;
2246     init_params(&quantization_params,
2247       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2248 
2249     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2250     struct xnn_code_buffer code_buffer;
2251     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2252     ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k(), nullptr));
2253     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2254     xnn_qs8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qs8_gemm_minmax_ukernel_function >(code_buffer.start);
2255 
2256     gemm(
2257       m(), n(), k(),
2258       a.data(), a_stride() * sizeof(int8_t),
2259       extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
2260       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2261       &quantization_params);
2262 
2263     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2264 
2265     for (size_t m_index = 0; m_index < m(); m_index++) {
2266       for (size_t n_index = 0; n_index < n(); n_index++) {
2267         c_ref[m_index * n() + n_index] = requantize(
2268           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2269       }
2270     }
2271 
2272     for (size_t i = 0; i < m(); i++) {
2273       for (size_t j = 0; j < n(); j++) {
2274         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2275         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2276         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2277             << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
2278             << " (accumulator = " << acc[i * n() + j]
2279             << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
2280             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2281             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2282       }
2283     }
2284   }
2285 }
2286 
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2287 void GemmMicrokernelTester::Test(
2288   xnn_jit_igemm_code_generator_function igemm_generator,
2289   xnn_init_qs8_conv_minmax_params_fn init_params,
2290   xnn_qs8_requantize_fn requantize) const
2291 {
2292   ASSERT_LE(m(), mr());
2293 
2294   std::random_device random_device;
2295   auto rng = std::mt19937(random_device());
2296   auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2297   auto i8rng = std::bind(
2298     std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2299     std::ref(rng));
2300   auto w8rng = std::bind(
2301     std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2302     std::ref(rng));
2303 
2304   std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
2305   std::vector<int8_t> b(n() * ks() * k());
2306   std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2307   std::vector<int32_t> bias(n());
2308   std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2309   std::vector<int32_t> acc(m() * n());
2310   std::vector<int8_t> c_ref(m() * n());
2311   std::vector<int8_t> junk(k() + 8);
2312   std::vector<const int8_t*> im2col(mr() * ks());
2313 
2314   std::fill(junk.begin(), junk.end(), 0xA5);
2315 
2316   for (size_t iteration = 0; iteration < iterations(); iteration++) {
2317     do {
2318       std::generate(a.begin(), a.end(), std::ref(i8rng));
2319     } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2320     do {
2321       std::generate(b.begin(), b.end(), std::ref(w8rng));
2322     } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2323     std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2324     std::fill(c.begin(), c.end(), 0xA5);
2325 
2326     std::fill(packed_w.begin(), packed_w.end(), 0);
2327     const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2328     xnn_pack_qs8_conv_goki_w(
2329       1, n(), ks(), k(), nr(), kr(), sr(),
2330       b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
2331 
2332     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2333       for (size_t m_index = 0; m_index < mr(); m_index++) {
2334         im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
2335       }
2336     }
2337     std::shuffle(im2col.begin(), im2col.end(), rng);
2338     if (zero_index() != SIZE_MAX) {
2339       for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2340         im2col[ks_index * mr() + zero_index()] = a.data();
2341       }
2342     }
2343     for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2344       for (size_t m_index = m(); m_index < mr(); m_index++) {
2345         im2col[ks_index * mr() + m_index] = junk.data();
2346       }
2347     }
2348 
2349     // Compute 32-bit results and output quantization arguments.
2350     std::fill(acc.begin(), acc.end(), 0);
2351     for (size_t m_index = 0; m_index < m(); m_index++) {
2352       for (size_t n_index = 0; n_index < n(); n_index++) {
2353         for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2354           for (size_t k_index = 0; k_index < k(); k_index++) {
2355             if (im2col[ks_index * mr() + m_index] == a.data()) {
2356               acc[m_index * n() + n_index] +=
2357                 (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2358                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2359             } else {
2360               acc[m_index * n() + n_index] +=
2361                 (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2362                 int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2363             }
2364           }
2365         }
2366         acc[m_index * n() + n_index] += bias[n_index];
2367       }
2368     }
2369 
2370     const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2371     const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2372     const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2373     const uint8_t c_zero_point = uint8_t(std::max(std::min(
2374       lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2375       long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2376 
2377     const float requantization_scale = 1.0f / float(c_scale);
2378     union xnn_qs8_conv_minmax_params quantization_params;
2379     init_params(&quantization_params,
2380       requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2381 
2382     const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2383 
2384     ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2385     struct xnn_code_buffer code_buffer;
2386     ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2387     ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer, mr(), n() % nr(), k(), ks() * mr() * sizeof(void *), nullptr));
2388     ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2389     xnn_qs8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qs8_igemm_minmax_ukernel_function>(code_buffer.start);
2390 
2391     igemm(
2392       m(), n(), k(), ks() * mr() * sizeof(void*),
2393       im2col.data(), packed_w.data(),
2394       c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2395       a_offset() * sizeof(uint8_t), zero_pointer,
2396       &quantization_params);
2397 
2398     ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2399 
2400     for (size_t m_index = 0; m_index < m(); m_index++) {
2401       for (size_t n_index = 0; n_index < n(); n_index++) {
2402         c_ref[m_index * n() + n_index] = requantize(
2403           acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2404       }
2405     }
2406 
2407     for (size_t i = 0; i < m(); i++) {
2408       for (size_t j = 0; j < n(); j++) {
2409         ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2410         ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2411         ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2412             << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2413             << " (accumulator = " << acc[i * n() + j]
2414             << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2415             << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2416             << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2417       }
2418     }
2419   }
2420 }
2421 
2422 #endif  // XNN_PLATFORM_JIT
2423