1*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
2*4bdc9457SAndroid Build Coastguard Worker #include <cfloat>
3*4bdc9457SAndroid Build Coastguard Worker #include <chrono>
4*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
5*4bdc9457SAndroid Build Coastguard Worker #include <functional>
6*4bdc9457SAndroid Build Coastguard Worker #include <random>
7*4bdc9457SAndroid Build Coastguard Worker #include <vector>
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <benchmark/benchmark.h>
10*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_INTEL_DNNL
11*4bdc9457SAndroid Build Coastguard Worker #include <dnnl.h>
12*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_INTEL_DNNL
13*4bdc9457SAndroid Build Coastguard Worker #include "bench/utils.h"
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/raddexpminusmax.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/raddextexp.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/raddstoreexpminusmax.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/rmax.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/vbinary.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/vscaleexpminusmax.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/vscaleextexp.h>
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker
28*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_INTEL_DNNL
DNNLSoftArgMax(benchmark::State & state)29*4bdc9457SAndroid Build Coastguard Worker static void DNNLSoftArgMax(
30*4bdc9457SAndroid Build Coastguard Worker benchmark::State& state)
31*4bdc9457SAndroid Build Coastguard Worker {
32*4bdc9457SAndroid Build Coastguard Worker const size_t elements = state.range(0);
33*4bdc9457SAndroid Build Coastguard Worker const size_t cache_line_size_max = 128;
34*4bdc9457SAndroid Build Coastguard Worker const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
35*4bdc9457SAndroid Build Coastguard Worker
36*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
37*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
38*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
39*4bdc9457SAndroid Build Coastguard Worker
40*4bdc9457SAndroid Build Coastguard Worker const size_t num_buffers = 1 +
41*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
42*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(elements);
43*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(packed_elements * num_buffers);
44*4bdc9457SAndroid Build Coastguard Worker
45*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f32rng));
46*4bdc9457SAndroid Build Coastguard Worker
47*4bdc9457SAndroid Build Coastguard Worker dnnl_engine_t engine;
48*4bdc9457SAndroid Build Coastguard Worker if (dnnl_engine_create(&engine, dnnl_cpu, 0) != dnnl_success) {
49*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create CPU engine");
50*4bdc9457SAndroid Build Coastguard Worker return;
51*4bdc9457SAndroid Build Coastguard Worker }
52*4bdc9457SAndroid Build Coastguard Worker
53*4bdc9457SAndroid Build Coastguard Worker dnnl_dim_t input_output_shape[1] = { static_cast<int>(elements) };
54*4bdc9457SAndroid Build Coastguard Worker
55*4bdc9457SAndroid Build Coastguard Worker dnnl_memory_desc_t memory_descriptor = { 0 };
56*4bdc9457SAndroid Build Coastguard Worker if (dnnl_memory_desc_init_by_tag(
57*4bdc9457SAndroid Build Coastguard Worker &memory_descriptor, 1, input_output_shape, dnnl_f32, dnnl_x) != dnnl_success)
58*4bdc9457SAndroid Build Coastguard Worker {
59*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create input memory descriptor");
60*4bdc9457SAndroid Build Coastguard Worker return;
61*4bdc9457SAndroid Build Coastguard Worker }
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker dnnl_memory_t input_memory = nullptr;
64*4bdc9457SAndroid Build Coastguard Worker if (dnnl_memory_create(
65*4bdc9457SAndroid Build Coastguard Worker &input_memory, &memory_descriptor, engine, x.data()) != dnnl_success)
66*4bdc9457SAndroid Build Coastguard Worker {
67*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create input memory");
68*4bdc9457SAndroid Build Coastguard Worker return;
69*4bdc9457SAndroid Build Coastguard Worker }
70*4bdc9457SAndroid Build Coastguard Worker
71*4bdc9457SAndroid Build Coastguard Worker dnnl_memory_t output_memory = nullptr;
72*4bdc9457SAndroid Build Coastguard Worker if (dnnl_memory_create(
73*4bdc9457SAndroid Build Coastguard Worker &output_memory, &memory_descriptor, engine, y.data()) != dnnl_success)
74*4bdc9457SAndroid Build Coastguard Worker {
75*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create output memory");
76*4bdc9457SAndroid Build Coastguard Worker return;
77*4bdc9457SAndroid Build Coastguard Worker }
78*4bdc9457SAndroid Build Coastguard Worker
79*4bdc9457SAndroid Build Coastguard Worker dnnl_softmax_desc_t softmax_forward_descriptor = {};
80*4bdc9457SAndroid Build Coastguard Worker if (dnnl_softmax_forward_desc_init(
81*4bdc9457SAndroid Build Coastguard Worker &softmax_forward_descriptor, dnnl_forward_inference,
82*4bdc9457SAndroid Build Coastguard Worker &memory_descriptor, 0) != dnnl_success)
83*4bdc9457SAndroid Build Coastguard Worker {
84*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create SoftMax forward descriptor");
85*4bdc9457SAndroid Build Coastguard Worker return;
86*4bdc9457SAndroid Build Coastguard Worker }
87*4bdc9457SAndroid Build Coastguard Worker
88*4bdc9457SAndroid Build Coastguard Worker dnnl_primitive_desc_t softmax_primitive_descriptor = nullptr;
89*4bdc9457SAndroid Build Coastguard Worker if (dnnl_primitive_desc_create(
90*4bdc9457SAndroid Build Coastguard Worker &softmax_primitive_descriptor, &softmax_forward_descriptor,
91*4bdc9457SAndroid Build Coastguard Worker nullptr /* primitive attributes */, engine, nullptr /* hint */) != dnnl_success)
92*4bdc9457SAndroid Build Coastguard Worker {
93*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create SoftMax primitive descriptor");
94*4bdc9457SAndroid Build Coastguard Worker return;
95*4bdc9457SAndroid Build Coastguard Worker }
96*4bdc9457SAndroid Build Coastguard Worker
97*4bdc9457SAndroid Build Coastguard Worker dnnl_primitive_t softmax_primitive = nullptr;
98*4bdc9457SAndroid Build Coastguard Worker if (dnnl_primitive_create(
99*4bdc9457SAndroid Build Coastguard Worker &softmax_primitive, softmax_primitive_descriptor) != dnnl_success)
100*4bdc9457SAndroid Build Coastguard Worker {
101*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create SoftMax primitive");
102*4bdc9457SAndroid Build Coastguard Worker return;
103*4bdc9457SAndroid Build Coastguard Worker }
104*4bdc9457SAndroid Build Coastguard Worker
105*4bdc9457SAndroid Build Coastguard Worker dnnl_exec_arg_t softmax_args[2] = {
106*4bdc9457SAndroid Build Coastguard Worker {DNNL_ARG_SRC, input_memory},
107*4bdc9457SAndroid Build Coastguard Worker {DNNL_ARG_DST, output_memory},
108*4bdc9457SAndroid Build Coastguard Worker };
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker dnnl_stream_t stream = nullptr;
111*4bdc9457SAndroid Build Coastguard Worker if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != dnnl_success) {
112*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create stream");
113*4bdc9457SAndroid Build Coastguard Worker return;
114*4bdc9457SAndroid Build Coastguard Worker }
115*4bdc9457SAndroid Build Coastguard Worker
116*4bdc9457SAndroid Build Coastguard Worker size_t buffer_index = 0;
117*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
118*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
119*4bdc9457SAndroid Build Coastguard Worker if (++buffer_index == num_buffers) {
120*4bdc9457SAndroid Build Coastguard Worker buffer_index = 0;
121*4bdc9457SAndroid Build Coastguard Worker }
122*4bdc9457SAndroid Build Coastguard Worker
123*4bdc9457SAndroid Build Coastguard Worker const auto start = std::chrono::high_resolution_clock::now();
124*4bdc9457SAndroid Build Coastguard Worker if (dnnl_primitive_execute(
125*4bdc9457SAndroid Build Coastguard Worker softmax_primitive, stream, 2, softmax_args) != dnnl_success)
126*4bdc9457SAndroid Build Coastguard Worker {
127*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to execute SoftMax");
128*4bdc9457SAndroid Build Coastguard Worker return;
129*4bdc9457SAndroid Build Coastguard Worker }
130*4bdc9457SAndroid Build Coastguard Worker const auto end = std::chrono::high_resolution_clock::now();
131*4bdc9457SAndroid Build Coastguard Worker
132*4bdc9457SAndroid Build Coastguard Worker const auto elapsed_seconds =
133*4bdc9457SAndroid Build Coastguard Worker std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
134*4bdc9457SAndroid Build Coastguard Worker state.SetIterationTime(elapsed_seconds.count());
135*4bdc9457SAndroid Build Coastguard Worker }
136*4bdc9457SAndroid Build Coastguard Worker
137*4bdc9457SAndroid Build Coastguard Worker if (dnnl_stream_destroy(stream) != dnnl_success) {
138*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy stream");
139*4bdc9457SAndroid Build Coastguard Worker return;
140*4bdc9457SAndroid Build Coastguard Worker }
141*4bdc9457SAndroid Build Coastguard Worker
142*4bdc9457SAndroid Build Coastguard Worker if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != dnnl_success) {
143*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy SoftMax primitive descriptor");
144*4bdc9457SAndroid Build Coastguard Worker return;
145*4bdc9457SAndroid Build Coastguard Worker }
146*4bdc9457SAndroid Build Coastguard Worker
147*4bdc9457SAndroid Build Coastguard Worker if (dnnl_primitive_destroy(softmax_primitive) != dnnl_success) {
148*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy SoftMax primitive");
149*4bdc9457SAndroid Build Coastguard Worker return;
150*4bdc9457SAndroid Build Coastguard Worker }
151*4bdc9457SAndroid Build Coastguard Worker
152*4bdc9457SAndroid Build Coastguard Worker if (dnnl_memory_destroy(input_memory) != dnnl_success) {
153*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy input memory");
154*4bdc9457SAndroid Build Coastguard Worker return;
155*4bdc9457SAndroid Build Coastguard Worker }
156*4bdc9457SAndroid Build Coastguard Worker
157*4bdc9457SAndroid Build Coastguard Worker if (dnnl_memory_destroy(output_memory) != dnnl_success) {
158*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy output memory");
159*4bdc9457SAndroid Build Coastguard Worker return;
160*4bdc9457SAndroid Build Coastguard Worker }
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker if (dnnl_engine_destroy(engine) != dnnl_success) {
163*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to destroy engine");
164*4bdc9457SAndroid Build Coastguard Worker return;
165*4bdc9457SAndroid Build Coastguard Worker }
166*4bdc9457SAndroid Build Coastguard Worker
167*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
168*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
169*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
170*4bdc9457SAndroid Build Coastguard Worker }
171*4bdc9457SAndroid Build Coastguard Worker
172*4bdc9457SAndroid Build Coastguard Worker const size_t elements_per_iteration = elements;
173*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
174*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
175*4bdc9457SAndroid Build Coastguard Worker
176*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * elements * sizeof(float);
177*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
178*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
179*4bdc9457SAndroid Build Coastguard Worker }
180*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_INTEL_DNNL
181*4bdc9457SAndroid Build Coastguard Worker
ThreePassSoftMaxWithRecomputing(benchmark::State & state,xnn_f32_rmax_ukernel_function rmax,xnn_f32_raddexpminusmax_ukernel_function raddexpminusmax,xnn_f32_vscaleexpminusmax_ukernel_function vscaleexpminusmax,benchmark::utils::IsaCheckFunction isa_check=nullptr)182*4bdc9457SAndroid Build Coastguard Worker static void ThreePassSoftMaxWithRecomputing(
183*4bdc9457SAndroid Build Coastguard Worker benchmark::State& state,
184*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel_function rmax,
185*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddexpminusmax_ukernel_function raddexpminusmax,
186*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleexpminusmax_ukernel_function vscaleexpminusmax,
187*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::IsaCheckFunction isa_check = nullptr)
188*4bdc9457SAndroid Build Coastguard Worker {
189*4bdc9457SAndroid Build Coastguard Worker if (isa_check && !isa_check(state)) {
190*4bdc9457SAndroid Build Coastguard Worker return;
191*4bdc9457SAndroid Build Coastguard Worker }
192*4bdc9457SAndroid Build Coastguard Worker
193*4bdc9457SAndroid Build Coastguard Worker const size_t elements = state.range(0);
194*4bdc9457SAndroid Build Coastguard Worker const size_t cache_line_size_max = 128;
195*4bdc9457SAndroid Build Coastguard Worker const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
196*4bdc9457SAndroid Build Coastguard Worker
197*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
198*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
199*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
200*4bdc9457SAndroid Build Coastguard Worker
201*4bdc9457SAndroid Build Coastguard Worker const size_t num_buffers = 1 +
202*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
203*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(elements);
204*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(packed_elements * num_buffers);
205*4bdc9457SAndroid Build Coastguard Worker
206*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f32rng));
207*4bdc9457SAndroid Build Coastguard Worker
208*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DisableDenormals();
209*4bdc9457SAndroid Build Coastguard Worker
210*4bdc9457SAndroid Build Coastguard Worker size_t buffer_index = 0;
211*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
212*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
213*4bdc9457SAndroid Build Coastguard Worker if (++buffer_index == num_buffers) {
214*4bdc9457SAndroid Build Coastguard Worker buffer_index = 0;
215*4bdc9457SAndroid Build Coastguard Worker }
216*4bdc9457SAndroid Build Coastguard Worker
217*4bdc9457SAndroid Build Coastguard Worker const auto start = std::chrono::high_resolution_clock::now();
218*4bdc9457SAndroid Build Coastguard Worker float x_max = nanf("");
219*4bdc9457SAndroid Build Coastguard Worker rmax(elements * sizeof(float), x.data(), &x_max);
220*4bdc9457SAndroid Build Coastguard Worker float y_sum = nanf("");
221*4bdc9457SAndroid Build Coastguard Worker raddexpminusmax(elements * sizeof(float), x.data(), &y_sum, x_max);
222*4bdc9457SAndroid Build Coastguard Worker vscaleexpminusmax(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, x_max, 1.0f / y_sum);
223*4bdc9457SAndroid Build Coastguard Worker const auto end = std::chrono::high_resolution_clock::now();
224*4bdc9457SAndroid Build Coastguard Worker
225*4bdc9457SAndroid Build Coastguard Worker const auto elapsed_seconds =
226*4bdc9457SAndroid Build Coastguard Worker std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
227*4bdc9457SAndroid Build Coastguard Worker state.SetIterationTime(elapsed_seconds.count());
228*4bdc9457SAndroid Build Coastguard Worker }
229*4bdc9457SAndroid Build Coastguard Worker
230*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
231*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
232*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
233*4bdc9457SAndroid Build Coastguard Worker }
234*4bdc9457SAndroid Build Coastguard Worker
235*4bdc9457SAndroid Build Coastguard Worker const size_t elements_per_iteration = elements;
236*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
237*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
238*4bdc9457SAndroid Build Coastguard Worker
239*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * elements * sizeof(float);
240*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
241*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
242*4bdc9457SAndroid Build Coastguard Worker }
243*4bdc9457SAndroid Build Coastguard Worker
ThreePassSoftMaxWithReloading(benchmark::State & state,xnn_f32_rmax_ukernel_function rmax,xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax,xnn_init_f32_expminus_params_fn init_expminus_params,xnn_f32_vbinary_minmax_ukernel_function vmulc,xnn_init_f32_minmax_params_fn init_minmax_params,benchmark::utils::IsaCheckFunction isa_check=nullptr)244*4bdc9457SAndroid Build Coastguard Worker static void ThreePassSoftMaxWithReloading(
245*4bdc9457SAndroid Build Coastguard Worker benchmark::State& state,
246*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel_function rmax,
247*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddstoreexpminusmax_ukernel_function raddstoreexpminusmax,
248*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_expminus_params_fn init_expminus_params,
249*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vbinary_minmax_ukernel_function vmulc,
250*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_params_fn init_minmax_params,
251*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::IsaCheckFunction isa_check = nullptr)
252*4bdc9457SAndroid Build Coastguard Worker {
253*4bdc9457SAndroid Build Coastguard Worker if (isa_check && !isa_check(state)) {
254*4bdc9457SAndroid Build Coastguard Worker return;
255*4bdc9457SAndroid Build Coastguard Worker }
256*4bdc9457SAndroid Build Coastguard Worker
257*4bdc9457SAndroid Build Coastguard Worker const size_t elements = state.range(0);
258*4bdc9457SAndroid Build Coastguard Worker const size_t cache_line_size_max = 128;
259*4bdc9457SAndroid Build Coastguard Worker const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
260*4bdc9457SAndroid Build Coastguard Worker
261*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
262*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
263*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
264*4bdc9457SAndroid Build Coastguard Worker
265*4bdc9457SAndroid Build Coastguard Worker const size_t num_buffers = 1 +
266*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
267*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(elements);
268*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(packed_elements * num_buffers);
269*4bdc9457SAndroid Build Coastguard Worker
270*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f32rng));
271*4bdc9457SAndroid Build Coastguard Worker
272*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DisableDenormals();
273*4bdc9457SAndroid Build Coastguard Worker
274*4bdc9457SAndroid Build Coastguard Worker xnn_f32_expminus_params expminus_params;
275*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params minmax_params;
276*4bdc9457SAndroid Build Coastguard Worker init_expminus_params(&expminus_params);
277*4bdc9457SAndroid Build Coastguard Worker init_minmax_params(&minmax_params, -INFINITY, INFINITY);
278*4bdc9457SAndroid Build Coastguard Worker
279*4bdc9457SAndroid Build Coastguard Worker size_t buffer_index = 0;
280*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
281*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
282*4bdc9457SAndroid Build Coastguard Worker if (++buffer_index == num_buffers) {
283*4bdc9457SAndroid Build Coastguard Worker buffer_index = 0;
284*4bdc9457SAndroid Build Coastguard Worker }
285*4bdc9457SAndroid Build Coastguard Worker
286*4bdc9457SAndroid Build Coastguard Worker const auto start = std::chrono::high_resolution_clock::now();
287*4bdc9457SAndroid Build Coastguard Worker float x_max = nanf("");
288*4bdc9457SAndroid Build Coastguard Worker rmax(elements * sizeof(float), x.data(), &x_max);
289*4bdc9457SAndroid Build Coastguard Worker float y_sum = nanf("");
290*4bdc9457SAndroid Build Coastguard Worker raddstoreexpminusmax(elements * sizeof(float), x.data(), &x_max, y.data() + packed_elements * buffer_index, &y_sum, &expminus_params);
291*4bdc9457SAndroid Build Coastguard Worker const float inv_y_sum = 1.0f / y_sum;
292*4bdc9457SAndroid Build Coastguard Worker vmulc(elements * sizeof(float), y.data() + packed_elements * buffer_index, &inv_y_sum, y.data() + packed_elements * buffer_index, &minmax_params);
293*4bdc9457SAndroid Build Coastguard Worker const auto end = std::chrono::high_resolution_clock::now();
294*4bdc9457SAndroid Build Coastguard Worker
295*4bdc9457SAndroid Build Coastguard Worker const auto elapsed_seconds =
296*4bdc9457SAndroid Build Coastguard Worker std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
297*4bdc9457SAndroid Build Coastguard Worker state.SetIterationTime(elapsed_seconds.count());
298*4bdc9457SAndroid Build Coastguard Worker }
299*4bdc9457SAndroid Build Coastguard Worker
300*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
301*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
302*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
303*4bdc9457SAndroid Build Coastguard Worker }
304*4bdc9457SAndroid Build Coastguard Worker
305*4bdc9457SAndroid Build Coastguard Worker const size_t elements_per_iteration = elements;
306*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
307*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
308*4bdc9457SAndroid Build Coastguard Worker
309*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * elements * sizeof(float);
310*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
311*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
312*4bdc9457SAndroid Build Coastguard Worker }
313*4bdc9457SAndroid Build Coastguard Worker
TwoPassSoftMax(benchmark::State & state,xnn_f32_raddextexp_ukernel_function raddextexp,xnn_f32_vscaleextexp_ukernel_function vscaleextexp,benchmark::utils::IsaCheckFunction isa_check=nullptr)314*4bdc9457SAndroid Build Coastguard Worker static void TwoPassSoftMax(
315*4bdc9457SAndroid Build Coastguard Worker benchmark::State& state,
316*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddextexp_ukernel_function raddextexp,
317*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleextexp_ukernel_function vscaleextexp,
318*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::IsaCheckFunction isa_check = nullptr)
319*4bdc9457SAndroid Build Coastguard Worker {
320*4bdc9457SAndroid Build Coastguard Worker if (isa_check && !isa_check(state)) {
321*4bdc9457SAndroid Build Coastguard Worker return;
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker
324*4bdc9457SAndroid Build Coastguard Worker const size_t elements = state.range(0);
325*4bdc9457SAndroid Build Coastguard Worker const size_t cache_line_size_max = 128;
326*4bdc9457SAndroid Build Coastguard Worker const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float));
327*4bdc9457SAndroid Build Coastguard Worker
328*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
329*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
330*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-1000.0f, 1000.0f), std::ref(rng));
331*4bdc9457SAndroid Build Coastguard Worker
332*4bdc9457SAndroid Build Coastguard Worker const size_t num_buffers = 1 +
333*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DivideRoundUp<size_t>(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float));
334*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(elements);
335*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(packed_elements * num_buffers);
336*4bdc9457SAndroid Build Coastguard Worker
337*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f32rng));
338*4bdc9457SAndroid Build Coastguard Worker
339*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::DisableDenormals();
340*4bdc9457SAndroid Build Coastguard Worker
341*4bdc9457SAndroid Build Coastguard Worker size_t buffer_index = 0;
342*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
343*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::PrefetchToL1(x.data(), x.size() * sizeof(float));
344*4bdc9457SAndroid Build Coastguard Worker if (++buffer_index == num_buffers) {
345*4bdc9457SAndroid Build Coastguard Worker buffer_index = 0;
346*4bdc9457SAndroid Build Coastguard Worker }
347*4bdc9457SAndroid Build Coastguard Worker
348*4bdc9457SAndroid Build Coastguard Worker const auto start = std::chrono::high_resolution_clock::now();
349*4bdc9457SAndroid Build Coastguard Worker float scale[2];
350*4bdc9457SAndroid Build Coastguard Worker raddextexp(elements * sizeof(float), x.data(), scale);
351*4bdc9457SAndroid Build Coastguard Worker vscaleextexp(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, 1.0f / scale[0], -scale[1]);
352*4bdc9457SAndroid Build Coastguard Worker const auto end = std::chrono::high_resolution_clock::now();
353*4bdc9457SAndroid Build Coastguard Worker
354*4bdc9457SAndroid Build Coastguard Worker const auto elapsed_seconds =
355*4bdc9457SAndroid Build Coastguard Worker std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
356*4bdc9457SAndroid Build Coastguard Worker state.SetIterationTime(elapsed_seconds.count());
357*4bdc9457SAndroid Build Coastguard Worker }
358*4bdc9457SAndroid Build Coastguard Worker
359*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
360*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
361*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
362*4bdc9457SAndroid Build Coastguard Worker }
363*4bdc9457SAndroid Build Coastguard Worker
364*4bdc9457SAndroid Build Coastguard Worker const size_t elements_per_iteration = elements;
365*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
366*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
367*4bdc9457SAndroid Build Coastguard Worker
368*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * elements * sizeof(float);
369*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
370*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
371*4bdc9457SAndroid Build Coastguard Worker }
372*4bdc9457SAndroid Build Coastguard Worker
CharacteristicArguments(benchmark::internal::Benchmark * b)373*4bdc9457SAndroid Build Coastguard Worker static void CharacteristicArguments(benchmark::internal::Benchmark* b) {
374*4bdc9457SAndroid Build Coastguard Worker for (int32_t n = 1000; n <= 100000000; n *= 10) {
375*4bdc9457SAndroid Build Coastguard Worker b->Arg(n);
376*4bdc9457SAndroid Build Coastguard Worker b->Arg(3 * n);
377*4bdc9457SAndroid Build Coastguard Worker }
378*4bdc9457SAndroid Build Coastguard Worker }
379*4bdc9457SAndroid Build Coastguard Worker
380*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_INTEL_DNNL
381*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime();
382*4bdc9457SAndroid Build Coastguard Worker #endif
383*4bdc9457SAndroid Build Coastguard Worker
384*4bdc9457SAndroid Build Coastguard Worker #if XNN_ARCH_X86 || XNN_ARCH_X86_64
385*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5,
386*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddextexp_ukernel__avx2_p5_x96,
387*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleextexp_ukernel__avx2_p5_x40,
388*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
389*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5,
390*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel__avx,
391*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddexpminusmax_ukernel__avx2_p5_x96,
392*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_x24,
393*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
394*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5,
395*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel__avx,
396*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddstoreexpminusmax_ukernel__avx2_rr1_p5_x64_acc2,
397*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_expminus_avx2_rr1_p5_params,
398*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vmulc_minmax_ukernel__avx_x16,
399*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_avx_params,
400*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime();
401*4bdc9457SAndroid Build Coastguard Worker
402*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef,
403*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_x144_acc3,
404*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_x16,
405*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
406*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef,
407*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel__avx512f,
408*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_x128_acc4,
409*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_x16,
410*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
411*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx512f_p5_scalef,
412*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rmax_ukernel__avx512f,
413*4bdc9457SAndroid Build Coastguard Worker xnn_f32_raddstoreexpminusmax_ukernel__avx512f_rr1_p5_scalef_x128_acc2,
414*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_expminus_avx512_rr1_p5_params,
415*4bdc9457SAndroid Build Coastguard Worker xnn_f32_vmulc_minmax_ukernel__avx512f_x32,
416*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_scalar_params,
417*4bdc9457SAndroid Build Coastguard Worker benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime();
418*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
419*4bdc9457SAndroid Build Coastguard Worker
420*4bdc9457SAndroid Build Coastguard Worker #ifndef XNNPACK_BENCHMARK_NO_MAIN
421*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_MAIN();
422*4bdc9457SAndroid Build Coastguard Worker #endif
423