xref: /aosp_15_r20/external/XNNPACK/test/spmm-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <algorithm>
11 #include <cassert>
12 #include <cmath>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <random>
16 #include <vector>
17 
18 #include <fp16.h>
19 
20 #include <xnnpack.h>
21 #include <xnnpack/aligned-allocator.h>
22 #include <xnnpack/microfnptr.h>
23 #include <xnnpack/microparams-init.h>
24 
25 
is_fp16_zero(uint16_t x)26 static inline bool is_fp16_zero(uint16_t x) {
27   const uint16_t two_x = x + x;
28   return two_x == 0;
29 }
30 
31 class SpMMMicrokernelTester {
32  public:
mr(size_t mr)33   inline SpMMMicrokernelTester& mr(size_t mr) {
34     this->mr_ = mr;
35     return *this;
36   }
37 
mr()38   inline size_t mr() const {
39     return this->mr_;
40   }
41 
nr(size_t nr)42   inline SpMMMicrokernelTester& nr(size_t nr) {
43     this->nr_ = nr;
44     return *this;
45   }
46 
nr()47   inline size_t nr() const {
48     return this->nr_;
49   }
50 
m(size_t m)51   inline SpMMMicrokernelTester& m(size_t m) {
52     this->m_ = m;
53     return *this;
54   }
55 
m()56   inline size_t m() const {
57     return this->m_;
58   }
59 
n(size_t n)60   inline SpMMMicrokernelTester& n(size_t n) {
61     this->n_ = n;
62     return *this;
63   }
64 
n()65   inline size_t n() const {
66     return this->n_;
67   }
68 
k(size_t k)69   inline SpMMMicrokernelTester& k(size_t k) {
70     this->k_ = k;
71     return *this;
72   }
73 
k()74   inline size_t k() const {
75     return this->k_;
76   }
77 
output_stride(size_t output_stride)78   inline SpMMMicrokernelTester& output_stride(size_t output_stride) {
79     assert(output_stride != 0);
80     this->output_stride_ = output_stride;
81     return *this;
82   }
83 
output_stride()84   inline size_t output_stride() const {
85     if (this->output_stride_ == 0) {
86       return m();
87     } else {
88       assert(this->output_stride_ >= m());
89       return this->output_stride_;
90     }
91   }
92 
sparsity(float sparsity)93   inline SpMMMicrokernelTester& sparsity(float sparsity) {
94     this->sparsity_ = sparsity;
95     return *this;
96   }
97 
sparsity()98   inline float sparsity() const {
99     return this->sparsity_;
100   }
101 
qmin(uint8_t qmin)102   inline SpMMMicrokernelTester& qmin(uint8_t qmin) {
103     this->qmin_ = qmin;
104     return *this;
105   }
106 
qmin()107   inline uint8_t qmin() const {
108     return this->qmin_;
109   }
110 
qmax(uint8_t qmax)111   inline SpMMMicrokernelTester& qmax(uint8_t qmax) {
112     this->qmax_ = qmax;
113     return *this;
114   }
115 
qmax()116   inline uint8_t qmax() const {
117     return this->qmax_;
118   }
119 
iterations(size_t iterations)120   inline SpMMMicrokernelTester& iterations(size_t iterations) {
121     this->iterations_ = iterations;
122     return *this;
123   }
124 
iterations()125   inline size_t iterations() const {
126     return this->iterations_;
127   }
128 
Test(xnn_f32_spmm_minmax_ukernel_function spmm,xnn_init_f32_minmax_params_fn init_params)129   void Test(xnn_f32_spmm_minmax_ukernel_function spmm, xnn_init_f32_minmax_params_fn init_params) const {
130     ASSERT_GE(m(), 1);
131     ASSERT_GE(n(), 1);
132     ASSERT_GE(k(), 1);
133 
134     std::random_device random_device;
135     auto rng = std::mt19937(random_device());
136     std::uniform_real_distribution<float> f32dist;
137     std::uniform_real_distribution<float> pdist;
138 
139     std::vector<float, AlignedAllocator<float, 64>> input(k() * m());
140     // Think of b as (n/nr + n % nr) x k, expansion happens later.
141     const size_t ncols = n() / nr() + n() % nr();
142     std::vector<float> b(ncols * k());
143     std::vector<float> bias(n());
144     // Number of non-zero weights per N (output channel).
145     std::vector<uint32_t> nmap(n());
146     // Mapping from index of non-zero weight to increment of K (input channel) following this index.
147     std::vector<int32_t> dmap(n() * k());
148     std::vector<float> w(n() * k() + n());
149     std::vector<float> output((n() - 1) * output_stride() + m());
150     std::vector<float> output_ref(n() * m());
151 
152     for (size_t iteration = 0; iteration < iterations(); iteration++) {
153       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
154       std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
155       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
156       std::fill(output.begin(), output.end(), nanf(""));
157       std::fill(output_ref.begin(), output_ref.end(), 0.0f);
158       std::fill(nmap.begin(), nmap.end(), 0);
159       std::fill(dmap.begin(), dmap.end(), 0);
160       std::fill(w.begin(), w.end(), 0.0f);
161 
162       for (float& b_value : b) {
163         if (pdist(rng) <= sparsity()) {
164           b_value = 0.0f;
165         }
166       }
167 
168       uint32_t nnz = 0;
169       uint32_t wcnt = 0;
170       size_t last_kk = 0;
171       bool first_nzz = true;
172       size_t first_kk = 0;
173       for (size_t nn = 0; nn < n() / nr(); nn++) {
174         for (size_t i = 0; i < nr(); ++i)
175           w[wcnt++] = bias[nr() * nn + i];
176         for (size_t kk = 0; kk < k(); kk++) {
177           if (b[nn * k() + kk] != 0.0f) {
178             // Every non-zero actually corresponds to nr adjacent non-zeros.
179             for (size_t i = 0; i < nr(); ++i)
180               w[wcnt++] = b[nn * k() + kk] + static_cast<float>(i);
181             // Skip the very first non-zero weight as we record only the difference.
182             if (first_nzz) {
183               first_kk = kk;
184             } else {
185               const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(float));
186               dmap[nnz++] = increment;
187             }
188             last_kk = kk;
189             first_nzz = false;
190             nmap[nn] += 1;
191           }
192         }
193       }
194 
195       // now we've constructed the matrix for the blocked part and switch to the
196       // leftovers, which we do as nr=1 always.
197       for (size_t nn = n() / nr(); nn < ncols; nn++) {
198         w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
199         for (size_t kk = 0; kk < k(); kk++) {
200           if (b[nn * k() + kk] != 0.0f) {
201             // Every non-zero actually corresponds to nr adjacent non-zeros.
202             w[wcnt++] = b[nn * k() + kk];
203             // Skip the very first non-zero weight as we record only the difference.
204             if (first_nzz) {
205               first_kk = kk;
206             } else {
207               const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(float));
208               dmap[nnz++] = increment;
209             }
210             last_kk = kk;
211             first_nzz = false;
212             nmap[nn] += 1;
213           }
214         }
215       }
216       // In the end, we must return input pointer to the initial value.
217       const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(float));
218       dmap[nnz++] = increment;
219 
220       // Generate expanded b which will be used in reference calculation.
221       // Everywhere there is input non-zero in the original we copy it and add an
222       // adjacent non-zero with incremented weight value.
223       std::vector<float> b_full(n() * k());
224       if (nr() == 1) {
225          b_full = b;
226       }
227       else {
228         for (size_t nn = 0; nn < n() / nr(); nn++) {
229           for (size_t kk = 0; kk < k(); kk++) {
230             if (b[nn * k() + kk] != 0.0f) {
231               for (size_t i = 0; i < nr(); ++i)
232                 b_full[nr() * nn * k() + i * k() + kk] = b[nn * k() + kk] + static_cast<float>(i);
233             }
234           }
235         }
236         for (size_t nn = n() / nr(); nn < ncols; nn++) {
237           for (size_t kk = 0; kk < k(); kk++) {
238             if (b[nn * k() + kk] != 0.0f) {
239               b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
240             }
241           }
242         }
243       }
244 
245       for (size_t oc = 0; oc < n(); oc++) {
246         for (size_t pxb = 0; pxb < m(); pxb++) {
247           output_ref[oc * m() + pxb] = bias[oc];
248           for (size_t ic = 0; ic < k(); ic++) {
249             output_ref[oc * m() + pxb] += input[ic * m() + pxb] * b_full[oc * k() + ic];
250           }
251         }
252       }
253 
254       // Micro-kernel can access one element beyond w and dmap for software pipelining.
255       w.resize(wcnt + 1);
256       dmap.resize(nnz + 1);
257 
258       // Compute clamping parameters.
259       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
260       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
261       const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
262       const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
263 
264       // Clamp reference results.
265       for (float& output_value : output_ref) {
266         output_value = std::min(std::max(output_value, output_min), output_max);
267       }
268 
269       // Prepare parameters.
270       xnn_f32_minmax_params params;
271       init_params(&params, output_min, output_max);
272 
273       spmm(m() * sizeof(float), n(),
274         input.data() + first_kk * m(),
275         w.data(), dmap.data(), nmap.data(),
276         output.data(), output_stride() * sizeof(float),
277         &params);
278 
279       // Validate micro-kernel outputs.
280       for (size_t i = 0; i < m(); i++) {
281         for (size_t j = 0; j < n(); j++) {
282           ASSERT_NEAR(
283               output[j * output_stride() + i],
284               output_ref[j * m() + i],
285               std::abs(output_ref[j * m() + i]) * 1.0e-6f)
286             << "at M index " << i << " / " << m() << " (tile " << mr() << ")"
287             << ", N index " << j << " / " << n() << " (tile " << nr() << ")"
288             << ", K = " << k();
289         }
290       }
291     }
292   }
293 
Test(xnn_f16_spmm_minmax_ukernel_function spmm,xnn_init_f16_minmax_params_fn init_params)294   void Test(xnn_f16_spmm_minmax_ukernel_function spmm, xnn_init_f16_minmax_params_fn init_params) const {
295     ASSERT_GE(m(), 1);
296     ASSERT_GE(n(), 1);
297     ASSERT_GE(k(), 1);
298 
299     std::random_device random_device;
300     auto rng = std::mt19937(random_device());
301     std::uniform_real_distribution<float> f32dist;
302     std::uniform_real_distribution<float> pdist;
303 
304     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> input(k() * m());
305     // Think of b as (n/nr + n % nr) x k, expansion happens later.
306     const size_t ncols = n() / nr() + n() % nr();
307     std::vector<uint16_t> b(ncols * k());
308     std::vector<uint16_t> bias(n());
309     // Number of non-zero weights per N (output channel).
310     std::vector<uint32_t> nmap(n());
311     // Mapping from index of non-zero weight to increment of K (input channel) following this index.
312     std::vector<int32_t> dmap(n() * k());
313     std::vector<uint16_t> w(n() * k() + n());
314     std::vector<uint16_t> output((n() - 1) * output_stride() + m());
315     std::vector<float> output_ref(n() * m());
316 
317     for (size_t iteration = 0; iteration < iterations(); iteration++) {
318       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
319       std::generate(b.begin(), b.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
320       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
321       std::fill(output.begin(), output.end(), 0xC000);
322       std::fill(output_ref.begin(), output_ref.end(), 0.0f);
323       std::fill(nmap.begin(), nmap.end(), 0);
324       std::fill(dmap.begin(), dmap.end(), 0);
325       std::fill(w.begin(), w.end(), 0);
326 
327       for (uint16_t& b_value : b) {
328         if (pdist(rng) <= sparsity()) {
329           b_value = 0;
330         }
331       }
332 
333       uint32_t nnz = 0;
334       uint32_t wcnt = 0;
335       size_t last_kk = 0;
336       bool first_nzz = true;
337       size_t first_kk = 0;
338       for (size_t nn = 0; nn < n() / nr(); nn++) {
339         for (size_t i = 0; i < nr(); ++i)
340           w[wcnt++] = bias[nr() * nn + i];
341         for (size_t kk = 0; kk < k(); kk++) {
342           if (!is_fp16_zero(b[nn * k() + kk])) {
343             // Every non-zero actually corresponds to nr adjacent non-zeros.
344             for (size_t i = 0; i < nr(); ++i)
345               w[wcnt++] = fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
346             // Skip the very first non-zero weight as we record only the difference.
347             if (first_nzz) {
348               first_kk = kk;
349             } else {
350               const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
351               dmap[nnz++] = increment;
352             }
353             last_kk = kk;
354             first_nzz = false;
355             nmap[nn] += 1;
356           }
357         }
358       }
359 
360       // now we've constructed the matrix for the blocked part and switch to the
361       // leftovers, which we do as nr=1 always.
362       for (size_t nn = n() / nr(); nn < ncols; nn++) {
363         w[wcnt++] = bias[(n() / nr()) * nr() + (nn - n() / nr())];
364         for (size_t kk = 0; kk < k(); kk++) {
365           if (!is_fp16_zero(b[nn * k() + kk])) {
366             // Every non-zero actually corresponds to nr adjacent non-zeros.
367             w[wcnt++] = b[nn * k() + kk];
368             // Skip the very first non-zero weight as we record only the difference.
369             if (first_nzz) {
370               first_kk = kk;
371             } else {
372               const int32_t increment = int32_t(kk - last_kk) * int32_t(m() * sizeof(uint16_t));
373               dmap[nnz++] = increment;
374             }
375             last_kk = kk;
376             first_nzz = false;
377             nmap[nn] += 1;
378           }
379         }
380       }
381       // In the end, we must return input pointer to the initial value.
382       const int64_t increment = int32_t(first_kk - last_kk) * int32_t(m() * sizeof(uint16_t));
383       dmap[nnz++] = increment;
384 
385       // Generate expanded b which will be used in reference calculation.
386       // Everywhere there is input non-zero in the original we copy it and add an
387       // adjacent non-zero with incremented weight value.
388       std::vector<uint16_t> b_full(n() * k());
389       if (nr() == 1) {
390          b_full = b;
391       }
392       else {
393         for (size_t nn = 0; nn < n() / nr(); nn++) {
394           for (size_t kk = 0; kk < k(); kk++) {
395             if (b[nn * k() + kk] != 0.0f) {
396               for (size_t i = 0; i < nr(); ++i)
397                 b_full[nr() * nn * k() + i * k() + kk] = fp16_ieee_from_fp32_value(
398                   fp16_ieee_to_fp32_value(b[nn * k() + kk]) + static_cast<float>(i));
399             }
400           }
401         }
402         for (size_t nn = n() / nr(); nn < ncols; nn++) {
403           for (size_t kk = 0; kk < k(); kk++) {
404             if (b[nn * k() + kk] != 0.0f) {
405               b_full[nr() * (n() / nr()) * k() + (nn - n() / nr()) * k() + kk] = b[nn * k() + kk];
406             }
407           }
408         }
409       }
410 
411       for (size_t oc = 0; oc < n(); oc++) {
412         for (size_t pxb = 0; pxb < m(); pxb++) {
413           output_ref[oc * m() + pxb] = fp16_ieee_to_fp32_value(bias[oc]);
414           for (size_t ic = 0; ic < k(); ic++) {
415             output_ref[oc * m() + pxb] += fp16_ieee_to_fp32_value(input[ic * m() + pxb]) * fp16_ieee_to_fp32_value(b_full[oc * k() + ic]);
416           }
417         }
418       }
419 
420       // Micro-kernel can access one element beyond w and dmap for software pipelining.
421       w.resize(wcnt + 1);
422       dmap.resize(nnz + 1);
423 
424       // Compute clamping parameters.
425       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
426       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
427       const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
428       const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
429 
430       // Clamp reference results.
431       for (float& output_value : output_ref) {
432         output_value = std::min(std::max(output_value, output_min), output_max);
433       }
434 
435       // Prepare parameters.
436       xnn_f16_minmax_params params;
437       init_params(&params,
438         fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max));
439 
440       spmm(m() * sizeof(uint16_t), n(),
441         input.data() + first_kk * m(),
442         w.data(), dmap.data(), nmap.data(),
443         output.data(), output_stride() * sizeof(uint16_t),
444         &params);
445 
446       // Validate micro-kernel outputs.
447       for (size_t i = 0; i < m(); i++) {
448         for (size_t j = 0; j < n(); j++) {
449           ASSERT_NEAR(
450               fp16_ieee_to_fp32_value(output[j * output_stride() + i]),
451               output_ref[j * m() + i],
452               std::max(1.0e-4f, std::abs(output_ref[j * m() + i]) * 1.0e-2f))
453             << "at M index " << i << " / " << m() << " (tile " << mr() << ")"
454             << ", N index " << j << " / " << n() << " (tile " << nr() << ")"
455             << ", K = " << k();
456         }
457       }
458     }
459   }
460 
461  private:
462   size_t mr_{1};
463   size_t nr_{1};
464   size_t m_{1};
465   size_t n_{1};
466   size_t k_{1};
467   size_t output_stride_{0};
468   float sparsity_{0.5f};
469   uint8_t qmin_{0};
470   uint8_t qmax_{255};
471   size_t iterations_{1};
472 };
473