xref: /aosp_15_r20/external/XNNPACK/bench/bf16-gemm.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <algorithm>
7 #include <cfloat>
8 #include <cmath>
9 #include <functional>
10 #include <random>
11 #include <vector>
12 
13 #include <cpuinfo.h>
14 
15 #include <benchmark/benchmark.h>
16 #include <fp16/fp16.h>
17 #include "bench/gemm.h"
18 #include "bench/utils.h"
19 
20 #include <xnnpack.h>
21 #include <xnnpack/aligned-allocator.h>
22 #include <xnnpack/common.h>
23 #include <xnnpack/gemm.h>
24 #include <xnnpack/math.h>
25 #include <xnnpack/pack.h>
26 #include <xnnpack/microfnptr.h>
27 #include <xnnpack/microparams-init.h>
28 
29 
bf16_gemm(benchmark::State & state,xnn_bf16_gemm_minmax_ukernel_function gemm,size_t mr,size_t nr,size_t kr,size_t sr,xnn_init_bf16_minmax_params_fn init_params,benchmark::utils::IsaCheckFunction isa_check=nullptr)30 static void bf16_gemm(benchmark::State& state,
31   xnn_bf16_gemm_minmax_ukernel_function gemm,
32   size_t mr, size_t nr, size_t kr, size_t sr,
33   xnn_init_bf16_minmax_params_fn init_params,
34   benchmark::utils::IsaCheckFunction isa_check = nullptr)
35 {
36   if (isa_check && !isa_check(state)) {
37     return;
38   }
39 
40   const size_t mc = state.range(0);
41   const size_t nc = state.range(1);
42   const size_t kc = state.range(2);
43 
44   const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
45   const size_t kc_stride = benchmark::utils::RoundUp(kc, kr * sr);
46 
47   std::random_device random_device;
48   auto rng = std::mt19937(random_device());
49   auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
50 
51   std::vector<uint16_t> a(mc * kc + XNN_EXTRA_BYTES / sizeof(uint16_t));
52   std::generate(a.begin(), a.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
53   std::vector<uint16_t> k(nc * kc);
54   std::generate(k.begin(), k.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
55   std::vector<uint16_t> b(nc);
56   std::generate(b.begin(), b.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
57 
58   const size_t w_elements = nc_stride * kc_stride + nc_stride;
59   const size_t c_elements = mc * nc;
60   const size_t num_buffers = 1 +
61     benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
62       sizeof(uint16_t) * (w_elements + c_elements));
63 
64   std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> w(w_elements * num_buffers);
65   std::fill(w.begin(), w.end(), 0);
66   xnn_pack_f16_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data(), 0, nullptr);
67   std::vector<uint16_t> c(c_elements * num_buffers);
68   std::fill(c.begin(), c.end(), UINT16_C(0x7FC0) /* NaN */);
69 
70   // Prepare minmax parameters.
71   xnn_bf16_minmax_params params;
72   init_params(&params,
73     UINT16_C(0xFF80)  /* -inf */, UINT16_C(0x7F80)  /* inf */);
74 
75   size_t buffer_index = 0;
76   for (auto _ : state) {
77     // Use circular buffers (exceeding cache size) and prefetch to control cache state:
78     // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
79     // - W is not in cache (for any cache level)
80     // - C is not in cache (for any cache level)
81     state.PauseTiming();
82     benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(uint16_t));
83     buffer_index = (buffer_index + 1) % num_buffers;
84     state.ResumeTiming();
85 
86     for (uint32_t m = 0; m < mc; m += mr) {
87       const uint32_t mb = min(mc - m, mr);
88       for (uint32_t n = 0; n < nc; n += nr) {
89         const uint32_t nb = min(nc - n, nr);
90         gemm(
91           mb, nb, kc * sizeof(uint16_t),
92           a.data() + m * kc, kc * sizeof(uint16_t),
93           w.data() + (nc_stride * buffer_index + n) * (kc_stride + 1),
94           c.data() + (mc * buffer_index + m) * nc + n, nc * sizeof(uint16_t), nr * sizeof(uint16_t),
95           &params);
96       }
97     }
98   }
99 
100   const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
101   if (cpu_frequency != 0) {
102     state.counters["cpufreq"] = cpu_frequency;
103   }
104 
105   state.counters["FLOPS"] = benchmark::Counter(
106     uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
107 }
108 
109 
110 #if XNN_ENABLE_ARM_BF16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
bf16_gemm_1x8c2__neonbf16_bfdot_lane_ld128(benchmark::State & state,const char * net)111   static void bf16_gemm_1x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) {
112     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_1x8c2__neonbf16_bfdot_lane_ld128, 1, 8, 2, 1,
113       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
114   }
bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128(benchmark::State & state,const char * net)115   static void bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) {
116     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_4x8c2__neonbf16_bfdot_lane_ld128, 4, 8, 2, 1,
117       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
118   }
bf16_gemm_5x8c2__neonbf16_bfdot_lane_ld128(benchmark::State & state,const char * net)119   static void bf16_gemm_5x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) {
120     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_5x8c2__neonbf16_bfdot_lane_ld128, 5, 8, 2, 1,
121       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
122   }
bf16_gemm_6x8c2__neonbf16_bfdot_lane_ld128(benchmark::State & state,const char * net)123   static void bf16_gemm_6x8c2__neonbf16_bfdot_lane_ld128(benchmark::State& state, const char* net) {
124     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_6x8c2__neonbf16_bfdot_lane_ld128, 6, 8, 2, 1,
125       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
126   }
127 
bf16_gemm_1x4c8__neonbf16_bfdot(benchmark::State & state,const char * net)128   static void bf16_gemm_1x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) {
129     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfdot, 1, 4, 8, 1,
130       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
131   }
bf16_gemm_2x4c8__neonbf16_bfdot(benchmark::State & state,const char * net)132   static void bf16_gemm_2x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) {
133     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfdot, 2, 4, 8, 1,
134       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
135   }
bf16_gemm_3x4c8__neonbf16_bfdot(benchmark::State & state,const char * net)136   static void bf16_gemm_3x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) {
137     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfdot, 3, 4, 8, 1,
138       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
139   }
bf16_gemm_4x4c8__neonbf16_bfdot(benchmark::State & state,const char * net)140   static void bf16_gemm_4x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) {
141     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfdot, 4, 4, 8, 1,
142       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
143   }
bf16_gemm_5x4c8__neonbf16_bfdot(benchmark::State & state,const char * net)144   static void bf16_gemm_5x4c8__neonbf16_bfdot(benchmark::State& state, const char* net) {
145     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfdot, 5, 4, 8, 1,
146       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
147   }
148 
bf16_gemm_1x4c8__neonbf16_bfmlal(benchmark::State & state,const char * net)149   static void bf16_gemm_1x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) {
150     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonbf16_bfmlal, 1, 4, 8, 1,
151       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
152   }
bf16_gemm_2x4c8__neonbf16_bfmlal(benchmark::State & state,const char * net)153   static void bf16_gemm_2x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) {
154     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonbf16_bfmlal, 2, 4, 8, 1,
155       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
156   }
bf16_gemm_3x4c8__neonbf16_bfmlal(benchmark::State & state,const char * net)157   static void bf16_gemm_3x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) {
158     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonbf16_bfmlal, 3, 4, 8, 1,
159       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
160   }
bf16_gemm_4x4c8__neonbf16_bfmlal(benchmark::State & state,const char * net)161   static void bf16_gemm_4x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) {
162     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonbf16_bfmlal, 4, 4, 8, 1,
163       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
164   }
bf16_gemm_5x4c8__neonbf16_bfmlal(benchmark::State & state,const char * net)165   static void bf16_gemm_5x4c8__neonbf16_bfmlal(benchmark::State& state, const char* net) {
166     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonbf16_bfmlal, 5, 4, 8, 1,
167       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONBF16);
168   }
169 
170   BENCHMARK_GEMM(bf16_gemm_1x8c2__neonbf16_bfdot_lane_ld128)
BENCHMARK_GEMM(bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128)171   BENCHMARK_GEMM(bf16_gemm_4x8c2__neonbf16_bfdot_lane_ld128)
172   BENCHMARK_GEMM(bf16_gemm_5x8c2__neonbf16_bfdot_lane_ld128)
173   BENCHMARK_GEMM(bf16_gemm_6x8c2__neonbf16_bfdot_lane_ld128)
174 
175   BENCHMARK_GEMM(bf16_gemm_1x4c8__neonbf16_bfdot)
176   BENCHMARK_GEMM(bf16_gemm_2x4c8__neonbf16_bfdot)
177   BENCHMARK_GEMM(bf16_gemm_3x4c8__neonbf16_bfdot)
178   BENCHMARK_GEMM(bf16_gemm_4x4c8__neonbf16_bfdot)
179   BENCHMARK_GEMM(bf16_gemm_5x4c8__neonbf16_bfdot)
180 
181   BENCHMARK_GEMM(bf16_gemm_1x4c8__neonbf16_bfmlal)
182   BENCHMARK_GEMM(bf16_gemm_2x4c8__neonbf16_bfmlal)
183   BENCHMARK_GEMM(bf16_gemm_3x4c8__neonbf16_bfmlal)
184   BENCHMARK_GEMM(bf16_gemm_4x4c8__neonbf16_bfmlal)
185   BENCHMARK_GEMM(bf16_gemm_5x4c8__neonbf16_bfmlal)
186 #endif  // XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
187 
188 #if XNN_ARCH_ARM64
189   static void bf16_gemm_1x4c8__neonfma_zip(benchmark::State& state, const char* net) {
190     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_zip, 1, 4, 8, 1,
191       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
192   }
bf16_gemm_2x4c8__neonfma_zip(benchmark::State & state,const char * net)193   static void bf16_gemm_2x4c8__neonfma_zip(benchmark::State& state, const char* net) {
194     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_zip, 2, 4, 8, 1,
195       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
196   }
bf16_gemm_3x4c8__neonfma_zip(benchmark::State & state,const char * net)197   static void bf16_gemm_3x4c8__neonfma_zip(benchmark::State& state, const char* net) {
198     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_zip, 3, 4, 8, 1,
199       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
200   }
bf16_gemm_4x4c8__neonfma_zip(benchmark::State & state,const char * net)201   static void bf16_gemm_4x4c8__neonfma_zip(benchmark::State& state, const char* net) {
202     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_zip, 4, 4, 8, 1,
203       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
204   }
bf16_gemm_5x4c8__neonfma_zip(benchmark::State & state,const char * net)205   static void bf16_gemm_5x4c8__neonfma_zip(benchmark::State& state, const char* net) {
206     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_zip, 5, 4, 8, 1,
207       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
208   }
209 
210   BENCHMARK_GEMM(bf16_gemm_1x4c8__neonfma_zip)
BENCHMARK_GEMM(bf16_gemm_2x4c8__neonfma_zip)211   BENCHMARK_GEMM(bf16_gemm_2x4c8__neonfma_zip)
212   BENCHMARK_GEMM(bf16_gemm_3x4c8__neonfma_zip)
213   BENCHMARK_GEMM(bf16_gemm_4x4c8__neonfma_zip)
214   BENCHMARK_GEMM(bf16_gemm_5x4c8__neonfma_zip)
215 #endif  // XNN_ARCH_ARM64
216 
217 #if XNN_ARCH_ARM || XNN_ARCH_ARM64
218   static void bf16_gemm_1x4c8__neonfma_shland(benchmark::State& state, const char* net) {
219     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_1x4c8__neonfma_shland, 1, 4, 8, 1,
220       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
221   }
bf16_gemm_2x4c8__neonfma_shland(benchmark::State & state,const char * net)222   static void bf16_gemm_2x4c8__neonfma_shland(benchmark::State& state, const char* net) {
223     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_2x4c8__neonfma_shland, 2, 4, 8, 1,
224       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
225   }
bf16_gemm_3x4c8__neonfma_shland(benchmark::State & state,const char * net)226   static void bf16_gemm_3x4c8__neonfma_shland(benchmark::State& state, const char* net) {
227     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_3x4c8__neonfma_shland, 3, 4, 8, 1,
228       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
229   }
bf16_gemm_4x4c8__neonfma_shland(benchmark::State & state,const char * net)230   static void bf16_gemm_4x4c8__neonfma_shland(benchmark::State& state, const char* net) {
231     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_4x4c8__neonfma_shland, 4, 4, 8, 1,
232       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
233   }
bf16_gemm_5x4c8__neonfma_shland(benchmark::State & state,const char * net)234   static void bf16_gemm_5x4c8__neonfma_shland(benchmark::State& state, const char* net) {
235     bf16_gemm(state, xnn_bf16_gemm_minmax_ukernel_5x4c8__neonfma_shland, 5, 4, 8, 1,
236       xnn_init_bf16_minmax_scalar_params, benchmark::utils::CheckNEONFMA);
237   }
238 
239   BENCHMARK_GEMM(bf16_gemm_1x4c8__neonfma_shland)
240   BENCHMARK_GEMM(bf16_gemm_2x4c8__neonfma_shland)
241   BENCHMARK_GEMM(bf16_gemm_3x4c8__neonfma_shland)
242   BENCHMARK_GEMM(bf16_gemm_4x4c8__neonfma_shland)
243   BENCHMARK_GEMM(bf16_gemm_5x4c8__neonfma_shland)
244 #endif  // XNN_ARCH_ARM || XNN_ARCH_ARM64
245 
246 #ifndef XNNPACK_BENCHMARK_NO_MAIN
247 BENCHMARK_MAIN();
248 #endif
249