1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
7*4bdc9457SAndroid Build Coastguard Worker #include <array>
8*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
9*4bdc9457SAndroid Build Coastguard Worker #include <functional>
10*4bdc9457SAndroid Build Coastguard Worker #include <limits>
11*4bdc9457SAndroid Build Coastguard Worker #include <random>
12*4bdc9457SAndroid Build Coastguard Worker #include <vector>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker
16*4bdc9457SAndroid Build Coastguard Worker #include <benchmark/benchmark.h>
17*4bdc9457SAndroid Build Coastguard Worker #include "bench/utils.h"
18*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
19*4bdc9457SAndroid Build Coastguard Worker #include "flatbuffers/include/flatbuffers/flatbuffers.h"
20*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
21*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h"
22*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
23*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/schema/schema_generated.h"
24*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/version.h"
25*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker
xnnpack_square_f32(benchmark::State & state)28*4bdc9457SAndroid Build Coastguard Worker static void xnnpack_square_f32(benchmark::State& state) {
29*4bdc9457SAndroid Build Coastguard Worker const size_t batch_size = state.range(0);
30*4bdc9457SAndroid Build Coastguard Worker
31*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
32*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
33*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-10.0f, 10.0f), std::ref(rng));
34*4bdc9457SAndroid Build Coastguard Worker
35*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(batch_size + XNN_EXTRA_BYTES / sizeof(float));
36*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size);
37*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(f32rng));
38*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf(""));
39*4bdc9457SAndroid Build Coastguard Worker
40*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_initialize(nullptr /* allocator */);
41*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
42*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to initialize XNNPACK");
43*4bdc9457SAndroid Build Coastguard Worker return;
44*4bdc9457SAndroid Build Coastguard Worker }
45*4bdc9457SAndroid Build Coastguard Worker
46*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t square_op = nullptr;
47*4bdc9457SAndroid Build Coastguard Worker status = xnn_create_square_nc_f32(
48*4bdc9457SAndroid Build Coastguard Worker 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
49*4bdc9457SAndroid Build Coastguard Worker 0 /* flags */, &square_op);
50*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success || square_op == nullptr) {
51*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create Square operator");
52*4bdc9457SAndroid Build Coastguard Worker return;
53*4bdc9457SAndroid Build Coastguard Worker }
54*4bdc9457SAndroid Build Coastguard Worker
55*4bdc9457SAndroid Build Coastguard Worker status = xnn_setup_square_nc_f32(
56*4bdc9457SAndroid Build Coastguard Worker square_op, batch_size,
57*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
58*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */);
59*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
60*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to setup Square operator");
61*4bdc9457SAndroid Build Coastguard Worker return;
62*4bdc9457SAndroid Build Coastguard Worker }
63*4bdc9457SAndroid Build Coastguard Worker
64*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
65*4bdc9457SAndroid Build Coastguard Worker status = xnn_run_operator(square_op, nullptr /* thread pool */);
66*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
67*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to run Square operator");
68*4bdc9457SAndroid Build Coastguard Worker return;
69*4bdc9457SAndroid Build Coastguard Worker }
70*4bdc9457SAndroid Build Coastguard Worker }
71*4bdc9457SAndroid Build Coastguard Worker
72*4bdc9457SAndroid Build Coastguard Worker status = xnn_delete_operator(square_op);
73*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
74*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to delete Square operator");
75*4bdc9457SAndroid Build Coastguard Worker return;
76*4bdc9457SAndroid Build Coastguard Worker }
77*4bdc9457SAndroid Build Coastguard Worker
78*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
79*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
80*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
81*4bdc9457SAndroid Build Coastguard Worker }
82*4bdc9457SAndroid Build Coastguard Worker
83*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
84*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * batch_size, benchmark::Counter::kIsRate);
85*4bdc9457SAndroid Build Coastguard Worker
86*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * batch_size * sizeof(float);
87*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
88*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
89*4bdc9457SAndroid Build Coastguard Worker }
90*4bdc9457SAndroid Build Coastguard Worker
91*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
tflite_square_f32(benchmark::State & state)92*4bdc9457SAndroid Build Coastguard Worker static void tflite_square_f32(benchmark::State& state) {
93*4bdc9457SAndroid Build Coastguard Worker const size_t batch_size = state.range(0);
94*4bdc9457SAndroid Build Coastguard Worker
95*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
96*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
97*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-10.0f, 10.0f), std::ref(rng));
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
100*4bdc9457SAndroid Build Coastguard Worker const flatbuffers::Offset<tflite::OperatorCode> operator_code =
101*4bdc9457SAndroid Build Coastguard Worker CreateOperatorCode(builder, tflite::BuiltinOperator_SQUARE);
102*4bdc9457SAndroid Build Coastguard Worker
103*4bdc9457SAndroid Build Coastguard Worker const std::array<flatbuffers::Offset<tflite::Buffer>, 1> buffers{{
104*4bdc9457SAndroid Build Coastguard Worker tflite::CreateBuffer(builder, builder.CreateVector({})),
105*4bdc9457SAndroid Build Coastguard Worker }};
106*4bdc9457SAndroid Build Coastguard Worker
107*4bdc9457SAndroid Build Coastguard Worker const std::array<int32_t, 1> shape{{
108*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(batch_size)
109*4bdc9457SAndroid Build Coastguard Worker }};
110*4bdc9457SAndroid Build Coastguard Worker
111*4bdc9457SAndroid Build Coastguard Worker const std::array<flatbuffers::Offset<tflite::Tensor>, 2> tensors{{
112*4bdc9457SAndroid Build Coastguard Worker tflite::CreateTensor(builder,
113*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(shape.data(), shape.size()),
114*4bdc9457SAndroid Build Coastguard Worker tflite::TensorType_FLOAT32),
115*4bdc9457SAndroid Build Coastguard Worker tflite::CreateTensor(builder,
116*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(shape.data(), shape.size()),
117*4bdc9457SAndroid Build Coastguard Worker tflite::TensorType_FLOAT32),
118*4bdc9457SAndroid Build Coastguard Worker }};
119*4bdc9457SAndroid Build Coastguard Worker
120*4bdc9457SAndroid Build Coastguard Worker const std::array<int32_t, 1> op_inputs{{ 0 }};
121*4bdc9457SAndroid Build Coastguard Worker const std::array<int32_t, 1> op_outputs{{ 1 }};
122*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Operator> op = tflite::CreateOperator(
123*4bdc9457SAndroid Build Coastguard Worker builder,
124*4bdc9457SAndroid Build Coastguard Worker 0 /* opcode_index */,
125*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
126*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size()));
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker const std::array<int32_t, 1> graph_inputs{{ 0 }};
129*4bdc9457SAndroid Build Coastguard Worker const std::array<int32_t, 1> graph_outputs{{ 1 }};
130*4bdc9457SAndroid Build Coastguard Worker const flatbuffers::Offset<tflite::SubGraph> subgraph = tflite::CreateSubGraph(
131*4bdc9457SAndroid Build Coastguard Worker builder,
132*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(tensors.data(), tensors.size()),
133*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(graph_inputs.data(), graph_inputs.size()),
134*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(graph_outputs.data(), graph_outputs.size()),
135*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&op, 1));
136*4bdc9457SAndroid Build Coastguard Worker
137*4bdc9457SAndroid Build Coastguard Worker const flatbuffers::Offset<tflite::Model> model_buffer = tflite::CreateModel(builder,
138*4bdc9457SAndroid Build Coastguard Worker TFLITE_SCHEMA_VERSION,
139*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&operator_code, 1),
140*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&subgraph, 1),
141*4bdc9457SAndroid Build Coastguard Worker builder.CreateString("Square model"),
142*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(buffers.data(), buffers.size()));
143*4bdc9457SAndroid Build Coastguard Worker
144*4bdc9457SAndroid Build Coastguard Worker builder.Finish(model_buffer);
145*4bdc9457SAndroid Build Coastguard Worker
146*4bdc9457SAndroid Build Coastguard Worker const tflite::Model* model = tflite::GetModel(builder.GetBufferPointer());
147*4bdc9457SAndroid Build Coastguard Worker tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
148*4bdc9457SAndroid Build Coastguard Worker tflite::InterpreterBuilder interpreterBuilder(model, resolver);
149*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter;
150*4bdc9457SAndroid Build Coastguard Worker if (interpreterBuilder(&interpreter) != kTfLiteOk || interpreter == nullptr) {
151*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create TFLite interpreter");
152*4bdc9457SAndroid Build Coastguard Worker return;
153*4bdc9457SAndroid Build Coastguard Worker }
154*4bdc9457SAndroid Build Coastguard Worker interpreter->SetNumThreads(1);
155*4bdc9457SAndroid Build Coastguard Worker
156*4bdc9457SAndroid Build Coastguard Worker if (interpreter->AllocateTensors() != kTfLiteOk) {
157*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to allocate tensors");
158*4bdc9457SAndroid Build Coastguard Worker return;
159*4bdc9457SAndroid Build Coastguard Worker }
160*4bdc9457SAndroid Build Coastguard Worker
161*4bdc9457SAndroid Build Coastguard Worker std::generate(
162*4bdc9457SAndroid Build Coastguard Worker interpreter->typed_tensor<float>(0),
163*4bdc9457SAndroid Build Coastguard Worker interpreter->typed_tensor<float>(0) + batch_size,
164*4bdc9457SAndroid Build Coastguard Worker std::ref(f32rng));
165*4bdc9457SAndroid Build Coastguard Worker
166*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
167*4bdc9457SAndroid Build Coastguard Worker if (interpreter->Invoke() != kTfLiteOk) {
168*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to invoke TFLite interpreter");
169*4bdc9457SAndroid Build Coastguard Worker return;
170*4bdc9457SAndroid Build Coastguard Worker }
171*4bdc9457SAndroid Build Coastguard Worker }
172*4bdc9457SAndroid Build Coastguard Worker
173*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
174*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
175*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
176*4bdc9457SAndroid Build Coastguard Worker }
177*4bdc9457SAndroid Build Coastguard Worker
178*4bdc9457SAndroid Build Coastguard Worker state.counters["elements"] =
179*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * batch_size, benchmark::Counter::kIsRate);
180*4bdc9457SAndroid Build Coastguard Worker
181*4bdc9457SAndroid Build Coastguard Worker const size_t bytes_per_iteration = 2 * batch_size * sizeof(float);
182*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] =
183*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate);
184*4bdc9457SAndroid Build Coastguard Worker
185*4bdc9457SAndroid Build Coastguard Worker interpreter.reset();
186*4bdc9457SAndroid Build Coastguard Worker }
187*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
188*4bdc9457SAndroid Build Coastguard Worker
189*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(xnnpack_square_f32)
190*4bdc9457SAndroid Build Coastguard Worker ->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
191*4bdc9457SAndroid Build Coastguard Worker ->UseRealTime();
192*4bdc9457SAndroid Build Coastguard Worker
193*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
194*4bdc9457SAndroid Build Coastguard Worker BENCHMARK(tflite_square_f32)
195*4bdc9457SAndroid Build Coastguard Worker ->Apply(benchmark::utils::UnaryElementwiseParameters<float, float>)
196*4bdc9457SAndroid Build Coastguard Worker ->UseRealTime();
197*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
198*4bdc9457SAndroid Build Coastguard Worker
199*4bdc9457SAndroid Build Coastguard Worker #ifndef XNNPACK_BENCHMARK_NO_MAIN
200*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_MAIN();
201*4bdc9457SAndroid Build Coastguard Worker #endif
202