xref: /aosp_15_r20/external/XNNPACK/bench/softmax.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
5*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
6*4bdc9457SAndroid Build Coastguard Worker 
7*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
8*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
9*4bdc9457SAndroid Build Coastguard Worker #include <functional>
10*4bdc9457SAndroid Build Coastguard Worker #include <random>
11*4bdc9457SAndroid Build Coastguard Worker #include <vector>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
14*4bdc9457SAndroid Build Coastguard Worker 
15*4bdc9457SAndroid Build Coastguard Worker #include <benchmark/benchmark.h>
16*4bdc9457SAndroid Build Coastguard Worker #include "bench/utils.h"
17*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
18*4bdc9457SAndroid Build Coastguard Worker #include "flatbuffers/include/flatbuffers/flatbuffers.h"
19*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
20*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h"
21*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
22*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/schema/schema_generated.h"
23*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/version.h"
24*4bdc9457SAndroid Build Coastguard Worker #endif  // BENCHMARK_TENSORFLOW_LITE
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
xnnpack_softmax_qu8(benchmark::State & state)27*4bdc9457SAndroid Build Coastguard Worker static void xnnpack_softmax_qu8(benchmark::State& state) {
28*4bdc9457SAndroid Build Coastguard Worker   const size_t batch_size = static_cast<size_t>(state.range(0));
29*4bdc9457SAndroid Build Coastguard Worker   const size_t channels = static_cast<size_t>(state.range(1));
30*4bdc9457SAndroid Build Coastguard Worker 
31*4bdc9457SAndroid Build Coastguard Worker   std::random_device random_device;
32*4bdc9457SAndroid Build Coastguard Worker   auto rng = std::mt19937(random_device());
33*4bdc9457SAndroid Build Coastguard Worker   auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
34*4bdc9457SAndroid Build Coastguard Worker 
35*4bdc9457SAndroid Build Coastguard Worker   std::vector<uint8_t> input(batch_size * channels);
36*4bdc9457SAndroid Build Coastguard Worker   std::vector<uint8_t> output(batch_size * channels);
37*4bdc9457SAndroid Build Coastguard Worker   std::generate(input.begin(), input.end(), std::ref(u8rng));
38*4bdc9457SAndroid Build Coastguard Worker   std::fill(output.begin(), output.end(), 0xA5);
39*4bdc9457SAndroid Build Coastguard Worker 
40*4bdc9457SAndroid Build Coastguard Worker   xnn_status status = xnn_initialize(nullptr /* allocator */);
41*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
42*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to initialize XNNPACK");
43*4bdc9457SAndroid Build Coastguard Worker     return;
44*4bdc9457SAndroid Build Coastguard Worker   }
45*4bdc9457SAndroid Build Coastguard Worker 
46*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t softmax_op = nullptr;
47*4bdc9457SAndroid Build Coastguard Worker   status = xnn_create_softmax_nc_qu8(
48*4bdc9457SAndroid Build Coastguard Worker     channels, channels /* input stride */, channels /* output stride */,
49*4bdc9457SAndroid Build Coastguard Worker     1.0f /* input scale */,
50*4bdc9457SAndroid Build Coastguard Worker     0 /* output zero point */, 1.0f / 256.0f /* output scale */,
51*4bdc9457SAndroid Build Coastguard Worker     0 /* flags */, &softmax_op);
52*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success || softmax_op == nullptr) {
53*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to create SoftMax operator");
54*4bdc9457SAndroid Build Coastguard Worker     return;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
57*4bdc9457SAndroid Build Coastguard Worker   status = xnn_setup_softmax_nc_qu8(
58*4bdc9457SAndroid Build Coastguard Worker     softmax_op,
59*4bdc9457SAndroid Build Coastguard Worker     batch_size,
60*4bdc9457SAndroid Build Coastguard Worker     input.data(), output.data(),
61*4bdc9457SAndroid Build Coastguard Worker     nullptr /* thread pool */);
62*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
63*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to setup SoftMax operator");
64*4bdc9457SAndroid Build Coastguard Worker     return;
65*4bdc9457SAndroid Build Coastguard Worker   }
66*4bdc9457SAndroid Build Coastguard Worker 
67*4bdc9457SAndroid Build Coastguard Worker   for (auto _ : state) {
68*4bdc9457SAndroid Build Coastguard Worker     status = xnn_run_operator(softmax_op, nullptr /* thread pool */);
69*4bdc9457SAndroid Build Coastguard Worker     if (status != xnn_status_success) {
70*4bdc9457SAndroid Build Coastguard Worker       state.SkipWithError("failed to run SoftMax operator");
71*4bdc9457SAndroid Build Coastguard Worker       return;
72*4bdc9457SAndroid Build Coastguard Worker     }
73*4bdc9457SAndroid Build Coastguard Worker   }
74*4bdc9457SAndroid Build Coastguard Worker 
75*4bdc9457SAndroid Build Coastguard Worker   status = xnn_delete_operator(softmax_op);
76*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
77*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to delete SoftMax operator");
78*4bdc9457SAndroid Build Coastguard Worker     return;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
81*4bdc9457SAndroid Build Coastguard Worker   const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
82*4bdc9457SAndroid Build Coastguard Worker   if (cpu_frequency != 0) {
83*4bdc9457SAndroid Build Coastguard Worker     state.counters["cpufreq"] = cpu_frequency;
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
86*4bdc9457SAndroid Build Coastguard Worker   const size_t elements_per_iteration = batch_size * channels;
87*4bdc9457SAndroid Build Coastguard Worker   state.counters["elements"] =
88*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker   const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(uint8_t);
91*4bdc9457SAndroid Build Coastguard Worker   state.counters["bytes"] =
92*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
93*4bdc9457SAndroid Build Coastguard Worker }
94*4bdc9457SAndroid Build Coastguard Worker #endif  // XNN_NO_QU8_OPERATORS
95*4bdc9457SAndroid Build Coastguard Worker 
xnnpack_softmax_f32(benchmark::State & state)96*4bdc9457SAndroid Build Coastguard Worker static void xnnpack_softmax_f32(benchmark::State& state) {
97*4bdc9457SAndroid Build Coastguard Worker   const size_t batch_size = static_cast<size_t>(state.range(0));
98*4bdc9457SAndroid Build Coastguard Worker   const size_t channels = static_cast<size_t>(state.range(1));
99*4bdc9457SAndroid Build Coastguard Worker 
100*4bdc9457SAndroid Build Coastguard Worker   std::random_device random_device;
101*4bdc9457SAndroid Build Coastguard Worker   auto rng = std::mt19937(random_device());
102*4bdc9457SAndroid Build Coastguard Worker   auto f32rng = std::bind(std::uniform_real_distribution<float>(-100.0f, 100.0f), std::ref(rng));
103*4bdc9457SAndroid Build Coastguard Worker 
104*4bdc9457SAndroid Build Coastguard Worker   std::vector<float> input(batch_size * channels + XNN_EXTRA_BYTES / sizeof(float));
105*4bdc9457SAndroid Build Coastguard Worker   std::vector<float> output(batch_size * channels);
106*4bdc9457SAndroid Build Coastguard Worker   std::generate(input.begin(), input.end(), std::ref(f32rng));
107*4bdc9457SAndroid Build Coastguard Worker   std::fill(output.begin(), output.end(), std::nanf(""));
108*4bdc9457SAndroid Build Coastguard Worker 
109*4bdc9457SAndroid Build Coastguard Worker   xnn_status status = xnn_initialize(nullptr /* allocator */);
110*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
111*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to initialize XNNPACK");
112*4bdc9457SAndroid Build Coastguard Worker     return;
113*4bdc9457SAndroid Build Coastguard Worker   }
114*4bdc9457SAndroid Build Coastguard Worker 
115*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t softmax_op = nullptr;
116*4bdc9457SAndroid Build Coastguard Worker   status = xnn_create_softmax_nc_f32(
117*4bdc9457SAndroid Build Coastguard Worker     channels, channels /* input stride */, channels /* output stride */,
118*4bdc9457SAndroid Build Coastguard Worker     0 /* flags */, &softmax_op);
119*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success || softmax_op == nullptr) {
120*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to create SoftMax operator");
121*4bdc9457SAndroid Build Coastguard Worker     return;
122*4bdc9457SAndroid Build Coastguard Worker   }
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker   status = xnn_setup_softmax_nc_f32(
125*4bdc9457SAndroid Build Coastguard Worker     softmax_op,
126*4bdc9457SAndroid Build Coastguard Worker     batch_size,
127*4bdc9457SAndroid Build Coastguard Worker     input.data(), output.data(),
128*4bdc9457SAndroid Build Coastguard Worker     nullptr /* thread pool */);
129*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
130*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to setup SoftMax operator");
131*4bdc9457SAndroid Build Coastguard Worker     return;
132*4bdc9457SAndroid Build Coastguard Worker   }
133*4bdc9457SAndroid Build Coastguard Worker 
134*4bdc9457SAndroid Build Coastguard Worker   for (auto _ : state) {
135*4bdc9457SAndroid Build Coastguard Worker     status = xnn_run_operator(softmax_op, nullptr /* thread pool */);
136*4bdc9457SAndroid Build Coastguard Worker     if (status != xnn_status_success) {
137*4bdc9457SAndroid Build Coastguard Worker       state.SkipWithError("failed to run SoftMax operator");
138*4bdc9457SAndroid Build Coastguard Worker       return;
139*4bdc9457SAndroid Build Coastguard Worker     }
140*4bdc9457SAndroid Build Coastguard Worker   }
141*4bdc9457SAndroid Build Coastguard Worker 
142*4bdc9457SAndroid Build Coastguard Worker   status = xnn_delete_operator(softmax_op);
143*4bdc9457SAndroid Build Coastguard Worker   if (status != xnn_status_success) {
144*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to delete SoftMax operator");
145*4bdc9457SAndroid Build Coastguard Worker     return;
146*4bdc9457SAndroid Build Coastguard Worker   }
147*4bdc9457SAndroid Build Coastguard Worker 
148*4bdc9457SAndroid Build Coastguard Worker   const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
149*4bdc9457SAndroid Build Coastguard Worker   if (cpu_frequency != 0) {
150*4bdc9457SAndroid Build Coastguard Worker     state.counters["cpufreq"] = cpu_frequency;
151*4bdc9457SAndroid Build Coastguard Worker   }
152*4bdc9457SAndroid Build Coastguard Worker 
153*4bdc9457SAndroid Build Coastguard Worker   const size_t elements_per_iteration = batch_size * channels;
154*4bdc9457SAndroid Build Coastguard Worker   state.counters["elements"] =
155*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
156*4bdc9457SAndroid Build Coastguard Worker 
157*4bdc9457SAndroid Build Coastguard Worker   const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(float);
158*4bdc9457SAndroid Build Coastguard Worker   state.counters["bytes"] =
159*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
160*4bdc9457SAndroid Build Coastguard Worker }
161*4bdc9457SAndroid Build Coastguard Worker 
162*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
tflite_softmax_f32(benchmark::State & state)163*4bdc9457SAndroid Build Coastguard Worker static void tflite_softmax_f32(benchmark::State& state) {
164*4bdc9457SAndroid Build Coastguard Worker   const size_t batch_size = state.range(0);
165*4bdc9457SAndroid Build Coastguard Worker   const size_t channels = state.range(1);
166*4bdc9457SAndroid Build Coastguard Worker 
167*4bdc9457SAndroid Build Coastguard Worker   std::random_device random_device;
168*4bdc9457SAndroid Build Coastguard Worker   auto rng = std::mt19937(random_device());
169*4bdc9457SAndroid Build Coastguard Worker   auto f32rng = std::bind(std::uniform_real_distribution<float>(-100.0f, 100.0f), std::ref(rng));
170*4bdc9457SAndroid Build Coastguard Worker 
171*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
172*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::OperatorCode> operator_code =
173*4bdc9457SAndroid Build Coastguard Worker     tflite::CreateOperatorCode(builder, tflite::BuiltinOperator_SOFTMAX);
174*4bdc9457SAndroid Build Coastguard Worker 
175*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::SoftmaxOptions> softmax_options =
176*4bdc9457SAndroid Build Coastguard Worker     tflite::CreateSoftmaxOptions(builder, 1.0f /* beta */);
177*4bdc9457SAndroid Build Coastguard Worker 
178*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::Buffer> buffers[1] = {
179*4bdc9457SAndroid Build Coastguard Worker     tflite::CreateBuffer(builder, builder.CreateVector({})),
180*4bdc9457SAndroid Build Coastguard Worker   };
181*4bdc9457SAndroid Build Coastguard Worker 
182*4bdc9457SAndroid Build Coastguard Worker   const int32_t input_shape[4] = {
183*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(batch_size),
184*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(1 /* height */),
185*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(1 /* width */),
186*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(channels)
187*4bdc9457SAndroid Build Coastguard Worker   };
188*4bdc9457SAndroid Build Coastguard Worker   const int32_t output_shape[4] = {
189*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(batch_size),
190*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(1 /* height */),
191*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(1 /* width */),
192*4bdc9457SAndroid Build Coastguard Worker     static_cast<int32_t>(channels)
193*4bdc9457SAndroid Build Coastguard Worker   };
194*4bdc9457SAndroid Build Coastguard Worker 
195*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::Tensor> tensors[2] = {
196*4bdc9457SAndroid Build Coastguard Worker     tflite::CreateTensor(builder,
197*4bdc9457SAndroid Build Coastguard Worker                          builder.CreateVector<int32_t>(input_shape, 4),
198*4bdc9457SAndroid Build Coastguard Worker                          tflite::TensorType_FLOAT32),
199*4bdc9457SAndroid Build Coastguard Worker     tflite::CreateTensor(builder,
200*4bdc9457SAndroid Build Coastguard Worker                          builder.CreateVector<int32_t>(output_shape, 4),
201*4bdc9457SAndroid Build Coastguard Worker                          tflite::TensorType_FLOAT32),
202*4bdc9457SAndroid Build Coastguard Worker   };
203*4bdc9457SAndroid Build Coastguard Worker 
204*4bdc9457SAndroid Build Coastguard Worker   const int32_t op_inputs[1] = { 0 };
205*4bdc9457SAndroid Build Coastguard Worker   const int32_t op_outputs[1] = { 1 };
206*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::Operator> op = tflite::CreateOperator(
207*4bdc9457SAndroid Build Coastguard Worker       builder,
208*4bdc9457SAndroid Build Coastguard Worker       0 /* opcode_index */,
209*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector<int32_t>(op_inputs, 1),
210*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector<int32_t>(op_outputs, 1),
211*4bdc9457SAndroid Build Coastguard Worker       tflite::BuiltinOptions_SoftmaxOptions, softmax_options.Union());
212*4bdc9457SAndroid Build Coastguard Worker 
213*4bdc9457SAndroid Build Coastguard Worker   const int32_t graph_inputs[1] = { 0 };
214*4bdc9457SAndroid Build Coastguard Worker   const int32_t graph_outputs[1] = { 1 };
215*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::SubGraph> subgraph = tflite::CreateSubGraph(
216*4bdc9457SAndroid Build Coastguard Worker       builder,
217*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector(tensors, 2),
218*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector<int32_t>(graph_inputs, 1),
219*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector<int32_t>(graph_outputs, 1),
220*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector(&op, 1));
221*4bdc9457SAndroid Build Coastguard Worker 
222*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<flatbuffers::String> description = builder.CreateString("Softmax model");
223*4bdc9457SAndroid Build Coastguard Worker 
224*4bdc9457SAndroid Build Coastguard Worker   flatbuffers::Offset<tflite::Model> model_buffer = tflite::CreateModel(builder,
225*4bdc9457SAndroid Build Coastguard Worker       TFLITE_SCHEMA_VERSION,
226*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector(&operator_code, 1),
227*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector(&subgraph, 1),
228*4bdc9457SAndroid Build Coastguard Worker       description,
229*4bdc9457SAndroid Build Coastguard Worker       builder.CreateVector(buffers, 1));
230*4bdc9457SAndroid Build Coastguard Worker 
231*4bdc9457SAndroid Build Coastguard Worker   builder.Finish(model_buffer);
232*4bdc9457SAndroid Build Coastguard Worker 
233*4bdc9457SAndroid Build Coastguard Worker   const tflite::Model* model = tflite::GetModel(builder.GetBufferPointer());
234*4bdc9457SAndroid Build Coastguard Worker   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
235*4bdc9457SAndroid Build Coastguard Worker   tflite::InterpreterBuilder interpreterBuilder(model, resolver);
236*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<tflite::Interpreter> interpreter;
237*4bdc9457SAndroid Build Coastguard Worker   if (interpreterBuilder(&interpreter) != kTfLiteOk) {
238*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to create TFLite interpreter");
239*4bdc9457SAndroid Build Coastguard Worker     return;
240*4bdc9457SAndroid Build Coastguard Worker   }
241*4bdc9457SAndroid Build Coastguard Worker   if (interpreter == nullptr) {
242*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("TFLite interpreter is null");
243*4bdc9457SAndroid Build Coastguard Worker     return;
244*4bdc9457SAndroid Build Coastguard Worker   }
245*4bdc9457SAndroid Build Coastguard Worker   interpreter->SetNumThreads(1);
246*4bdc9457SAndroid Build Coastguard Worker 
247*4bdc9457SAndroid Build Coastguard Worker   if (interpreter->AllocateTensors() != kTfLiteOk) {
248*4bdc9457SAndroid Build Coastguard Worker     state.SkipWithError("failed to allocate tensors");
249*4bdc9457SAndroid Build Coastguard Worker     return;
250*4bdc9457SAndroid Build Coastguard Worker   }
251*4bdc9457SAndroid Build Coastguard Worker 
252*4bdc9457SAndroid Build Coastguard Worker   std::generate(
253*4bdc9457SAndroid Build Coastguard Worker     interpreter->typed_tensor<float>(0),
254*4bdc9457SAndroid Build Coastguard Worker     interpreter->typed_tensor<float>(0) + batch_size * channels,
255*4bdc9457SAndroid Build Coastguard Worker     std::ref(f32rng));
256*4bdc9457SAndroid Build Coastguard Worker 
257*4bdc9457SAndroid Build Coastguard Worker   for (auto _ : state) {
258*4bdc9457SAndroid Build Coastguard Worker     if (interpreter->Invoke() != kTfLiteOk) {
259*4bdc9457SAndroid Build Coastguard Worker       state.SkipWithError("failed to invoke TFLite interpreter");
260*4bdc9457SAndroid Build Coastguard Worker       return;
261*4bdc9457SAndroid Build Coastguard Worker     }
262*4bdc9457SAndroid Build Coastguard Worker   }
263*4bdc9457SAndroid Build Coastguard Worker 
264*4bdc9457SAndroid Build Coastguard Worker   const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
265*4bdc9457SAndroid Build Coastguard Worker   if (cpu_frequency != 0) {
266*4bdc9457SAndroid Build Coastguard Worker     state.counters["cpufreq"] = cpu_frequency;
267*4bdc9457SAndroid Build Coastguard Worker   }
268*4bdc9457SAndroid Build Coastguard Worker 
269*4bdc9457SAndroid Build Coastguard Worker   const size_t elements_per_iteration = batch_size * channels;
270*4bdc9457SAndroid Build Coastguard Worker   state.counters["elements"] =
271*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate);
272*4bdc9457SAndroid Build Coastguard Worker 
273*4bdc9457SAndroid Build Coastguard Worker   const size_t bytes_per_iteration = 2 * elements_per_iteration * sizeof(float);
274*4bdc9457SAndroid Build Coastguard Worker   state.counters["bytes"] =
275*4bdc9457SAndroid Build Coastguard Worker     benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
276*4bdc9457SAndroid Build Coastguard Worker 
277*4bdc9457SAndroid Build Coastguard Worker   interpreter.reset();
278*4bdc9457SAndroid Build Coastguard Worker }
279*4bdc9457SAndroid Build Coastguard Worker #endif  // BENCHMARK_TENSORFLOW_LITE
280*4bdc9457SAndroid Build Coastguard Worker 
CharacteristicArguments(benchmark::internal::Benchmark * b)281*4bdc9457SAndroid Build Coastguard Worker static void CharacteristicArguments(benchmark::internal::Benchmark* b)
282*4bdc9457SAndroid Build Coastguard Worker {
283*4bdc9457SAndroid Build Coastguard Worker   b->ArgNames({"N", "C"});
284*4bdc9457SAndroid Build Coastguard Worker 
285*4bdc9457SAndroid Build Coastguard Worker   // CIFAR-10
286*4bdc9457SAndroid Build Coastguard Worker   b->Args({1, 10});
287*4bdc9457SAndroid Build Coastguard Worker   // CIFAR-100 */
288*4bdc9457SAndroid Build Coastguard Worker   b->Args({1, 100});
289*4bdc9457SAndroid Build Coastguard Worker   // ImageNet-1K
290*4bdc9457SAndroid Build Coastguard Worker   b->Args({1, 1000});
291*4bdc9457SAndroid Build Coastguard Worker   // ImageNet-1K+1
292*4bdc9457SAndroid Build Coastguard Worker   b->Args({1, 1001});
293*4bdc9457SAndroid Build Coastguard Worker   // ImageNet-22K
294*4bdc9457SAndroid Build Coastguard Worker   b->Args({1, 21841});
295*4bdc9457SAndroid Build Coastguard Worker   // ADE20K
296*4bdc9457SAndroid Build Coastguard Worker   b->Args({257 * 257, 151});
297*4bdc9457SAndroid Build Coastguard Worker }
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
300*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(xnnpack_softmax_qu8)->Apply(CharacteristicArguments)->UseRealTime();
301*4bdc9457SAndroid Build Coastguard Worker #endif  // XNN_NO_QU8_OPERATORS
302*4bdc9457SAndroid Build Coastguard Worker 
303*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(xnnpack_softmax_f32)->Apply(CharacteristicArguments)->UseRealTime();
304*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
305*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(tflite_softmax_f32)->Apply(CharacteristicArguments)->UseRealTime();
306*4bdc9457SAndroid Build Coastguard Worker #endif  // BENCHMARK_TENSORFLOW_LITE
307*4bdc9457SAndroid Build Coastguard Worker 
308*4bdc9457SAndroid Build Coastguard Worker #ifndef XNNPACK_BENCHMARK_NO_MAIN
309*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_MAIN();
310*4bdc9457SAndroid Build Coastguard Worker #endif
311