xref: /aosp_15_r20/external/XNNPACK/bench/f16-spmm.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <cfloat>
8 #include <cmath>
9 #include <functional>
10 #include <random>
11 #include <vector>
12 
13 #include <cpuinfo.h>
14 
15 #include <benchmark/benchmark.h>
16 #include <fp16/fp16.h>
17 #include "bench/spmm.h"
18 #include "bench/utils.h"
19 
20 #include <xnnpack.h>
21 #include <xnnpack/aligned-allocator.h>
22 #include <xnnpack/common.h>
23 #include <xnnpack/microfnptr.h>
24 #include <xnnpack/microparams-init.h>
25 #include <xnnpack/spmm.h>
26 
27 
f16_spmm(benchmark::State & state,xnn_f16_spmm_minmax_ukernel_function spmm,uint32_t mr,uint32_t nr,float sparsity,xnn_init_f16_minmax_params_fn init_params,benchmark::utils::IsaCheckFunction isa_check=nullptr)28 static void f16_spmm(benchmark::State& state,
29   xnn_f16_spmm_minmax_ukernel_function spmm, uint32_t mr, uint32_t nr, float sparsity,
30   xnn_init_f16_minmax_params_fn init_params,
31   benchmark::utils::IsaCheckFunction isa_check = nullptr)
32 {
33   if (isa_check && !isa_check(state)) {
34     return;
35   }
36   const size_t mc = state.range(0);
37   const size_t nc = state.range(1);
38   const size_t kc = state.range(2);
39 
40   std::random_device random_device;
41   auto rng = std::mt19937(random_device());
42   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
43   auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
44 
45   // if using blocks, generate the reduced matrix first and then extrude along
46   // the block dimension (n), to get the full matrix
47   size_t ncols = nc / nr + nc % nr;
48   std::vector<uint16_t> b(ncols * kc);
49   std::vector<uint16_t> bias(nc);
50   std::vector<uint16_t> w;
51   std::vector<uint32_t> nmap;
52   std::vector<int32_t> dmap;
53   const size_t sparse_end = std::min(size_t(float(b.size()) * sparsity), b.size());
54   const size_t num_nonzeroes = nr * (b.size() - sparse_end);
55 
56   const size_t w_elements = num_nonzeroes + nc;
57   const size_t c_elements = mc * nc;
58   const size_t dmap_elements = num_nonzeroes / nr;
59   const size_t nmap_elements = nc;
60   const size_t num_buffers = 1 +
61     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
62       sizeof(uint16_t) * (w_elements + c_elements) + sizeof(uint32_t) * (dmap_elements + nmap_elements));
63 
64   // Micro-kernel can access one element beyond w and dmap for software pipelining.
65   w.reserve(num_buffers * w_elements + 1);
66   dmap.reserve(num_buffers * dmap_elements + 1);
67   nmap.resize(num_buffers * nmap_elements);
68 
69   std::vector<size_t> a_offsets(num_buffers);
70 
71   for (size_t buffer_index = 0; buffer_index < num_buffers; buffer_index++) {
72     // Re-generate weights. Note: each re-generation produces the number of non-zeroes.
73     std::fill(b.begin(), b.begin() + sparse_end, 0);
74     std::generate(b.begin() + sparse_end, b.end(), std::ref(f16rng));
75     std::shuffle(b.begin(), b.end(), rng);
76     std::generate(bias.begin(), bias.end(), std::ref(f16rng));
77 
78     uint32_t first_j = 0, last_j = 0;
79     bool is_first_nonzero = true;
80     for (uint32_t i = 0; i < nc / nr; i++) {
81       for (uint32_t n = 0; n < nr; n++)
82         w.push_back(bias[nr * i + n]);
83       for (uint32_t j = 0; j < kc; j++) {
84         if ((b[i * kc + j] & 0x7FFF) != 0) {
85           for (size_t l = 0; l < nr; l++)
86             w.push_back(fp16_ieee_from_fp32_value(fp16_ieee_to_fp32_value(b[i * kc + j]) + static_cast<float>(i)));
87           if (is_first_nonzero) {
88             first_j = j;
89           } else {
90             const ptrdiff_t increment = int32_t(j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
91             dmap.push_back(increment);
92           }
93           last_j = j;
94           is_first_nonzero = false;
95           nmap[buffer_index * nmap_elements + i] += 1;
96         }
97       }
98     }
99     for (uint32_t i = nc / nr; i < ncols; i++) {
100       w.push_back(bias[i]);
101       for (uint32_t j = 0; j < kc; j++) {
102         if ((b[i * kc + j] & 0x7FFF) != 0) {
103           w.push_back(b[i * kc + j]);
104           if (is_first_nonzero) {
105             first_j = j;
106           } else {
107             const ptrdiff_t increment = int32_t(j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
108             dmap.push_back(increment);
109           }
110           last_j = j;
111           is_first_nonzero = false;
112           nmap[buffer_index * nmap_elements + i] += 1;
113         }
114       }
115     }
116     {
117       const ptrdiff_t increment = int32_t(first_j - last_j) * int32_t(mc) * int32_t(sizeof(uint16_t));
118       dmap.push_back(increment);
119     }
120 
121     a_offsets[buffer_index] = first_j * mc;
122   }
123 
124   // Micro-kernel can access one element beyond w and dmap for software pipelining.
125   w.resize(w.size() + 1);
126   dmap.resize(dmap.size() + 1);
127 
128   std::vector<float, AlignedAllocator<float, 64>> a(kc * mc);
129   std::vector<float, AlignedAllocator<float, 64>> c(num_buffers * c_elements);
130 
131   std::generate(a.begin(), a.end(), std::ref(f32rng));
132   std::fill(c.begin(), c.end(), nanf(""));
133 
134   xnn_f16_minmax_params params;
135   init_params(&params, 0x7C00 /* inf */, 0xFC00 /* -inf */);
136 
137   size_t buffer_index = 0;
138   for (auto _ : state) {
139     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
140     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
141     // - W, Kmap, and Nmap is not in cache (for any cache level)
142     // - C is not in cache (for any cache level)
143     state.PauseTiming();
144     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(uint16_t));
145     buffer_index = (buffer_index + 1) % num_buffers;
146     state.ResumeTiming();
147 
148     spmm(mc * sizeof(uint16_t), nc,
149       a.data() + a_offsets[buffer_index],
150       w.data() + buffer_index * w_elements,
151       dmap.data() + buffer_index * dmap_elements,
152       nmap.data() + buffer_index * nmap_elements,
153       c.data() + buffer_index * c_elements, mc * sizeof(uint16_t),
154       &params);
155   }
156 
157   const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
158   if (cpu_frequency != 0) {
159     state.counters["cpufreq"] = cpu_frequency;
160   }
161 
162   state.counters["FLOPS"] = benchmark::Counter(
163     uint64_t(state.iterations()) * 2 * mc * num_nonzeroes, benchmark::Counter::kIsRate);
164 
165   state.counters["EffFLOPS"] = benchmark::Counter(
166     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
167 }
168 
169 
170 #if XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
spmm80_8x1__neonfp16arith(benchmark::State & state,const char * net)171   static void spmm80_8x1__neonfp16arith(benchmark::State& state, const char* net) {
172     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_8x1__neonfp16arith, 8, 1, 0.8f,
173       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
174   }
spmm80_8x1__neonfp16arith_x2(benchmark::State & state,const char * net)175   static void spmm80_8x1__neonfp16arith_x2(benchmark::State& state, const char* net) {
176     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_8x1__neonfp16arith_x2, 8, 1, 0.8f,
177       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
178   }
spmm80_16x1__neonfp16arith(benchmark::State & state,const char * net)179   static void spmm80_16x1__neonfp16arith(benchmark::State& state, const char* net) {
180     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_16x1__neonfp16arith, 16, 1, 0.8f,
181       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
182   }
spmm80_16x1__neonfp16arith_x2(benchmark::State & state,const char * net)183   static void spmm80_16x1__neonfp16arith_x2(benchmark::State& state, const char* net) {
184     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_16x1__neonfp16arith_x2, 16, 1, 0.8f,
185       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
186   }
spmm80_24x1__neonfp16arith(benchmark::State & state,const char * net)187   static void spmm80_24x1__neonfp16arith(benchmark::State& state, const char* net) {
188     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_24x1__neonfp16arith, 24, 1, 0.8f,
189       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
190   }
spmm80_24x1__neonfp16arith_x2(benchmark::State & state,const char * net)191   static void spmm80_24x1__neonfp16arith_x2(benchmark::State& state, const char* net) {
192     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_24x1__neonfp16arith_x2, 24, 1, 0.8f,
193       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
194   }
spmm80_32x1__neonfp16arith(benchmark::State & state,const char * net)195   static void spmm80_32x1__neonfp16arith(benchmark::State& state, const char* net) {
196     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_32x1__neonfp16arith, 32, 1, 0.8f,
197       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
198   }
spmm80_32x1__neonfp16arith_x2(benchmark::State & state,const char * net)199   static void spmm80_32x1__neonfp16arith_x2(benchmark::State& state, const char* net) {
200     f16_spmm(state, xnn_f16_spmm_minmax_ukernel_32x1__neonfp16arith_x2, 32, 1, 0.8f,
201       xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
202   }
203 
204   BENCHMARK_SPMM(spmm80_8x1__neonfp16arith)
205   BENCHMARK_SPMM(spmm80_8x1__neonfp16arith_x2)
206   BENCHMARK_SPMM(spmm80_16x1__neonfp16arith)
207   BENCHMARK_SPMM(spmm80_16x1__neonfp16arith_x2)
208   BENCHMARK_SPMM(spmm80_24x1__neonfp16arith)
209   BENCHMARK_SPMM(spmm80_24x1__neonfp16arith_x2)
210   BENCHMARK_SPMM(spmm80_32x1__neonfp16arith)
211   BENCHMARK_SPMM(spmm80_32x1__neonfp16arith_x2)
212 #endif  // XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
213 
214 #ifndef XNNPACK_BENCHMARK_NO_MAIN
215 BENCHMARK_MAIN();
216 #endif
217