xref: /aosp_15_r20/external/armnn/tests/TfLiteBenchmark-Armnn/TfLiteBenchmark-Armnn.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 STMicroelectronics and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
7*89c4ff92SAndroid Build Coastguard Worker #include <getopt.h>
8*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
9*89c4ff92SAndroid Build Coastguard Worker #include <signal.h>
10*89c4ff92SAndroid Build Coastguard Worker #include <string>
11*89c4ff92SAndroid Build Coastguard Worker #include <sys/time.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <vector>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendId.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnnTfLiteParser/ITfLiteParser.hpp>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker // Application parameters
21*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> default_preferred_backends_order = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
22*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> preferred_backends_order;
23*89c4ff92SAndroid Build Coastguard Worker std::string model_file_str;
24*89c4ff92SAndroid Build Coastguard Worker std::string preferred_backend_str;
25*89c4ff92SAndroid Build Coastguard Worker int nb_loops = 1;
26*89c4ff92SAndroid Build Coastguard Worker 
get_us(struct timeval t)27*89c4ff92SAndroid Build Coastguard Worker double get_us(struct timeval t)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     return (armnn::numeric_cast<double>(t.tv_sec) *
30*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<double>(1000000) +
31*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<double>(t.tv_usec));
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker 
get_ms(struct timeval t)34*89c4ff92SAndroid Build Coastguard Worker double get_ms(struct timeval t)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker     return (armnn::numeric_cast<double>(t.tv_sec) *
37*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<double>(1000) +
38*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<double>(t.tv_usec) / 1000);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker 
print_help(char ** argv)41*89c4ff92SAndroid Build Coastguard Worker static void print_help(char** argv)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker     std::cout <<
44*89c4ff92SAndroid Build Coastguard Worker         "Usage: " << argv[0] << " -m <model .tflite>\n"
45*89c4ff92SAndroid Build Coastguard Worker         "\n"
46*89c4ff92SAndroid Build Coastguard Worker         "-m --model_file <.tflite file path>:  .tflite model to be executed\n"
47*89c4ff92SAndroid Build Coastguard Worker         "-b --backend <device>:                preferred backend device to run layers on by default. Possible choices: "
48*89c4ff92SAndroid Build Coastguard Worker                                                << armnn::BackendRegistryInstance().GetBackendIdsAsString() << "\n"
49*89c4ff92SAndroid Build Coastguard Worker         "                                      (by default CpuAcc, CpuRef)\n"
50*89c4ff92SAndroid Build Coastguard Worker         "-l --loops <int>:                     provide the number of times the inference will be executed\n"
51*89c4ff92SAndroid Build Coastguard Worker         "                                      (by default nb_loops=1)\n"
52*89c4ff92SAndroid Build Coastguard Worker         "--help:                               show this help\n";
53*89c4ff92SAndroid Build Coastguard Worker     exit(1);
54*89c4ff92SAndroid Build Coastguard Worker }
55*89c4ff92SAndroid Build Coastguard Worker 
process_args(int argc,char ** argv)56*89c4ff92SAndroid Build Coastguard Worker void process_args(int argc, char** argv)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker     const char* const short_opts = "m:b:l:h";
59*89c4ff92SAndroid Build Coastguard Worker     const option long_opts[] = {
60*89c4ff92SAndroid Build Coastguard Worker         {"model_file",   required_argument, nullptr, 'm'},
61*89c4ff92SAndroid Build Coastguard Worker         {"backend",      required_argument, nullptr, 'b'},
62*89c4ff92SAndroid Build Coastguard Worker         {"loops",        required_argument, nullptr, 'l'},
63*89c4ff92SAndroid Build Coastguard Worker         {"help",         no_argument,       nullptr, 'h'},
64*89c4ff92SAndroid Build Coastguard Worker         {nullptr,        no_argument,       nullptr, 0}
65*89c4ff92SAndroid Build Coastguard Worker     };
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     while (true)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         const auto opt = getopt_long(argc, argv, short_opts, long_opts, nullptr);
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker         if (-1 == opt)
72*89c4ff92SAndroid Build Coastguard Worker         {
73*89c4ff92SAndroid Build Coastguard Worker             break;
74*89c4ff92SAndroid Build Coastguard Worker         }
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker         switch (opt)
77*89c4ff92SAndroid Build Coastguard Worker         {
78*89c4ff92SAndroid Build Coastguard Worker         case 'm':
79*89c4ff92SAndroid Build Coastguard Worker             model_file_str = std::string(optarg);
80*89c4ff92SAndroid Build Coastguard Worker             std::cout << "model file set to: " << model_file_str << std::endl;
81*89c4ff92SAndroid Build Coastguard Worker             break;
82*89c4ff92SAndroid Build Coastguard Worker         case 'b':
83*89c4ff92SAndroid Build Coastguard Worker             preferred_backend_str = std::string(optarg);
84*89c4ff92SAndroid Build Coastguard Worker             // Overwrite the backend
85*89c4ff92SAndroid Build Coastguard Worker             preferred_backends_order.push_back(preferred_backend_str);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker             std::cout << "backend device set to:" << preferred_backend_str << std::endl;;
88*89c4ff92SAndroid Build Coastguard Worker             break;
89*89c4ff92SAndroid Build Coastguard Worker         case 'l':
90*89c4ff92SAndroid Build Coastguard Worker             nb_loops = std::stoi(optarg);
91*89c4ff92SAndroid Build Coastguard Worker             std::cout << "benchmark will execute " << nb_loops << " inference(s)" << std::endl;
92*89c4ff92SAndroid Build Coastguard Worker             break;
93*89c4ff92SAndroid Build Coastguard Worker         case 'h': // -h or --help
94*89c4ff92SAndroid Build Coastguard Worker         case '?': // Unrecognized option
95*89c4ff92SAndroid Build Coastguard Worker         default:
96*89c4ff92SAndroid Build Coastguard Worker             print_help(argv);
97*89c4ff92SAndroid Build Coastguard Worker             break;
98*89c4ff92SAndroid Build Coastguard Worker         }
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     if (model_file_str.empty())
102*89c4ff92SAndroid Build Coastguard Worker     {
103*89c4ff92SAndroid Build Coastguard Worker         print_help(argv);
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])107*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<double> inferenceTimes;
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     // Get options
112*89c4ff92SAndroid Build Coastguard Worker     process_args(argc, argv);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     // Create the runtime
115*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime::CreationOptions options;
116*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     // Create Parser
119*89c4ff92SAndroid Build Coastguard Worker     armnnTfLiteParser::ITfLiteParserPtr armnnparser(armnnTfLiteParser::ITfLiteParser::Create());
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     // Create a network
122*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnnparser->CreateNetworkFromBinaryFile(model_file_str.c_str());
123*89c4ff92SAndroid Build Coastguard Worker     if (!network)
124*89c4ff92SAndroid Build Coastguard Worker     {
125*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Failed to create an ArmNN network");
126*89c4ff92SAndroid Build Coastguard Worker     }
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
129*89c4ff92SAndroid Build Coastguard Worker     if (preferred_backends_order.size() == 0)
130*89c4ff92SAndroid Build Coastguard Worker     {
131*89c4ff92SAndroid Build Coastguard Worker         preferred_backends_order = default_preferred_backends_order;
132*89c4ff92SAndroid Build Coastguard Worker     }
133*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optimizedNet = armnn::Optimize(*network,
134*89c4ff92SAndroid Build Coastguard Worker                                                                preferred_backends_order,
135*89c4ff92SAndroid Build Coastguard Worker                                                                runtime->GetDeviceSpec());
136*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId networkId;
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     // Load the network in to the runtime
139*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(networkId, std::move(optimizedNet));
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     // Check the number of subgraph
142*89c4ff92SAndroid Build Coastguard Worker     if (armnnparser->GetSubgraphCount() != 1)
143*89c4ff92SAndroid Build Coastguard Worker     {
144*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Model with more than 1 subgraph is not supported by this benchmark application.\n";
145*89c4ff92SAndroid Build Coastguard Worker         exit(0);
146*89c4ff92SAndroid Build Coastguard Worker     }
147*89c4ff92SAndroid Build Coastguard Worker     size_t subgraphId = 0;
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker     // Set up the input network
150*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\nModel information:" << std::endl;
151*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnnTfLiteParser::BindingPointInfo> inputBindings;
152*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo>                   inputTensorInfos;
153*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> inputTensorNames = armnnparser->GetSubgraphInputTensorNames(subgraphId);
154*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputTensorNames.size() ; i++)
155*89c4ff92SAndroid Build Coastguard Worker     {
156*89c4ff92SAndroid Build Coastguard Worker         std::cout << "inputTensorNames[" << i << "] = " << inputTensorNames[i] << std::endl;
157*89c4ff92SAndroid Build Coastguard Worker         armnnTfLiteParser::BindingPointInfo inputBinding = armnnparser->GetNetworkInputBindingInfo(
158*89c4ff92SAndroid Build Coastguard Worker                                                                            subgraphId,
159*89c4ff92SAndroid Build Coastguard Worker                                                                            inputTensorNames[i]);
160*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(networkId, inputBinding.first);
161*89c4ff92SAndroid Build Coastguard Worker         inputBindings.push_back(inputBinding);
162*89c4ff92SAndroid Build Coastguard Worker         inputTensorInfos.push_back(inputTensorInfo);
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker     // Set up the output network
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnnTfLiteParser::BindingPointInfo> outputBindings;
167*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo>                   outputTensorInfos;
168*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> outputTensorNames = armnnparser->GetSubgraphOutputTensorNames(subgraphId);
169*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < outputTensorNames.size() ; i++)
170*89c4ff92SAndroid Build Coastguard Worker     {
171*89c4ff92SAndroid Build Coastguard Worker         std::cout << "outputTensorNames[" << i << "] = " << outputTensorNames[i] << std::endl;
172*89c4ff92SAndroid Build Coastguard Worker         armnnTfLiteParser::BindingPointInfo outputBinding = armnnparser->GetNetworkOutputBindingInfo(
173*89c4ff92SAndroid Build Coastguard Worker                                                                              subgraphId,
174*89c4ff92SAndroid Build Coastguard Worker                                                                              outputTensorNames[i]);
175*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTensorInfo = runtime->GetOutputTensorInfo(networkId, outputBinding.first);
176*89c4ff92SAndroid Build Coastguard Worker         outputBindings.push_back(outputBinding);
177*89c4ff92SAndroid Build Coastguard Worker         outputTensorInfos.push_back(outputTensorInfo);
178*89c4ff92SAndroid Build Coastguard Worker     }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker     // Allocate input tensors
181*89c4ff92SAndroid Build Coastguard Worker     unsigned int nb_inputs = armnn::numeric_cast<unsigned int>(inputTensorInfos.size());
182*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
183*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> in;
184*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0 ; i < nb_inputs ; i++)
185*89c4ff92SAndroid Build Coastguard Worker     {
186*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> in_data(inputTensorInfos.at(i).GetNumElements());
187*89c4ff92SAndroid Build Coastguard Worker         in.push_back(in_data);
188*89c4ff92SAndroid Build Coastguard Worker         inputTensors.push_back({ inputBindings[i].first, armnn::ConstTensor(inputBindings[i].second, in[i].data()) });
189*89c4ff92SAndroid Build Coastguard Worker     }
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     // Allocate output tensors
192*89c4ff92SAndroid Build Coastguard Worker     unsigned int nb_ouputs = armnn::numeric_cast<unsigned int>(outputTensorInfos.size());
193*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
194*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>> out;
195*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < nb_ouputs ; i++)
196*89c4ff92SAndroid Build Coastguard Worker     {
197*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> out_data(outputTensorInfos.at(i).GetNumElements());
198*89c4ff92SAndroid Build Coastguard Worker         out.push_back(out_data);
199*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back({ outputBindings[i].first, armnn::Tensor(outputBindings[i].second, out[i].data()) });
200*89c4ff92SAndroid Build Coastguard Worker     }
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     // Run the inferences
203*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\ninferences are running: " << std::flush;
204*89c4ff92SAndroid Build Coastguard Worker     for (int i = 0 ; i < nb_loops ; i++)
205*89c4ff92SAndroid Build Coastguard Worker     {
206*89c4ff92SAndroid Build Coastguard Worker         struct timeval start_time, stop_time;
207*89c4ff92SAndroid Build Coastguard Worker         gettimeofday(&start_time, nullptr);
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker         runtime->EnqueueWorkload(networkId, inputTensors, outputTensors);
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker         gettimeofday(&stop_time, nullptr);
212*89c4ff92SAndroid Build Coastguard Worker         inferenceTimes.push_back((get_us(stop_time) - get_us(start_time)));
213*89c4ff92SAndroid Build Coastguard Worker         std::cout << "# " << std::flush;
214*89c4ff92SAndroid Build Coastguard Worker     }
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker     auto maxInfTime = *std::max_element(inferenceTimes.begin(), inferenceTimes.end());
217*89c4ff92SAndroid Build Coastguard Worker     auto minInfTime = *std::min_element(inferenceTimes.begin(), inferenceTimes.end());
218*89c4ff92SAndroid Build Coastguard Worker     auto avgInfTime = accumulate(inferenceTimes.begin(), inferenceTimes.end(), 0.0) /
219*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<double>(inferenceTimes.size());
220*89c4ff92SAndroid Build Coastguard Worker     std::cout << "\n\ninference time: ";
221*89c4ff92SAndroid Build Coastguard Worker     std::cout << "min=" << minInfTime << "us  ";
222*89c4ff92SAndroid Build Coastguard Worker     std::cout << "max=" << maxInfTime << "us  ";
223*89c4ff92SAndroid Build Coastguard Worker     std::cout << "avg=" << avgInfTime << "us" << std::endl;
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     return 0;
226*89c4ff92SAndroid Build Coastguard Worker }
227