1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
10*4bdc9457SAndroid Build Coastguard Worker #include <cfloat>
11*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
12*4bdc9457SAndroid Build Coastguard Worker #include <functional>
13*4bdc9457SAndroid Build Coastguard Worker #include <limits>
14*4bdc9457SAndroid Build Coastguard Worker #include <random>
15*4bdc9457SAndroid Build Coastguard Worker #include <vector>
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
18*4bdc9457SAndroid Build Coastguard Worker
19*4bdc9457SAndroid Build Coastguard Worker #include <benchmark/benchmark.h>
20*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
21*4bdc9457SAndroid Build Coastguard Worker #include "flatbuffers/include/flatbuffers/flatbuffers.h"
22*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
23*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h"
24*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
25*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/schema/schema_generated.h"
26*4bdc9457SAndroid Build Coastguard Worker #include "tensorflow/lite/version.h"
27*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
28*4bdc9457SAndroid Build Coastguard Worker #include "bench/utils.h"
29*4bdc9457SAndroid Build Coastguard Worker
30*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
xnnpack_average_pooling_qu8(benchmark::State & state,const char * net)31*4bdc9457SAndroid Build Coastguard Worker static void xnnpack_average_pooling_qu8(benchmark::State& state, const char* net) {
32*4bdc9457SAndroid Build Coastguard Worker const size_t batch_size = state.range(0);
33*4bdc9457SAndroid Build Coastguard Worker const size_t input_height = state.range(1);
34*4bdc9457SAndroid Build Coastguard Worker const size_t input_width = state.range(2);
35*4bdc9457SAndroid Build Coastguard Worker const size_t pooling_size = state.range(3);
36*4bdc9457SAndroid Build Coastguard Worker const size_t padding_size = state.range(4);
37*4bdc9457SAndroid Build Coastguard Worker const size_t stride = state.range(5);
38*4bdc9457SAndroid Build Coastguard Worker const size_t channels = state.range(6);
39*4bdc9457SAndroid Build Coastguard Worker
40*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
41*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
42*4bdc9457SAndroid Build Coastguard Worker auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
43*4bdc9457SAndroid Build Coastguard Worker
44*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = (2 * padding_size + input_height - pooling_size) / stride + 1;
45*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = (2 * padding_size + input_width - pooling_size) / stride + 1;
46*4bdc9457SAndroid Build Coastguard Worker
47*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(batch_size * input_height * input_width * channels + XNN_EXTRA_BYTES / sizeof(uint8_t));
48*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(u8rng));
49*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(batch_size * output_height * output_width * channels);
50*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5);
51*4bdc9457SAndroid Build Coastguard Worker
52*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_initialize(nullptr /* allocator */);
53*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
54*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to initialize XNNPACK");
55*4bdc9457SAndroid Build Coastguard Worker return;
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker
58*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t pooling_op = nullptr;
59*4bdc9457SAndroid Build Coastguard Worker status = xnn_create_average_pooling2d_nhwc_qu8(
60*4bdc9457SAndroid Build Coastguard Worker padding_size, padding_size, padding_size, padding_size,
61*4bdc9457SAndroid Build Coastguard Worker pooling_size, pooling_size,
62*4bdc9457SAndroid Build Coastguard Worker stride, stride,
63*4bdc9457SAndroid Build Coastguard Worker channels, channels /* input pixel stride */, channels /* output pixel stride */,
64*4bdc9457SAndroid Build Coastguard Worker 127 /* input zero point */, 0.75f /* input scale */,
65*4bdc9457SAndroid Build Coastguard Worker 127 /* output zero point */, 1.25f /* output scale */,
66*4bdc9457SAndroid Build Coastguard Worker 0, 255,
67*4bdc9457SAndroid Build Coastguard Worker 0 /* flags */, &pooling_op);
68*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
69*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create Average Pooling operator");
70*4bdc9457SAndroid Build Coastguard Worker return;
71*4bdc9457SAndroid Build Coastguard Worker }
72*4bdc9457SAndroid Build Coastguard Worker
73*4bdc9457SAndroid Build Coastguard Worker status = xnn_setup_average_pooling2d_nhwc_qu8(
74*4bdc9457SAndroid Build Coastguard Worker pooling_op,
75*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
76*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
77*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */);
78*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
79*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to setup Average Pooling operator");
80*4bdc9457SAndroid Build Coastguard Worker return;
81*4bdc9457SAndroid Build Coastguard Worker }
82*4bdc9457SAndroid Build Coastguard Worker
83*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
84*4bdc9457SAndroid Build Coastguard Worker status = xnn_run_operator(pooling_op, nullptr /* thread pool */);
85*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
86*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to run Average Pooling operator");
87*4bdc9457SAndroid Build Coastguard Worker return;
88*4bdc9457SAndroid Build Coastguard Worker }
89*4bdc9457SAndroid Build Coastguard Worker }
90*4bdc9457SAndroid Build Coastguard Worker
91*4bdc9457SAndroid Build Coastguard Worker status = xnn_delete_operator(pooling_op);
92*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
93*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to delete Average Pooling operator");
94*4bdc9457SAndroid Build Coastguard Worker return;
95*4bdc9457SAndroid Build Coastguard Worker }
96*4bdc9457SAndroid Build Coastguard Worker pooling_op = nullptr;
97*4bdc9457SAndroid Build Coastguard Worker
98*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
99*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
100*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
101*4bdc9457SAndroid Build Coastguard Worker }
102*4bdc9457SAndroid Build Coastguard Worker
103*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] = benchmark::Counter(
104*4bdc9457SAndroid Build Coastguard Worker uint64_t(state.iterations()) *
105*4bdc9457SAndroid Build Coastguard Worker batch_size * (input_height * input_width + output_height * output_width) * channels * sizeof(uint8_t),
106*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter::kIsRate);
107*4bdc9457SAndroid Build Coastguard Worker }
108*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_NO_QU8_OPERATORS
109*4bdc9457SAndroid Build Coastguard Worker
xnnpack_average_pooling_f32(benchmark::State & state,const char * net)110*4bdc9457SAndroid Build Coastguard Worker static void xnnpack_average_pooling_f32(benchmark::State& state, const char* net) {
111*4bdc9457SAndroid Build Coastguard Worker const size_t batch_size = state.range(0);
112*4bdc9457SAndroid Build Coastguard Worker const size_t input_height = state.range(1);
113*4bdc9457SAndroid Build Coastguard Worker const size_t input_width = state.range(2);
114*4bdc9457SAndroid Build Coastguard Worker const size_t pooling_size = state.range(3);
115*4bdc9457SAndroid Build Coastguard Worker const size_t padding_size = state.range(4);
116*4bdc9457SAndroid Build Coastguard Worker const size_t stride = state.range(5);
117*4bdc9457SAndroid Build Coastguard Worker const size_t channels = state.range(6);
118*4bdc9457SAndroid Build Coastguard Worker
119*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
120*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
121*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
122*4bdc9457SAndroid Build Coastguard Worker
123*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = (2 * padding_size + input_height - pooling_size) / stride + 1;
124*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = (2 * padding_size + input_width - pooling_size) / stride + 1;
125*4bdc9457SAndroid Build Coastguard Worker
126*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(batch_size * input_height * input_width * channels + XNN_EXTRA_BYTES / sizeof(float));
127*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(f32rng));
128*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size * output_height * output_width * channels);
129*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf(""));
130*4bdc9457SAndroid Build Coastguard Worker
131*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_initialize(nullptr /* allocator */);
132*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
133*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to initialize XNNPACK");
134*4bdc9457SAndroid Build Coastguard Worker return;
135*4bdc9457SAndroid Build Coastguard Worker }
136*4bdc9457SAndroid Build Coastguard Worker
137*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t pooling_op = nullptr;
138*4bdc9457SAndroid Build Coastguard Worker status = xnn_create_average_pooling2d_nhwc_f32(
139*4bdc9457SAndroid Build Coastguard Worker padding_size, padding_size, padding_size, padding_size,
140*4bdc9457SAndroid Build Coastguard Worker pooling_size, pooling_size,
141*4bdc9457SAndroid Build Coastguard Worker stride, stride,
142*4bdc9457SAndroid Build Coastguard Worker channels, channels /* input pixel stride */, channels /* output pixel stride */,
143*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity(),
144*4bdc9457SAndroid Build Coastguard Worker 0 /* flags */, &pooling_op);
145*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
146*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create Average Pooling operator");
147*4bdc9457SAndroid Build Coastguard Worker return;
148*4bdc9457SAndroid Build Coastguard Worker }
149*4bdc9457SAndroid Build Coastguard Worker
150*4bdc9457SAndroid Build Coastguard Worker status = xnn_setup_average_pooling2d_nhwc_f32(
151*4bdc9457SAndroid Build Coastguard Worker pooling_op,
152*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
153*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
154*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */);
155*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
156*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to setup Average Pooling operator");
157*4bdc9457SAndroid Build Coastguard Worker return;
158*4bdc9457SAndroid Build Coastguard Worker }
159*4bdc9457SAndroid Build Coastguard Worker
160*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
161*4bdc9457SAndroid Build Coastguard Worker status = xnn_run_operator(pooling_op, nullptr /* thread pool */);
162*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
163*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to run Average Pooling operator");
164*4bdc9457SAndroid Build Coastguard Worker return;
165*4bdc9457SAndroid Build Coastguard Worker }
166*4bdc9457SAndroid Build Coastguard Worker }
167*4bdc9457SAndroid Build Coastguard Worker
168*4bdc9457SAndroid Build Coastguard Worker status = xnn_delete_operator(pooling_op);
169*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
170*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to delete Average Pooling operator");
171*4bdc9457SAndroid Build Coastguard Worker return;
172*4bdc9457SAndroid Build Coastguard Worker }
173*4bdc9457SAndroid Build Coastguard Worker pooling_op = nullptr;
174*4bdc9457SAndroid Build Coastguard Worker
175*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
176*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
177*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
178*4bdc9457SAndroid Build Coastguard Worker }
179*4bdc9457SAndroid Build Coastguard Worker
180*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] = benchmark::Counter(
181*4bdc9457SAndroid Build Coastguard Worker uint64_t(state.iterations()) *
182*4bdc9457SAndroid Build Coastguard Worker batch_size * (input_height * input_width + output_height * output_width) * channels * sizeof(float),
183*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter::kIsRate);
184*4bdc9457SAndroid Build Coastguard Worker }
185*4bdc9457SAndroid Build Coastguard Worker
186*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
tflite_average_pooling_f32(benchmark::State & state,const char * net)187*4bdc9457SAndroid Build Coastguard Worker void tflite_average_pooling_f32(benchmark::State& state, const char* net) {
188*4bdc9457SAndroid Build Coastguard Worker const size_t batch_size = state.range(0);
189*4bdc9457SAndroid Build Coastguard Worker const size_t input_height = state.range(1);
190*4bdc9457SAndroid Build Coastguard Worker const size_t input_width = state.range(2);
191*4bdc9457SAndroid Build Coastguard Worker const size_t pooling_size = state.range(3);
192*4bdc9457SAndroid Build Coastguard Worker const size_t padding_size = state.range(4);
193*4bdc9457SAndroid Build Coastguard Worker const size_t stride = state.range(5);
194*4bdc9457SAndroid Build Coastguard Worker const size_t channels = state.range(6);
195*4bdc9457SAndroid Build Coastguard Worker
196*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
197*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
198*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
199*4bdc9457SAndroid Build Coastguard Worker
200*4bdc9457SAndroid Build Coastguard Worker tflite::Padding padding = tflite::Padding_VALID;
201*4bdc9457SAndroid Build Coastguard Worker if (2 * padding_size == (pooling_size - 1)) {
202*4bdc9457SAndroid Build Coastguard Worker padding = tflite::Padding_SAME;
203*4bdc9457SAndroid Build Coastguard Worker } else if (padding_size == 0) {
204*4bdc9457SAndroid Build Coastguard Worker padding = tflite::Padding_VALID;
205*4bdc9457SAndroid Build Coastguard Worker } else {
206*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("unsupported padding");
207*4bdc9457SAndroid Build Coastguard Worker return;
208*4bdc9457SAndroid Build Coastguard Worker }
209*4bdc9457SAndroid Build Coastguard Worker
210*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = (2 * padding_size + input_height - pooling_size) / stride + 1;
211*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = (2 * padding_size + input_width - pooling_size) / stride + 1;
212*4bdc9457SAndroid Build Coastguard Worker
213*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(batch_size * input_height * input_width * channels + XNN_EXTRA_BYTES / sizeof(float));
214*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(f32rng));
215*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size * output_height * output_width * channels);
216*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf(""));
217*4bdc9457SAndroid Build Coastguard Worker
218*4bdc9457SAndroid Build Coastguard Worker flatbuffers::FlatBufferBuilder builder;
219*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::OperatorCode> operator_code =
220*4bdc9457SAndroid Build Coastguard Worker CreateOperatorCode(builder, tflite::BuiltinOperator_AVERAGE_POOL_2D);
221*4bdc9457SAndroid Build Coastguard Worker
222*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Pool2DOptions> pool2d_options = CreatePool2DOptions(
223*4bdc9457SAndroid Build Coastguard Worker builder, padding,
224*4bdc9457SAndroid Build Coastguard Worker stride /* stride_w */, stride /* stride_h */,
225*4bdc9457SAndroid Build Coastguard Worker pooling_size /* filter_width */, pooling_size /* filter_height */,
226*4bdc9457SAndroid Build Coastguard Worker tflite::ActivationFunctionType_NONE);
227*4bdc9457SAndroid Build Coastguard Worker
228*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Buffer> buffers[1] = {
229*4bdc9457SAndroid Build Coastguard Worker tflite::CreateBuffer(builder, builder.CreateVector({})),
230*4bdc9457SAndroid Build Coastguard Worker };
231*4bdc9457SAndroid Build Coastguard Worker
232*4bdc9457SAndroid Build Coastguard Worker const int32_t input_shape[4] = {
233*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(batch_size),
234*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(input_height),
235*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(input_width),
236*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(channels)
237*4bdc9457SAndroid Build Coastguard Worker };
238*4bdc9457SAndroid Build Coastguard Worker const int32_t output_shape[4] = {
239*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(batch_size),
240*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(output_height),
241*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(output_width),
242*4bdc9457SAndroid Build Coastguard Worker static_cast<int32_t>(channels)
243*4bdc9457SAndroid Build Coastguard Worker };
244*4bdc9457SAndroid Build Coastguard Worker
245*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Tensor> tensors[2] = {
246*4bdc9457SAndroid Build Coastguard Worker tflite::CreateTensor(builder,
247*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(input_shape, 4),
248*4bdc9457SAndroid Build Coastguard Worker tflite::TensorType_FLOAT32),
249*4bdc9457SAndroid Build Coastguard Worker tflite::CreateTensor(builder,
250*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(output_shape, 4),
251*4bdc9457SAndroid Build Coastguard Worker tflite::TensorType_FLOAT32),
252*4bdc9457SAndroid Build Coastguard Worker };
253*4bdc9457SAndroid Build Coastguard Worker
254*4bdc9457SAndroid Build Coastguard Worker const int32_t op_inputs[1] = { 0 };
255*4bdc9457SAndroid Build Coastguard Worker const int32_t op_outputs[1] = { 1 };
256*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Operator> op = CreateOperator(
257*4bdc9457SAndroid Build Coastguard Worker builder,
258*4bdc9457SAndroid Build Coastguard Worker 0 /* opcode_index */,
259*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(op_inputs, 1),
260*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(op_outputs, 1),
261*4bdc9457SAndroid Build Coastguard Worker tflite::BuiltinOptions_Pool2DOptions,
262*4bdc9457SAndroid Build Coastguard Worker pool2d_options.Union());
263*4bdc9457SAndroid Build Coastguard Worker
264*4bdc9457SAndroid Build Coastguard Worker const int32_t graph_inputs[1] = { 0 };
265*4bdc9457SAndroid Build Coastguard Worker const int32_t graph_outputs[1] = { 1 };
266*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::SubGraph> subgraph = CreateSubGraph(
267*4bdc9457SAndroid Build Coastguard Worker builder,
268*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(tensors, 2),
269*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(graph_inputs, 1),
270*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector<int32_t>(graph_outputs, 1),
271*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&op, 1));
272*4bdc9457SAndroid Build Coastguard Worker
273*4bdc9457SAndroid Build Coastguard Worker flatbuffers::Offset<tflite::Model> model_buffer = tflite::CreateModel(builder,
274*4bdc9457SAndroid Build Coastguard Worker TFLITE_SCHEMA_VERSION,
275*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&operator_code, 1),
276*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(&subgraph, 1),
277*4bdc9457SAndroid Build Coastguard Worker builder.CreateString("AVERAGE_POOL_2D model"),
278*4bdc9457SAndroid Build Coastguard Worker builder.CreateVector(buffers, 1));
279*4bdc9457SAndroid Build Coastguard Worker
280*4bdc9457SAndroid Build Coastguard Worker builder.Finish(model_buffer);
281*4bdc9457SAndroid Build Coastguard Worker
282*4bdc9457SAndroid Build Coastguard Worker const tflite::Model* model = tflite::GetModel(builder.GetBufferPointer());
283*4bdc9457SAndroid Build Coastguard Worker tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
284*4bdc9457SAndroid Build Coastguard Worker tflite::InterpreterBuilder interpreterBuilder(model, resolver);
285*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter;
286*4bdc9457SAndroid Build Coastguard Worker if (interpreterBuilder(&interpreter) != kTfLiteOk) {
287*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to create TFLite interpreter");
288*4bdc9457SAndroid Build Coastguard Worker return;
289*4bdc9457SAndroid Build Coastguard Worker }
290*4bdc9457SAndroid Build Coastguard Worker if (interpreter == nullptr) {
291*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("TFLite interpreter is null");
292*4bdc9457SAndroid Build Coastguard Worker return;
293*4bdc9457SAndroid Build Coastguard Worker }
294*4bdc9457SAndroid Build Coastguard Worker interpreter->SetNumThreads(1);
295*4bdc9457SAndroid Build Coastguard Worker
296*4bdc9457SAndroid Build Coastguard Worker if (interpreter->AllocateTensors() != kTfLiteOk) {
297*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to allocate tensors");
298*4bdc9457SAndroid Build Coastguard Worker return;
299*4bdc9457SAndroid Build Coastguard Worker }
300*4bdc9457SAndroid Build Coastguard Worker
301*4bdc9457SAndroid Build Coastguard Worker std::generate(
302*4bdc9457SAndroid Build Coastguard Worker interpreter->typed_tensor<float>(0),
303*4bdc9457SAndroid Build Coastguard Worker interpreter->typed_tensor<float>(0) + batch_size * input_height * input_width * channels,
304*4bdc9457SAndroid Build Coastguard Worker std::ref(f32rng));
305*4bdc9457SAndroid Build Coastguard Worker
306*4bdc9457SAndroid Build Coastguard Worker for (auto _ : state) {
307*4bdc9457SAndroid Build Coastguard Worker if (interpreter->Invoke() != kTfLiteOk) {
308*4bdc9457SAndroid Build Coastguard Worker state.SkipWithError("failed to invoke TFLite interpreter");
309*4bdc9457SAndroid Build Coastguard Worker return;
310*4bdc9457SAndroid Build Coastguard Worker }
311*4bdc9457SAndroid Build Coastguard Worker }
312*4bdc9457SAndroid Build Coastguard Worker
313*4bdc9457SAndroid Build Coastguard Worker const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency();
314*4bdc9457SAndroid Build Coastguard Worker if (cpu_frequency != 0) {
315*4bdc9457SAndroid Build Coastguard Worker state.counters["cpufreq"] = cpu_frequency;
316*4bdc9457SAndroid Build Coastguard Worker }
317*4bdc9457SAndroid Build Coastguard Worker
318*4bdc9457SAndroid Build Coastguard Worker state.counters["bytes"] = benchmark::Counter(
319*4bdc9457SAndroid Build Coastguard Worker uint64_t(state.iterations()) *
320*4bdc9457SAndroid Build Coastguard Worker batch_size * (input_height * input_width + output_height * output_width) * channels * sizeof(float),
321*4bdc9457SAndroid Build Coastguard Worker benchmark::Counter::kIsRate);
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
324*4bdc9457SAndroid Build Coastguard Worker
325*4bdc9457SAndroid Build Coastguard Worker // Final global average pooling in ImageNet classification models.
ImageNet(benchmark::internal::Benchmark * b)326*4bdc9457SAndroid Build Coastguard Worker static void ImageNet(benchmark::internal::Benchmark* b) {
327*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
328*4bdc9457SAndroid Build Coastguard Worker
329*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
330*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 13, 13, 13, 0, 1, 1000});
331*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 7, 0, 1, 1000});
332*4bdc9457SAndroid Build Coastguard Worker }
333*4bdc9457SAndroid Build Coastguard Worker
334*4bdc9457SAndroid Build Coastguard Worker // ShuffleNet v1 with 1 group.
ShuffleNetV1G1(benchmark::internal::Benchmark * b)335*4bdc9457SAndroid Build Coastguard Worker static void ShuffleNetV1G1(benchmark::internal::Benchmark* b) {
336*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
337*4bdc9457SAndroid Build Coastguard Worker
338*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
339*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 56, 56, 3, 1, 2, 24});
340*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 28, 28, 3, 1, 2, 144});
341*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 14, 14, 3, 1, 2, 288});
342*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 3, 1, 2, 576});
343*4bdc9457SAndroid Build Coastguard Worker }
344*4bdc9457SAndroid Build Coastguard Worker
345*4bdc9457SAndroid Build Coastguard Worker // ShuffleNet v1 with 2 groups.
ShuffleNetV1G2(benchmark::internal::Benchmark * b)346*4bdc9457SAndroid Build Coastguard Worker static void ShuffleNetV1G2(benchmark::internal::Benchmark* b) {
347*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
348*4bdc9457SAndroid Build Coastguard Worker
349*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
350*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 56, 56, 3, 1, 2, 24});
351*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 28, 28, 3, 1, 2, 200});
352*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 14, 14, 3, 1, 2, 400});
353*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 3, 1, 2, 800});
354*4bdc9457SAndroid Build Coastguard Worker }
355*4bdc9457SAndroid Build Coastguard Worker
356*4bdc9457SAndroid Build Coastguard Worker // ShuffleNet v1 with 3 groups.
ShuffleNetV1G3(benchmark::internal::Benchmark * b)357*4bdc9457SAndroid Build Coastguard Worker static void ShuffleNetV1G3(benchmark::internal::Benchmark* b) {
358*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
359*4bdc9457SAndroid Build Coastguard Worker
360*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
361*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 56, 56, 3, 1, 2, 24});
362*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 28, 28, 3, 1, 2, 240});
363*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 14, 14, 3, 1, 2, 480});
364*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 3, 1, 2, 960});
365*4bdc9457SAndroid Build Coastguard Worker }
366*4bdc9457SAndroid Build Coastguard Worker
367*4bdc9457SAndroid Build Coastguard Worker // ShuffleNet v1 with 4 groups.
ShuffleNetV1G4(benchmark::internal::Benchmark * b)368*4bdc9457SAndroid Build Coastguard Worker static void ShuffleNetV1G4(benchmark::internal::Benchmark* b) {
369*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
370*4bdc9457SAndroid Build Coastguard Worker
371*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
372*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 56, 56, 3, 1, 2, 24});
373*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 28, 28, 3, 1, 2, 272});
374*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 14, 14, 3, 1, 2, 576});
375*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 3, 1, 2, 1088});
376*4bdc9457SAndroid Build Coastguard Worker }
377*4bdc9457SAndroid Build Coastguard Worker
378*4bdc9457SAndroid Build Coastguard Worker // ShuffleNet v1 with 8 groups.
ShuffleNetV1G8(benchmark::internal::Benchmark * b)379*4bdc9457SAndroid Build Coastguard Worker static void ShuffleNetV1G8(benchmark::internal::Benchmark* b) {
380*4bdc9457SAndroid Build Coastguard Worker b->ArgNames({"N", "H", "W", "K", "P", "S", "C"});
381*4bdc9457SAndroid Build Coastguard Worker
382*4bdc9457SAndroid Build Coastguard Worker /* N H W K P S C */
383*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 56, 56, 3, 1, 2, 24});
384*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 28, 28, 3, 1, 2, 384});
385*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 14, 14, 3, 1, 2, 768});
386*4bdc9457SAndroid Build Coastguard Worker b->Args({1, 7, 7, 3, 1, 2, 1536});
387*4bdc9457SAndroid Build Coastguard Worker }
388*4bdc9457SAndroid Build Coastguard Worker
389*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, imagenet, "ImageNet")->Apply(ImageNet)->UseRealTime();
390*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1)->UseRealTime();
391*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2)->UseRealTime();
392*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3)->UseRealTime();
393*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4)->UseRealTime();
394*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_f32, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8)->UseRealTime();
395*4bdc9457SAndroid Build Coastguard Worker
396*4bdc9457SAndroid Build Coastguard Worker #ifdef BENCHMARK_TENSORFLOW_LITE
397*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, imagenet, "ImageNet")->Apply(ImageNet)->UseRealTime();
398*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1)->UseRealTime();
399*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2)->UseRealTime();
400*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3)->UseRealTime();
401*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4)->UseRealTime();
402*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(tflite_average_pooling_f32, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8)->UseRealTime();
403*4bdc9457SAndroid Build Coastguard Worker #endif // BENCHMARK_TENSORFLOW_LITE
404*4bdc9457SAndroid Build Coastguard Worker
405*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_QU8_OPERATORS
406*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, imagenet, "ImageNet")->Apply(ImageNet)->UseRealTime();
407*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g1, "ShuffleNet v1 (1 group)")->Apply(ShuffleNetV1G1)->UseRealTime();
408*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)")->Apply(ShuffleNetV1G2)->UseRealTime();
409*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)")->Apply(ShuffleNetV1G3)->UseRealTime();
410*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)")->Apply(ShuffleNetV1G4)->UseRealTime();
411*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8)->UseRealTime();
412*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_NO_QU8_OPERATORS
413*4bdc9457SAndroid Build Coastguard Worker
414*4bdc9457SAndroid Build Coastguard Worker #ifndef XNNPACK_BENCHMARK_NO_MAIN
415*4bdc9457SAndroid Build Coastguard Worker BENCHMARK_MAIN();
416*4bdc9457SAndroid Build Coastguard Worker #endif
417