1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #include <algorithm>
10 #include <cfloat>
11 #include <cmath>
12 #include <functional>
13 #include <random>
14 #include <vector>
15
16 #include <cpuinfo.h>
17
18 #include <benchmark/benchmark.h>
19 #include <fp16/fp16.h>
20 #include "bench/gemm.h"
21 #include "bench/utils.h"
22
23 #include <xnnpack.h>
24 #include <xnnpack/aligned-allocator.h>
25 #include <xnnpack/common.h>
26 #include <xnnpack/gemm.h>
27 #include <xnnpack/math.h>
28 #include <xnnpack/pack.h>
29 #include <xnnpack/microfnptr.h>
30 #include <xnnpack/microparams-init.h>
31
32
f16_gemm(benchmark::State & state,xnn_f16_gemm_minmax_ukernel_function gemm,size_t mr,size_t nr,size_t kr,size_t sr,xnn_init_f16_minmax_params_fn init_params,benchmark::utils::IsaCheckFunction isa_check=nullptr)33 static void f16_gemm(benchmark::State& state,
34 xnn_f16_gemm_minmax_ukernel_function gemm,
35 size_t mr, size_t nr, size_t kr, size_t sr,
36 xnn_init_f16_minmax_params_fn init_params,
37 benchmark::utils::IsaCheckFunction isa_check = nullptr)
38 {
39 if (isa_check && !isa_check(state)) {
40 return;
41 }
42
43 const size_t mc = state.range(0);
44 const size_t nc = state.range(1);
45 const size_t kc = state.range(2);
46
47 const size_t nc_stride = benchmark::utils::RoundUp(nc, nr);
48 const size_t kc_stride = benchmark::utils::RoundUp(kc, kr * sr);
49
50 std::random_device random_device;
51 auto rng = std::mt19937(random_device());
52 auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
53 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
54
55 std::vector<uint16_t> a(mc * kc + XNN_EXTRA_BYTES / sizeof(uint16_t));
56 std::generate(a.begin(), a.end(), std::ref(f16rng));
57 std::vector<uint16_t> k(nc * kc);
58 std::generate(k.begin(), k.end(), std::ref(f16rng));
59 std::vector<uint16_t> b(nc);
60 std::generate(b.begin(), b.end(), std::ref(f16rng));
61
62 const size_t w_elements = nc_stride * kc_stride + nc_stride;
63 const size_t c_elements = mc * nc;
64 const size_t num_buffers = 1 +
65 benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(),
66 sizeof(uint16_t) * (w_elements + c_elements));
67
68 std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> w(w_elements * num_buffers);
69 std::fill(w.begin(), w.end(), 0);
70 xnn_pack_f16_gemm_goi_w(1 /* groups */, nc, kc, nr, kr, sr, k.data(), b.data(), w.data(), 0, nullptr);
71 std::vector<uint16_t> c(c_elements * num_buffers);
72 std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
73
74 // Prepare minmax parameters.
75 xnn_f16_minmax_params params;
76 init_params(¶ms,
77 UINT16_C(0xFC00) /* -inf */, UINT16_C(0x7C00) /* inf */);
78
79 size_t buffer_index = 0;
80 for (auto _ : state) {
81 // Use circular buffers (exceeding cache size) and prefetch to control cache state:
82 // - A is always in L1 cache (if fits, otherwise L2, L3, etc)
83 // - W is not in cache (for any cache level)
84 // - C is not in cache (for any cache level)
85 state.PauseTiming();
86 benchmark::utils::PrefetchToL1(a.data(), a.size() * sizeof(uint16_t));
87 buffer_index = (buffer_index + 1) % num_buffers;
88 state.ResumeTiming();
89
90 for (uint32_t m = 0; m < mc; m += mr) {
91 const uint32_t mb = min(mc - m, mr);
92 for (uint32_t n = 0; n < nc; n += nr) {
93 const uint32_t nb = min(nc - n, nr);
94 gemm(
95 mb, nb, kc * sizeof(uint16_t),
96 a.data() + m * kc, kc * sizeof(uint16_t),
97 w.data() + (nc_stride * buffer_index + n) * (kc_stride + 1),
98 c.data() + (mc * buffer_index + m) * nc + n, nc * sizeof(uint16_t), nr * sizeof(uint16_t),
99 ¶ms);
100 }
101 }
102 }
103
104 const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
105 if (cpu_frequency != 0) {
106 state.counters["cpufreq"] = cpu_frequency;
107 }
108
109 state.counters["FLOPS"] = benchmark::Counter(
110 uint64_t(state.iterations()) * 2 * mc * nc * kc, benchmark::Counter::kIsRate);
111 }
112
113
114 #if XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
f16_gemm_1x16__aarch64_neonfp16arith_ld32(benchmark::State & state,const char * net)115 static void f16_gemm_1x16__aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
116 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld32, 1, 16, 1, 1,
117 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
118 }
f16_gemm_1x16__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)119 static void f16_gemm_1x16__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
120 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64, 1, 16, 1, 1,
121 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
122 }
f16_gemm_4x16__aarch64_neonfp16arith_ld32(benchmark::State & state,const char * net)123 static void f16_gemm_4x16__aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
124 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32, 4, 16, 1, 1,
125 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
126 }
f16_gemm_4x16__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)127 static void f16_gemm_4x16__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
128 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64, 4, 16, 1, 1,
129 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
130 }
f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55(benchmark::State & state,const char * net)131 static void f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55(benchmark::State& state, const char* net) {
132 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55, 6, 16, 1, 1,
133 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
134 }
f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55r0(benchmark::State & state,const char * net)135 static void f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55r0(benchmark::State& state, const char* net) {
136 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a55r0, 6, 16, 1, 1,
137 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
138 }
f16_gemm_6x16__aarch64_neonfp16arith_cortex_a75(benchmark::State & state,const char * net)139 static void f16_gemm_6x16__aarch64_neonfp16arith_cortex_a75(benchmark::State& state, const char* net) {
140 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_cortex_a75, 6, 16, 1, 1,
141 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
142 }
f16_gemm_6x16__aarch64_neonfp16arith_ld32(benchmark::State & state,const char * net)143 static void f16_gemm_6x16__aarch64_neonfp16arith_ld32(benchmark::State& state, const char* net) {
144 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld32, 6, 16, 1, 1,
145 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
146 }
f16_gemm_6x16__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)147 static void f16_gemm_6x16__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
148 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__aarch64_neonfp16arith_ld64, 6, 16, 1, 1,
149 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
150 }
f16_gemm_1x8__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)151 static void f16_gemm_1x8__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
152 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x8__aarch64_neonfp16arith_ld64, 1, 8, 1, 1,
153 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
154 }
f16_gemm_4x8__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)155 static void f16_gemm_4x8__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
156 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x8__aarch64_neonfp16arith_ld64, 4, 8, 1, 1,
157 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
158 }
f16_gemm_6x8__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)159 static void f16_gemm_6x8__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
160 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x8__aarch64_neonfp16arith_ld64, 6, 8, 1, 1,
161 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
162 }
f16_gemm_8x8__aarch64_neonfp16arith_ld64(benchmark::State & state,const char * net)163 static void f16_gemm_8x8__aarch64_neonfp16arith_ld64(benchmark::State& state, const char* net) {
164 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_8x8__aarch64_neonfp16arith_ld64, 8, 8, 1, 1,
165 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
166 }
167
168 BENCHMARK_GEMM(f16_gemm_1x16__aarch64_neonfp16arith_ld32)
BENCHMARK_GEMM(f16_gemm_1x16__aarch64_neonfp16arith_ld64)169 BENCHMARK_GEMM(f16_gemm_1x16__aarch64_neonfp16arith_ld64)
170 BENCHMARK_GEMM(f16_gemm_4x16__aarch64_neonfp16arith_ld32)
171 BENCHMARK_GEMM(f16_gemm_4x16__aarch64_neonfp16arith_ld64)
172 BENCHMARK_GEMM(f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55)
173 BENCHMARK_GEMM(f16_gemm_6x16__aarch64_neonfp16arith_cortex_a55r0)
174 BENCHMARK_GEMM(f16_gemm_6x16__aarch64_neonfp16arith_cortex_a75)
175 BENCHMARK_GEMM(f16_gemm_6x16__aarch64_neonfp16arith_ld32)
176 BENCHMARK_GEMM(f16_gemm_6x16__aarch64_neonfp16arith_ld64)
177 BENCHMARK_GEMM(f16_gemm_1x8__aarch64_neonfp16arith_ld64)
178 BENCHMARK_GEMM(f16_gemm_4x8__aarch64_neonfp16arith_ld64)
179 BENCHMARK_GEMM(f16_gemm_6x8__aarch64_neonfp16arith_ld64)
180 BENCHMARK_GEMM(f16_gemm_8x8__aarch64_neonfp16arith_ld64)
181 #endif // XNN_ARCH_ARM64 && XNN_ENABLE_ASSEMBLY
182
183 #if XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
184 static void f16_gemm_1x8__neonfp16arith_ld64(benchmark::State& state, const char* net) {
185 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x8__neonfp16arith_ld64, 1, 8, 1, 1,
186 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
187 }
f16_gemm_4x8__neonfp16arith_ld64(benchmark::State & state,const char * net)188 static void f16_gemm_4x8__neonfp16arith_ld64(benchmark::State& state, const char* net) {
189 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x8__neonfp16arith_ld64, 4, 8, 1, 1,
190 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
191 }
f16_gemm_6x8__neonfp16arith_ld64(benchmark::State & state,const char * net)192 static void f16_gemm_6x8__neonfp16arith_ld64(benchmark::State& state, const char* net) {
193 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x8__neonfp16arith_ld64, 6, 8, 1, 1,
194 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
195 }
f16_gemm_8x8__neonfp16arith_ld64(benchmark::State & state,const char * net)196 static void f16_gemm_8x8__neonfp16arith_ld64(benchmark::State& state, const char* net) {
197 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_8x8__neonfp16arith_ld64, 8, 8, 1, 1,
198 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
199 }
f16_gemm_1x16__neonfp16arith_ld64(benchmark::State & state,const char * net)200 static void f16_gemm_1x16__neonfp16arith_ld64(benchmark::State& state, const char* net) {
201 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x16__neonfp16arith_ld64, 1, 16, 1, 1,
202 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
203 }
f16_gemm_4x16__neonfp16arith_ld64(benchmark::State & state,const char * net)204 static void f16_gemm_4x16__neonfp16arith_ld64(benchmark::State& state, const char* net) {
205 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x16__neonfp16arith_ld64, 4, 16, 1, 1,
206 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
207 }
f16_gemm_6x16__neonfp16arith_ld64(benchmark::State & state,const char * net)208 static void f16_gemm_6x16__neonfp16arith_ld64(benchmark::State& state, const char* net) {
209 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x16__neonfp16arith_ld64, 6, 16, 1, 1,
210 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
211 }
f16_gemm_8x16__neonfp16arith_ld64(benchmark::State & state,const char * net)212 static void f16_gemm_8x16__neonfp16arith_ld64(benchmark::State& state, const char* net) {
213 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_8x16__neonfp16arith_ld64, 8, 16, 1, 1,
214 xnn_init_f16_minmax_neon_params, benchmark::utils::CheckNEONFP16ARITH);
215 }
216
217 BENCHMARK_GEMM(f16_gemm_1x8__neonfp16arith_ld64)
BENCHMARK_GEMM(f16_gemm_4x8__neonfp16arith_ld64)218 BENCHMARK_GEMM(f16_gemm_4x8__neonfp16arith_ld64)
219 BENCHMARK_GEMM(f16_gemm_6x8__neonfp16arith_ld64)
220 BENCHMARK_GEMM(f16_gemm_8x8__neonfp16arith_ld64)
221 BENCHMARK_GEMM(f16_gemm_1x16__neonfp16arith_ld64)
222 BENCHMARK_GEMM(f16_gemm_4x16__neonfp16arith_ld64)
223 BENCHMARK_GEMM(f16_gemm_6x16__neonfp16arith_ld64)
224 BENCHMARK_GEMM(f16_gemm_8x16__neonfp16arith_ld64)
225 #endif // XNN_ENABLE_ARM_FP16 && (XNN_ARCH_ARM || XNN_ARCH_ARM64)
226
227 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
228 static void f16_gemm_1x8__avx2_broadcast(benchmark::State& state, const char* net) {
229 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x8__avx2_broadcast, 1, 8, 1, 1,
230 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
231 }
f16_gemm_4x8__avx2_broadcast(benchmark::State & state,const char * net)232 static void f16_gemm_4x8__avx2_broadcast(benchmark::State& state, const char* net) {
233 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x8__avx2_broadcast, 4, 8, 1, 1,
234 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
235 }
f16_gemm_5x8__avx2_broadcast(benchmark::State & state,const char * net)236 static void f16_gemm_5x8__avx2_broadcast(benchmark::State& state, const char* net) {
237 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_5x8__avx2_broadcast, 5, 8, 1, 1,
238 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
239 }
f16_gemm_6x8__avx2_broadcast(benchmark::State & state,const char * net)240 static void f16_gemm_6x8__avx2_broadcast(benchmark::State& state, const char* net) {
241 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_6x8__avx2_broadcast, 6, 8, 1, 1,
242 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
243 }
f16_gemm_7x8__avx2_broadcast(benchmark::State & state,const char * net)244 static void f16_gemm_7x8__avx2_broadcast(benchmark::State& state, const char* net) {
245 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_7x8__avx2_broadcast, 7, 8, 1, 1,
246 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
247 }
f16_gemm_1x16__avx2_broadcast(benchmark::State & state,const char * net)248 static void f16_gemm_1x16__avx2_broadcast(benchmark::State& state, const char* net) {
249 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_1x16__avx2_broadcast, 1, 16, 1, 1,
250 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
251 }
f16_gemm_3x16__avx2_broadcast(benchmark::State & state,const char * net)252 static void f16_gemm_3x16__avx2_broadcast(benchmark::State& state, const char* net) {
253 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_3x16__avx2_broadcast, 3, 16, 1, 1,
254 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
255 }
f16_gemm_4x16__avx2_broadcast(benchmark::State & state,const char * net)256 static void f16_gemm_4x16__avx2_broadcast(benchmark::State& state, const char* net) {
257 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_4x16__avx2_broadcast, 4, 16, 1, 1,
258 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
259 }
f16_gemm_5x16__avx2_broadcast(benchmark::State & state,const char * net)260 static void f16_gemm_5x16__avx2_broadcast(benchmark::State& state, const char* net) {
261 f16_gemm(state, xnn_f16_gemm_minmax_ukernel_5x16__avx2_broadcast, 5, 16, 1, 1,
262 xnn_init_f16_minmax_avx_params, benchmark::utils::CheckAVX2);
263 }
264
265 BENCHMARK_GEMM(f16_gemm_1x8__avx2_broadcast)
266 BENCHMARK_GEMM(f16_gemm_4x8__avx2_broadcast)
267 BENCHMARK_GEMM(f16_gemm_5x8__avx2_broadcast)
268 BENCHMARK_GEMM(f16_gemm_6x8__avx2_broadcast)
269 BENCHMARK_GEMM(f16_gemm_7x8__avx2_broadcast)
270 BENCHMARK_GEMM(f16_gemm_1x16__avx2_broadcast)
271 BENCHMARK_GEMM(f16_gemm_3x16__avx2_broadcast)
272 BENCHMARK_GEMM(f16_gemm_4x16__avx2_broadcast)
273 BENCHMARK_GEMM(f16_gemm_5x16__avx2_broadcast)
274 #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
275
276 #ifndef XNNPACK_BENCHMARK_NO_MAIN
277 BENCHMARK_MAIN();
278 #endif
279