1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
24*4bdc9457SAndroid Build Coastguard Worker
25*4bdc9457SAndroid Build Coastguard Worker
is_fp16_zero(uint16_t x)26*4bdc9457SAndroid Build Coastguard Worker static inline bool is_fp16_zero(uint16_t x) {
27*4bdc9457SAndroid Build Coastguard Worker const uint16_t two_x = x + x;
28*4bdc9457SAndroid Build Coastguard Worker return two_x == 0;
29*4bdc9457SAndroid Build Coastguard Worker }
30*4bdc9457SAndroid Build Coastguard Worker
31*4bdc9457SAndroid Build Coastguard Worker class SpMMMicrokernelTester {
32*4bdc9457SAndroid Build Coastguard Worker public:
mr(size_t mr)33*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& mr(size_t mr) {
34*4bdc9457SAndroid Build Coastguard Worker this->mr_ = mr;
35*4bdc9457SAndroid Build Coastguard Worker return *this;
36*4bdc9457SAndroid Build Coastguard Worker }
37*4bdc9457SAndroid Build Coastguard Worker
mr()38*4bdc9457SAndroid Build Coastguard Worker inline size_t mr() const {
39*4bdc9457SAndroid Build Coastguard Worker return this->mr_;
40*4bdc9457SAndroid Build Coastguard Worker }
41*4bdc9457SAndroid Build Coastguard Worker
nr(size_t nr)42*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& nr(size_t nr) {
43*4bdc9457SAndroid Build Coastguard Worker this->nr_ = nr;
44*4bdc9457SAndroid Build Coastguard Worker return *this;
45*4bdc9457SAndroid Build Coastguard Worker }
46*4bdc9457SAndroid Build Coastguard Worker
nr()47*4bdc9457SAndroid Build Coastguard Worker inline size_t nr() const {
48*4bdc9457SAndroid Build Coastguard Worker return this->nr_;
49*4bdc9457SAndroid Build Coastguard Worker }
50*4bdc9457SAndroid Build Coastguard Worker
m(size_t m)51*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& m(size_t m) {
52*4bdc9457SAndroid Build Coastguard Worker this->m_ = m;
53*4bdc9457SAndroid Build Coastguard Worker return *this;
54*4bdc9457SAndroid Build Coastguard Worker }
55*4bdc9457SAndroid Build Coastguard Worker
m()56*4bdc9457SAndroid Build Coastguard Worker inline size_t m() const {
57*4bdc9457SAndroid Build Coastguard Worker return this->m_;
58*4bdc9457SAndroid Build Coastguard Worker }
59*4bdc9457SAndroid Build Coastguard Worker
n(size_t n)60*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& n(size_t n) {
61*4bdc9457SAndroid Build Coastguard Worker this->n_ = n;
62*4bdc9457SAndroid Build Coastguard Worker return *this;
63*4bdc9457SAndroid Build Coastguard Worker }
64*4bdc9457SAndroid Build Coastguard Worker
n()65*4bdc9457SAndroid Build Coastguard Worker inline size_t n() const {
66*4bdc9457SAndroid Build Coastguard Worker return this->n_;
67*4bdc9457SAndroid Build Coastguard Worker }
68*4bdc9457SAndroid Build Coastguard Worker
k(size_t k)69*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& k(size_t k) {
70*4bdc9457SAndroid Build Coastguard Worker this->k_ = k;
71*4bdc9457SAndroid Build Coastguard Worker return *this;
72*4bdc9457SAndroid Build Coastguard Worker }
73*4bdc9457SAndroid Build Coastguard Worker
k()74*4bdc9457SAndroid Build Coastguard Worker inline size_t k() const {
75*4bdc9457SAndroid Build Coastguard Worker return this->k_;
76*4bdc9457SAndroid Build Coastguard Worker }
77*4bdc9457SAndroid Build Coastguard Worker
output_stride(size_t output_stride)78*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& output_stride(size_t output_stride) {
79*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0);
80*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride;
81*4bdc9457SAndroid Build Coastguard Worker return *this;
82*4bdc9457SAndroid Build Coastguard Worker }
83*4bdc9457SAndroid Build Coastguard Worker
output_stride()84*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const {
85*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) {
86*4bdc9457SAndroid Build Coastguard Worker return m();
87*4bdc9457SAndroid Build Coastguard Worker } else {
88*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= m());
89*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_;
90*4bdc9457SAndroid Build Coastguard Worker }
91*4bdc9457SAndroid Build Coastguard Worker }
92*4bdc9457SAndroid Build Coastguard Worker
sparsity(float sparsity)93*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& sparsity(float sparsity) {
94*4bdc9457SAndroid Build Coastguard Worker this->sparsity_ = sparsity;
95*4bdc9457SAndroid Build Coastguard Worker return *this;
96*4bdc9457SAndroid Build Coastguard Worker }
97*4bdc9457SAndroid Build Coastguard Worker
sparsity()98*4bdc9457SAndroid Build Coastguard Worker inline float sparsity() const {
99*4bdc9457SAndroid Build Coastguard Worker return this->sparsity_;
100*4bdc9457SAndroid Build Coastguard Worker }
101*4bdc9457SAndroid Build Coastguard Worker
qmin(uint8_t qmin)102*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& qmin(uint8_t qmin) {
103*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin;
104*4bdc9457SAndroid Build Coastguard Worker return *this;
105*4bdc9457SAndroid Build Coastguard Worker }
106*4bdc9457SAndroid Build Coastguard Worker
qmin()107*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const {
108*4bdc9457SAndroid Build Coastguard Worker return this->qmin_;
109*4bdc9457SAndroid Build Coastguard Worker }
110*4bdc9457SAndroid Build Coastguard Worker
qmax(uint8_t qmax)111*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& qmax(uint8_t qmax) {
112*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax;
113*4bdc9457SAndroid Build Coastguard Worker return *this;
114*4bdc9457SAndroid Build Coastguard Worker }
115*4bdc9457SAndroid Build Coastguard Worker
qmax()116*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const {
117*4bdc9457SAndroid Build Coastguard Worker return this->qmax_;
118*4bdc9457SAndroid Build Coastguard Worker }
119*4bdc9457SAndroid Build Coastguard Worker
iterations(size_t iterations)120*4bdc9457SAndroid Build Coastguard Worker inline SpMMMicrokernelTester& iterations(size_t iterations) {
121*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations;
122*4bdc9457SAndroid Build Coastguard Worker return *this;
123*4bdc9457SAndroid Build Coastguard Worker }
124*4bdc9457SAndroid Build Coastguard Worker
iterations()125*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const {
126*4bdc9457SAndroid Build Coastguard Worker return this->iterations_;
127*4bdc9457SAndroid Build Coastguard Worker }
128*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_spmm_minmax_ukernel_function spmm,xnn_init_f32_minmax_params_fn init_params)129*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_spmm_minmax_ukernel_function spmm, xnn_init_f32_minmax_params_fn init_params) const {
130*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(m(), 1);
131*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(n(), 1);
132*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(k(), 1);
133*4bdc9457SAndroid Build Coastguard Worker
134*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
135*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
136*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
137*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> pdist;
138*4bdc9457SAndroid Build Coastguard Worker
139*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> input(k() * m());
140*4bdc9457SAndroid Build Coastguard Worker // Think of b as (n/nr + n % nr) x k, expansion happens later.
141*4bdc9457SAndroid Build Coastguard Worker const size_t ncols = n() / nr() + n() % nr();
142*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(ncols * k());
143*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
144*4bdc9457SAndroid Build Coastguard Worker // Number of non-zero weights per N (output channel).
145*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> nmap(n());
146*4bdc9457SAndroid Build Coastguard Worker // Mapping from index of non-zero weight to increment of K (input channel) following this index.
147*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> dmap(n() * k());
148*4bdc9457SAndroid Build Coastguard Worker std::vector<float> w(n() * k() + n());
149*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((n() - 1) * output_stride() + m());
150*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(n() * m());
151*4bdc9457SAndroid Build Coastguard Worker
152*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
153*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
154*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
155*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
156*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf(""));
157*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
158*4bdc9457SAndroid Build Coastguard Worker std::fill(nmap.begin(), nmap.end(), 0);
159*4bdc9457SAndroid Build Coastguard Worker std::fill(dmap.begin(), dmap.end(), 0);
160*4bdc9457SAndroid Build Coastguard Worker std::fill(w.begin(), w.end(), 0.0f);
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker for (float& b_value : b) {
163*4bdc9457SAndroid Build Coastguard Worker if (pdist(rng) <= sparsity()) {
164*4bdc9457SAndroid Build Coastguard Worker b_value = 0.0f;
165*4bdc9457SAndroid Build Coastguard Worker }
166*4bdc9457SAndroid Build Coastguard Worker }
167*4bdc9457SAndroid Build Coastguard Worker
168*4bdc9457SAndroid Build Coastguard Worker uint32_t nnz = 0;
169*4bdc9457SAndroid Build Coastguard Worker uint32_t wcnt = 0;
170*4bdc9457SAndroid Build Coastguard Worker size_t last_kk = 0;
171*4bdc9457SAndroid Build Coastguard Worker bool first_nzz = true;
172*4bdc9457SAndroid Build Coastguard Worker size_t first_kk = 0;
173*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = 0; nn < n() / nr(); nn++) {
174*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
175*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = bias[nr() * nn + i];
176*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
177*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
178*4bdc9457SAndroid Build Coastguard Worker // Every non-zero actually corresponds to nr adjacent non-zeros.
179*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
180*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = b[nn * k() + kk] + static_cast<float>(i);
181*4bdc9457SAndroid Build Coastguard Worker // Skip the very first non-zero weight as we record only the difference.
182*4bdc9457SAndroid Build Coastguard Worker if (first_nzz) {
183*4bdc9457SAndroid Build Coastguard Worker first_kk = kk;
184*4bdc9457SAndroid Build Coastguard Worker } else {
185*4bdc9457SAndroid Build Coastguard Worker const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(float));
186*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
187*4bdc9457SAndroid Build Coastguard Worker }
188*4bdc9457SAndroid Build Coastguard Worker last_kk = kk;
189*4bdc9457SAndroid Build Coastguard Worker first_nzz = false;
190*4bdc9457SAndroid Build Coastguard Worker nmap[nn] += 1;
191*4bdc9457SAndroid Build Coastguard Worker }
192*4bdc9457SAndroid Build Coastguard Worker }
193*4bdc9457SAndroid Build Coastguard Worker }
194*4bdc9457SAndroid Build Coastguard Worker
195*4bdc9457SAndroid Build Coastguard Worker // now we've constructed the matrix for the blocked part and switch to the
196*4bdc9457SAndroid Build Coastguard Worker // leftovers, which we do as nr=1 always.
197*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = n() / nr(); nn < ncols; nn++) {
198*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
199*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
200*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
201*4bdc9457SAndroid Build Coastguard Worker // Every non-zero actually corresponds to nr adjacent non-zeros.
202*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = b[nn * k() + kk];
203*4bdc9457SAndroid Build Coastguard Worker // Skip the very first non-zero weight as we record only the difference.
204*4bdc9457SAndroid Build Coastguard Worker if (first_nzz) {
205*4bdc9457SAndroid Build Coastguard Worker first_kk = kk;
206*4bdc9457SAndroid Build Coastguard Worker } else {
207*4bdc9457SAndroid Build Coastguard Worker const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(float));
208*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
209*4bdc9457SAndroid Build Coastguard Worker }
210*4bdc9457SAndroid Build Coastguard Worker last_kk = kk;
211*4bdc9457SAndroid Build Coastguard Worker first_nzz = false;
212*4bdc9457SAndroid Build Coastguard Worker nmap[nn] += 1;
213*4bdc9457SAndroid Build Coastguard Worker }
214*4bdc9457SAndroid Build Coastguard Worker }
215*4bdc9457SAndroid Build Coastguard Worker }
216*4bdc9457SAndroid Build Coastguard Worker // In the end, we must return input pointer to the initial value.
217*4bdc9457SAndroid Build Coastguard Worker const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(float));
218*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
219*4bdc9457SAndroid Build Coastguard Worker
220*4bdc9457SAndroid Build Coastguard Worker // Generate expanded b which will be used in reference calculation.
221*4bdc9457SAndroid Build Coastguard Worker // Everywhere there is input non-zero in the original we copy it and add an
222*4bdc9457SAndroid Build Coastguard Worker // adjacent non-zero with incremented weight value.
223*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b_full(n() * k());
224*4bdc9457SAndroid Build Coastguard Worker if (nr() == 1) {
225*4bdc9457SAndroid Build Coastguard Worker b_full = b;
226*4bdc9457SAndroid Build Coastguard Worker }
227*4bdc9457SAndroid Build Coastguard Worker else {
228*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = 0; nn < n() / nr(); nn++) {
229*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
230*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
231*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
232*4bdc9457SAndroid Build Coastguard Worker b_full[nr() * nn * k() + i * k() + kk] = b[nn * k() + kk] + static_cast<float>(i);
233*4bdc9457SAndroid Build Coastguard Worker }
234*4bdc9457SAndroid Build Coastguard Worker }
235*4bdc9457SAndroid Build Coastguard Worker }
236*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = n() / nr(); nn < ncols; nn++) {
237*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
238*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
239*4bdc9457SAndroid Build Coastguard Worker b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
240*4bdc9457SAndroid Build Coastguard Worker }
241*4bdc9457SAndroid Build Coastguard Worker }
242*4bdc9457SAndroid Build Coastguard Worker }
243*4bdc9457SAndroid Build Coastguard Worker }
244*4bdc9457SAndroid Build Coastguard Worker
245*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < n(); oc++) {
246*4bdc9457SAndroid Build Coastguard Worker for (size_t pxb = 0; pxb < m(); pxb++) {
247*4bdc9457SAndroid Build Coastguard Worker output_ref[oc * m() + pxb] = bias[oc];
248*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < k(); ic++) {
249*4bdc9457SAndroid Build Coastguard Worker output_ref[oc * m() + pxb] += input[ic * m() + pxb] * b_full[oc * k() + ic];
250*4bdc9457SAndroid Build Coastguard Worker }
251*4bdc9457SAndroid Build Coastguard Worker }
252*4bdc9457SAndroid Build Coastguard Worker }
253*4bdc9457SAndroid Build Coastguard Worker
254*4bdc9457SAndroid Build Coastguard Worker // Micro-kernel can access one element beyond w and dmap for software pipelining.
255*4bdc9457SAndroid Build Coastguard Worker w.resize(wcnt + 1);
256*4bdc9457SAndroid Build Coastguard Worker dmap.resize(nnz + 1);
257*4bdc9457SAndroid Build Coastguard Worker
258*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
259*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
260*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
261*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
262*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
263*4bdc9457SAndroid Build Coastguard Worker
264*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
265*4bdc9457SAndroid Build Coastguard Worker for (float& output_value : output_ref) {
266*4bdc9457SAndroid Build Coastguard Worker output_value = std::min(std::max(output_value, output_min), output_max);
267*4bdc9457SAndroid Build Coastguard Worker }
268*4bdc9457SAndroid Build Coastguard Worker
269*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
270*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
271*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, output_min, output_max);
272*4bdc9457SAndroid Build Coastguard Worker
273*4bdc9457SAndroid Build Coastguard Worker spmm(m() * sizeof(float), n(),
274*4bdc9457SAndroid Build Coastguard Worker input.data() + first_kk * m(),
275*4bdc9457SAndroid Build Coastguard Worker w.data(), dmap.data(), nmap.data(),
276*4bdc9457SAndroid Build Coastguard Worker output.data(), output_stride() * sizeof(float),
277*4bdc9457SAndroid Build Coastguard Worker ¶ms);
278*4bdc9457SAndroid Build Coastguard Worker
279*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
280*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
281*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
282*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
283*4bdc9457SAndroid Build Coastguard Worker output[j * output_stride() + i],
284*4bdc9457SAndroid Build Coastguard Worker output_ref[j * m() + i],
285*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[j * m() + i]) * 1.0e-6f)
286*4bdc9457SAndroid Build Coastguard Worker << "at M index " << i << " / " << m() << " (tile " << mr() << ")"
287*4bdc9457SAndroid Build Coastguard Worker << ", N index " << j << " / " << n() << " (tile " << nr() << ")"
288*4bdc9457SAndroid Build Coastguard Worker << ", K = " << k();
289*4bdc9457SAndroid Build Coastguard Worker }
290*4bdc9457SAndroid Build Coastguard Worker }
291*4bdc9457SAndroid Build Coastguard Worker }
292*4bdc9457SAndroid Build Coastguard Worker }
293*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f16_spmm_minmax_ukernel_function spmm,xnn_init_f16_minmax_params_fn init_params)294*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_spmm_minmax_ukernel_function spmm, xnn_init_f16_minmax_params_fn init_params) const {
295*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(m(), 1);
296*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(n(), 1);
297*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(k(), 1);
298*4bdc9457SAndroid Build Coastguard Worker
299*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
300*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
301*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
302*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> pdist;
303*4bdc9457SAndroid Build Coastguard Worker
304*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> input(k() * m());
305*4bdc9457SAndroid Build Coastguard Worker // Think of b as (n/nr + n % nr) x k, expansion happens later.
306*4bdc9457SAndroid Build Coastguard Worker const size_t ncols = n() / nr() + n() % nr();
307*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(ncols * k());
308*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(n());
309*4bdc9457SAndroid Build Coastguard Worker // Number of non-zero weights per N (output channel).
310*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> nmap(n());
311*4bdc9457SAndroid Build Coastguard Worker // Mapping from index of non-zero weight to increment of K (input channel) following this index.
312*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> dmap(n() * k());
313*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> w(n() * k() + n());
314*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((n() - 1) * output_stride() + m());
315*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(n() * m());
316*4bdc9457SAndroid Build Coastguard Worker
317*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
318*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
319*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
320*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
321*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xC000);
322*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
323*4bdc9457SAndroid Build Coastguard Worker std::fill(nmap.begin(), nmap.end(), 0);
324*4bdc9457SAndroid Build Coastguard Worker std::fill(dmap.begin(), dmap.end(), 0);
325*4bdc9457SAndroid Build Coastguard Worker std::fill(w.begin(), w.end(), 0);
326*4bdc9457SAndroid Build Coastguard Worker
327*4bdc9457SAndroid Build Coastguard Worker for (uint16_t& b_value : b) {
328*4bdc9457SAndroid Build Coastguard Worker if (pdist(rng) <= sparsity()) {
329*4bdc9457SAndroid Build Coastguard Worker b_value = 0;
330*4bdc9457SAndroid Build Coastguard Worker }
331*4bdc9457SAndroid Build Coastguard Worker }
332*4bdc9457SAndroid Build Coastguard Worker
333*4bdc9457SAndroid Build Coastguard Worker uint32_t nnz = 0;
334*4bdc9457SAndroid Build Coastguard Worker uint32_t wcnt = 0;
335*4bdc9457SAndroid Build Coastguard Worker size_t last_kk = 0;
336*4bdc9457SAndroid Build Coastguard Worker bool first_nzz = true;
337*4bdc9457SAndroid Build Coastguard Worker size_t first_kk = 0;
338*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = 0; nn < n() / nr(); nn++) {
339*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
340*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = bias[nr() * nn + i];
341*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
342*4bdc9457SAndroid Build Coastguard Worker if (!is_fp16_zero(b[nn * k() + kk])) {
343*4bdc9457SAndroid Build Coastguard Worker // Every non-zero actually corresponds to nr adjacent non-zeros.
344*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
345*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
346*4bdc9457SAndroid Build Coastguard Worker // Skip the very first non-zero weight as we record only the difference.
347*4bdc9457SAndroid Build Coastguard Worker if (first_nzz) {
348*4bdc9457SAndroid Build Coastguard Worker first_kk = kk;
349*4bdc9457SAndroid Build Coastguard Worker } else {
350*4bdc9457SAndroid Build Coastguard Worker const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
351*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
352*4bdc9457SAndroid Build Coastguard Worker }
353*4bdc9457SAndroid Build Coastguard Worker last_kk = kk;
354*4bdc9457SAndroid Build Coastguard Worker first_nzz = false;
355*4bdc9457SAndroid Build Coastguard Worker nmap[nn] += 1;
356*4bdc9457SAndroid Build Coastguard Worker }
357*4bdc9457SAndroid Build Coastguard Worker }
358*4bdc9457SAndroid Build Coastguard Worker }
359*4bdc9457SAndroid Build Coastguard Worker
360*4bdc9457SAndroid Build Coastguard Worker // now we've constructed the matrix for the blocked part and switch to the
361*4bdc9457SAndroid Build Coastguard Worker // leftovers, which we do as nr=1 always.
362*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = n() / nr(); nn < ncols; nn++) {
363*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
364*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
365*4bdc9457SAndroid Build Coastguard Worker if (!is_fp16_zero(b[nn * k() + kk])) {
366*4bdc9457SAndroid Build Coastguard Worker // Every non-zero actually corresponds to nr adjacent non-zeros.
367*4bdc9457SAndroid Build Coastguard Worker w[wcnt++] = b[nn * k() + kk];
368*4bdc9457SAndroid Build Coastguard Worker // Skip the very first non-zero weight as we record only the difference.
369*4bdc9457SAndroid Build Coastguard Worker if (first_nzz) {
370*4bdc9457SAndroid Build Coastguard Worker first_kk = kk;
371*4bdc9457SAndroid Build Coastguard Worker } else {
372*4bdc9457SAndroid Build Coastguard Worker const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
373*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
374*4bdc9457SAndroid Build Coastguard Worker }
375*4bdc9457SAndroid Build Coastguard Worker last_kk = kk;
376*4bdc9457SAndroid Build Coastguard Worker first_nzz = false;
377*4bdc9457SAndroid Build Coastguard Worker nmap[nn] += 1;
378*4bdc9457SAndroid Build Coastguard Worker }
379*4bdc9457SAndroid Build Coastguard Worker }
380*4bdc9457SAndroid Build Coastguard Worker }
381*4bdc9457SAndroid Build Coastguard Worker // In the end, we must return input pointer to the initial value.
382*4bdc9457SAndroid Build Coastguard Worker const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(uint16_t));
383*4bdc9457SAndroid Build Coastguard Worker dmap[nnz++] = increment;
384*4bdc9457SAndroid Build Coastguard Worker
385*4bdc9457SAndroid Build Coastguard Worker // Generate expanded b which will be used in reference calculation.
386*4bdc9457SAndroid Build Coastguard Worker // Everywhere there is input non-zero in the original we copy it and add an
387*4bdc9457SAndroid Build Coastguard Worker // adjacent non-zero with incremented weight value.
388*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b_full(n() * k());
389*4bdc9457SAndroid Build Coastguard Worker if (nr() == 1) {
390*4bdc9457SAndroid Build Coastguard Worker b_full = b;
391*4bdc9457SAndroid Build Coastguard Worker }
392*4bdc9457SAndroid Build Coastguard Worker else {
393*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = 0; nn < n() / nr(); nn++) {
394*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
395*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
396*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < nr(); ++i)
397*4bdc9457SAndroid Build Coastguard Worker b_full[nr() * nn * k() + i * k() + kk] = fp16_ieee_from_fp32_value(
398*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
399*4bdc9457SAndroid Build Coastguard Worker }
400*4bdc9457SAndroid Build Coastguard Worker }
401*4bdc9457SAndroid Build Coastguard Worker }
402*4bdc9457SAndroid Build Coastguard Worker for (size_t nn = n() / nr(); nn < ncols; nn++) {
403*4bdc9457SAndroid Build Coastguard Worker for (size_t kk = 0; kk < k(); kk++) {
404*4bdc9457SAndroid Build Coastguard Worker if (b[nn * k() + kk] != 0.0f) {
405*4bdc9457SAndroid Build Coastguard Worker b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
406*4bdc9457SAndroid Build Coastguard Worker }
407*4bdc9457SAndroid Build Coastguard Worker }
408*4bdc9457SAndroid Build Coastguard Worker }
409*4bdc9457SAndroid Build Coastguard Worker }
410*4bdc9457SAndroid Build Coastguard Worker
411*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < n(); oc++) {
412*4bdc9457SAndroid Build Coastguard Worker for (size_t pxb = 0; pxb < m(); pxb++) {
413*4bdc9457SAndroid Build Coastguard Worker output_ref[oc * m() + pxb] = fp16_ieee_to_fp32_value(bias[oc]);
414*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < k(); ic++) {
415*4bdc9457SAndroid Build Coastguard Worker output_ref[oc * m() + pxb] += fp16_ieee_to_fp32_value(input[ic * m() + pxb]) * fp16_ieee_to_fp32_value(b_full[oc * k() + ic]);
416*4bdc9457SAndroid Build Coastguard Worker }
417*4bdc9457SAndroid Build Coastguard Worker }
418*4bdc9457SAndroid Build Coastguard Worker }
419*4bdc9457SAndroid Build Coastguard Worker
420*4bdc9457SAndroid Build Coastguard Worker // Micro-kernel can access one element beyond w and dmap for software pipelining.
421*4bdc9457SAndroid Build Coastguard Worker w.resize(wcnt + 1);
422*4bdc9457SAndroid Build Coastguard Worker dmap.resize(nnz + 1);
423*4bdc9457SAndroid Build Coastguard Worker
424*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
425*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
426*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
427*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
428*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
429*4bdc9457SAndroid Build Coastguard Worker
430*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
431*4bdc9457SAndroid Build Coastguard Worker for (float& output_value : output_ref) {
432*4bdc9457SAndroid Build Coastguard Worker output_value = std::min(std::max(output_value, output_min), output_max);
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker
435*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
436*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params;
437*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms,
438*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max));
439*4bdc9457SAndroid Build Coastguard Worker
440*4bdc9457SAndroid Build Coastguard Worker spmm(m() * sizeof(uint16_t), n(),
441*4bdc9457SAndroid Build Coastguard Worker input.data() + first_kk * m(),
442*4bdc9457SAndroid Build Coastguard Worker w.data(), dmap.data(), nmap.data(),
443*4bdc9457SAndroid Build Coastguard Worker output.data(), output_stride() * sizeof(uint16_t),
444*4bdc9457SAndroid Build Coastguard Worker ¶ms);
445*4bdc9457SAndroid Build Coastguard Worker
446*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
447*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
448*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
449*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
450*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[j * output_stride() + i]),
451*4bdc9457SAndroid Build Coastguard Worker output_ref[j * m() + i],
452*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(output_ref[j * m() + i]) * 1.0e-2f))
453*4bdc9457SAndroid Build Coastguard Worker << "at M index " << i << " / " << m() << " (tile " << mr() << ")"
454*4bdc9457SAndroid Build Coastguard Worker << ", N index " << j << " / " << n() << " (tile " << nr() << ")"
455*4bdc9457SAndroid Build Coastguard Worker << ", K = " << k();
456*4bdc9457SAndroid Build Coastguard Worker }
457*4bdc9457SAndroid Build Coastguard Worker }
458*4bdc9457SAndroid Build Coastguard Worker }
459*4bdc9457SAndroid Build Coastguard Worker }
460*4bdc9457SAndroid Build Coastguard Worker
461*4bdc9457SAndroid Build Coastguard Worker private:
462*4bdc9457SAndroid Build Coastguard Worker size_t mr_{1};
463*4bdc9457SAndroid Build Coastguard Worker size_t nr_{1};
464*4bdc9457SAndroid Build Coastguard Worker size_t m_{1};
465*4bdc9457SAndroid Build Coastguard Worker size_t n_{1};
466*4bdc9457SAndroid Build Coastguard Worker size_t k_{1};
467*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0};
468*4bdc9457SAndroid Build Coastguard Worker float sparsity_{0.5f};
469*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0};
470*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255};
471*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1};
472*4bdc9457SAndroid Build Coastguard Worker };
473