1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/ITfLiteParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "NMS.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <stb/stb_image.h>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <ghc/filesystem.hpp>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker #include <chrono>
21*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
22*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
23*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
24*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker using namespace armnnTfLiteParser;
27*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker static const int OPEN_FILE_ERROR = -2;
30*89c4ff92SAndroid Build Coastguard Worker static const int OPTIMIZE_NETWORK_ERROR = -3;
31*89c4ff92SAndroid Build Coastguard Worker static const int LOAD_NETWORK_ERROR = -4;
32*89c4ff92SAndroid Build Coastguard Worker static const int LOAD_IMAGE_ERROR = -5;
33*89c4ff92SAndroid Build Coastguard Worker static const int GENERAL_ERROR = -100;
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker #define CHECK_OK(v) \
36*89c4ff92SAndroid Build Coastguard Worker do { \
37*89c4ff92SAndroid Build Coastguard Worker try { \
38*89c4ff92SAndroid Build Coastguard Worker auto r_local = v; \
39*89c4ff92SAndroid Build Coastguard Worker if (r_local != 0) { return r_local;} \
40*89c4ff92SAndroid Build Coastguard Worker } \
41*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::Exception& e) \
42*89c4ff92SAndroid Build Coastguard Worker { \
43*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Oops: " << e.what(); \
44*89c4ff92SAndroid Build Coastguard Worker return GENERAL_ERROR; \
45*89c4ff92SAndroid Build Coastguard Worker } \
46*89c4ff92SAndroid Build Coastguard Worker } while(0)
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<std::reference_wrapper<TContainer>> & inputDataContainers)51*89c4ff92SAndroid Build Coastguard Worker inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
52*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<TContainer>>& inputDataContainers)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker const size_t numInputs = inputBindings.size();
57*89c4ff92SAndroid Build Coastguard Worker if (numInputs != inputDataContainers.size())
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Mismatching vectors");
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numInputs; i++)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& inputBinding = inputBindings[i];
65*89c4ff92SAndroid Build Coastguard Worker const TContainer& inputData = inputDataContainers[i].get();
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor inputTensor(inputBinding.second, inputData.data());
68*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker return inputTensors;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,const std::vector<std::reference_wrapper<TContainer>> & outputDataContainers)75*89c4ff92SAndroid Build Coastguard Worker inline armnn::OutputTensors MakeOutputTensors(
76*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BindingPointInfo>& outputBindings,
77*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<TContainer>>& outputDataContainers)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker const size_t numOutputs = outputBindings.size();
82*89c4ff92SAndroid Build Coastguard Worker if (numOutputs != outputDataContainers.size())
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Mismatching vectors");
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker outputTensors.reserve(numOutputs);
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < numOutputs; i++)
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& outputBinding = outputBindings[i];
92*89c4ff92SAndroid Build Coastguard Worker const TContainer& outputData = outputDataContainers[i].get();
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker armnn::Tensor outputTensor(outputBinding.second, const_cast<float*>(outputData.data()));
95*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker return outputTensors;
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker #define S_BOOL(name) enum class name {False=0, True=1};
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker S_BOOL(ImportMemory)
S_BOOL(DumpToDot)104*89c4ff92SAndroid Build Coastguard Worker S_BOOL(DumpToDot)
105*89c4ff92SAndroid Build Coastguard Worker S_BOOL(ExpectFile)
106*89c4ff92SAndroid Build Coastguard Worker S_BOOL(OptionalArg)
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker int LoadModel(const char* filename,
109*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser& parser,
110*89c4ff92SAndroid Build Coastguard Worker IRuntime& runtime,
111*89c4ff92SAndroid Build Coastguard Worker NetworkId& networkId,
112*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences,
113*89c4ff92SAndroid Build Coastguard Worker ImportMemory enableImport,
114*89c4ff92SAndroid Build Coastguard Worker DumpToDot dumpToDot)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker std::ifstream stream(filename, std::ios::in | std::ios::binary);
117*89c4ff92SAndroid Build Coastguard Worker if (!stream.is_open())
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Could not open model: " << filename;
120*89c4ff92SAndroid Build Coastguard Worker return OPEN_FILE_ERROR;
121*89c4ff92SAndroid Build Coastguard Worker }
122*89c4ff92SAndroid Build Coastguard Worker
123*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> contents((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
124*89c4ff92SAndroid Build Coastguard Worker stream.close();
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker auto model = parser.CreateNetworkFromBinary(contents);
127*89c4ff92SAndroid Build Coastguard Worker contents.clear();
128*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(debug) << "Model loaded ok: " << filename;
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker // Optimize backbone model
131*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque options;
132*89c4ff92SAndroid Build Coastguard Worker options.SetImportEnabled(enableImport != ImportMemory::False);
133*89c4ff92SAndroid Build Coastguard Worker auto optimizedModel = Optimize(*model, backendPreferences, runtime.GetDeviceSpec(), options);
134*89c4ff92SAndroid Build Coastguard Worker if (!optimizedModel)
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Could not optimize the model:" << filename;
137*89c4ff92SAndroid Build Coastguard Worker return OPTIMIZE_NETWORK_ERROR;
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker if (dumpToDot != DumpToDot::False)
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
143*89c4ff92SAndroid Build Coastguard Worker ss << filename << ".dot";
144*89c4ff92SAndroid Build Coastguard Worker std::ofstream dotStream(ss.str().c_str(), std::ofstream::out);
145*89c4ff92SAndroid Build Coastguard Worker optimizedModel->SerializeToDot(dotStream);
146*89c4ff92SAndroid Build Coastguard Worker dotStream.close();
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker // Load model into runtime
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource memSource = options.GetImportEnabled() ? armnn::MemorySource::Malloc
153*89c4ff92SAndroid Build Coastguard Worker : armnn::MemorySource::Undefined;
154*89c4ff92SAndroid Build Coastguard Worker INetworkProperties modelProps(false, memSource, memSource);
155*89c4ff92SAndroid Build Coastguard Worker Status status = runtime.LoadNetwork(networkId, std::move(optimizedModel), errorMessage, modelProps);
156*89c4ff92SAndroid Build Coastguard Worker if (status != Status::Success)
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Could not load " << filename << " model into runtime: " << errorMessage;
159*89c4ff92SAndroid Build Coastguard Worker return LOAD_NETWORK_ERROR;
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker return 0;
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker
LoadImage(const char * filename)166*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LoadImage(const char* filename)
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker if (strlen(filename) == 0)
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker return std::vector<float>(1920*10180*3, 0.0f);
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker struct Memory
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker ~Memory() {stbi_image_free(m_Data);}
175*89c4ff92SAndroid Build Coastguard Worker bool IsLoaded() const { return m_Data != nullptr;}
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker unsigned char* m_Data;
178*89c4ff92SAndroid Build Coastguard Worker };
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker std::vector<float> image;
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker int width;
183*89c4ff92SAndroid Build Coastguard Worker int height;
184*89c4ff92SAndroid Build Coastguard Worker int channels;
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker Memory mem = {stbi_load(filename, &width, &height, &channels, 3)};
187*89c4ff92SAndroid Build Coastguard Worker if (!mem.IsLoaded())
188*89c4ff92SAndroid Build Coastguard Worker {
189*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Could not load input image file: " << filename;
190*89c4ff92SAndroid Build Coastguard Worker return image;
191*89c4ff92SAndroid Build Coastguard Worker }
192*89c4ff92SAndroid Build Coastguard Worker
193*89c4ff92SAndroid Build Coastguard Worker if (width != 1920 || height != 1080 || channels != 3)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Input image has wong dimension: " << width << "x" << height << "x" << channels << ". "
196*89c4ff92SAndroid Build Coastguard Worker " Expected 1920x1080x3.";
197*89c4ff92SAndroid Build Coastguard Worker return image;
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker image.resize(1920*1080*3);
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker // Expand to float. Does this need de-gamma?
203*89c4ff92SAndroid Build Coastguard Worker for (unsigned int idx=0; idx <= 1920*1080*3; idx++)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker image[idx] = static_cast<float>(mem.m_Data[idx]) /255.0f;
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker return image;
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker
ValidateFilePath(std::string & file,ExpectFile expectFile)212*89c4ff92SAndroid Build Coastguard Worker bool ValidateFilePath(std::string& file, ExpectFile expectFile)
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker if (!ghc::filesystem::exists(file))
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Given file path " << file << " does not exist" << std::endl;
217*89c4ff92SAndroid Build Coastguard Worker return false;
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker if (!ghc::filesystem::is_regular_file(file) && expectFile == ExpectFile::True)
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Given file path " << file << " is not a regular file" << std::endl;
222*89c4ff92SAndroid Build Coastguard Worker return false;
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker return true;
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker
CheckAccuracy(std::vector<float> * toDetector0,std::vector<float> * toDetector1,std::vector<float> * toDetector2,std::vector<float> * detectorOutput,const std::vector<yolov3::Detection> & nmsOut,const std::vector<std::string> & filePaths)227*89c4ff92SAndroid Build Coastguard Worker void CheckAccuracy(std::vector<float>* toDetector0, std::vector<float>* toDetector1,
228*89c4ff92SAndroid Build Coastguard Worker std::vector<float>* toDetector2, std::vector<float>* detectorOutput,
229*89c4ff92SAndroid Build Coastguard Worker const std::vector<yolov3::Detection>& nmsOut, const std::vector<std::string>& filePaths)
230*89c4ff92SAndroid Build Coastguard Worker {
231*89c4ff92SAndroid Build Coastguard Worker std::ifstream pathStream;
232*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expected;
233*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<float>*> outputs;
234*89c4ff92SAndroid Build Coastguard Worker float compare = 0;
235*89c4ff92SAndroid Build Coastguard Worker unsigned int count = 0;
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker //Push back output vectors from inference for use in loop
238*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(toDetector0);
239*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(toDetector1);
240*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(toDetector2);
241*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(detectorOutput);
242*89c4ff92SAndroid Build Coastguard Worker
243*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputs.size(); ++i)
244*89c4ff92SAndroid Build Coastguard Worker {
245*89c4ff92SAndroid Build Coastguard Worker // Reading expected output files and assigning them to @expected. Close and Clear to reuse stream and clean RAM
246*89c4ff92SAndroid Build Coastguard Worker pathStream.open(filePaths[i]);
247*89c4ff92SAndroid Build Coastguard Worker if (!pathStream.is_open())
248*89c4ff92SAndroid Build Coastguard Worker {
249*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Expected output file can not be opened: " << filePaths[i];
250*89c4ff92SAndroid Build Coastguard Worker continue;
251*89c4ff92SAndroid Build Coastguard Worker }
252*89c4ff92SAndroid Build Coastguard Worker
253*89c4ff92SAndroid Build Coastguard Worker expected.assign(std::istream_iterator<float>(pathStream), {});
254*89c4ff92SAndroid Build Coastguard Worker pathStream.close();
255*89c4ff92SAndroid Build Coastguard Worker pathStream.clear();
256*89c4ff92SAndroid Build Coastguard Worker
257*89c4ff92SAndroid Build Coastguard Worker // Ensure each vector is the same length
258*89c4ff92SAndroid Build Coastguard Worker if (expected.size() != outputs[i]->size())
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Expected output size does not match actual output size: " << filePaths[i];
261*89c4ff92SAndroid Build Coastguard Worker }
262*89c4ff92SAndroid Build Coastguard Worker else
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker count = 0;
265*89c4ff92SAndroid Build Coastguard Worker
266*89c4ff92SAndroid Build Coastguard Worker // Compare abs(difference) with tolerance to check for value by value equality
267*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < outputs[i]->size(); ++j)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker compare = std::abs(expected[j] - outputs[i]->at(j));
270*89c4ff92SAndroid Build Coastguard Worker if (compare > 0.001f)
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker count++;
273*89c4ff92SAndroid Build Coastguard Worker }
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker if (count > 0)
276*89c4ff92SAndroid Build Coastguard Worker {
277*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << count << " output(s) do not match expected values in: " << filePaths[i];
278*89c4ff92SAndroid Build Coastguard Worker }
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker pathStream.open(filePaths[4]);
283*89c4ff92SAndroid Build Coastguard Worker if (!pathStream.is_open())
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Expected output file can not be opened: " << filePaths[4];
286*89c4ff92SAndroid Build Coastguard Worker }
287*89c4ff92SAndroid Build Coastguard Worker else
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker expected.assign(std::istream_iterator<float>(pathStream), {});
290*89c4ff92SAndroid Build Coastguard Worker pathStream.close();
291*89c4ff92SAndroid Build Coastguard Worker pathStream.clear();
292*89c4ff92SAndroid Build Coastguard Worker unsigned int y = 0;
293*89c4ff92SAndroid Build Coastguard Worker unsigned int numOfMember = 6;
294*89c4ff92SAndroid Build Coastguard Worker std::vector<float> intermediate;
295*89c4ff92SAndroid Build Coastguard Worker
296*89c4ff92SAndroid Build Coastguard Worker for (auto& detection: nmsOut)
297*89c4ff92SAndroid Build Coastguard Worker {
298*89c4ff92SAndroid Build Coastguard Worker for (unsigned int x = y * numOfMember; x < ((y * numOfMember) + numOfMember); ++x)
299*89c4ff92SAndroid Build Coastguard Worker {
300*89c4ff92SAndroid Build Coastguard Worker intermediate.push_back(expected[x]);
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker if (!yolov3::compare_detection(detection, intermediate))
303*89c4ff92SAndroid Build Coastguard Worker {
304*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Expected NMS output does not match: Detection " << y + 1;
305*89c4ff92SAndroid Build Coastguard Worker }
306*89c4ff92SAndroid Build Coastguard Worker intermediate.clear();
307*89c4ff92SAndroid Build Coastguard Worker y++;
308*89c4ff92SAndroid Build Coastguard Worker }
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker struct ParseArgs
313*89c4ff92SAndroid Build Coastguard Worker {
ParseArgsParseArgs314*89c4ff92SAndroid Build Coastguard Worker ParseArgs(int ac, char *av[]) : options{"TfLiteYoloV3Big-Armnn",
315*89c4ff92SAndroid Build Coastguard Worker "Executes YoloV3Big using ArmNN. YoloV3Big consists "
316*89c4ff92SAndroid Build Coastguard Worker "of 3 parts: A backbone TfLite model, a detector TfLite "
317*89c4ff92SAndroid Build Coastguard Worker "model, and None Maximum Suppression. All parts are "
318*89c4ff92SAndroid Build Coastguard Worker "executed successively."}
319*89c4ff92SAndroid Build Coastguard Worker {
320*89c4ff92SAndroid Build Coastguard Worker options.add_options()
321*89c4ff92SAndroid Build Coastguard Worker ("b,backbone-path",
322*89c4ff92SAndroid Build Coastguard Worker "File path where the TfLite model for the yoloV3big backbone "
323*89c4ff92SAndroid Build Coastguard Worker "can be found e.g. mydir/yoloV3big_backbone.tflite",
324*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>())
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker ("c,comparison-files",
327*89c4ff92SAndroid Build Coastguard Worker "Defines the expected outputs for the model "
328*89c4ff92SAndroid Build Coastguard Worker "of yoloV3big e.g. 'mydir/file1.txt,mydir/file2.txt,mydir/file3.txt,mydir/file4.txt'->InputToDetector1"
329*89c4ff92SAndroid Build Coastguard Worker " will be tried first then InputToDetector2 then InputToDetector3 then the Detector Output and finally"
330*89c4ff92SAndroid Build Coastguard Worker " the NMS output. NOTE: Files are passed as comma separated list without whitespaces.",
331*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>()->default_value({}))
332*89c4ff92SAndroid Build Coastguard Worker
333*89c4ff92SAndroid Build Coastguard Worker ("d,detector-path",
334*89c4ff92SAndroid Build Coastguard Worker "File path where the TfLite model for the yoloV3big "
335*89c4ff92SAndroid Build Coastguard Worker "detector can be found e.g.'mydir/yoloV3big_detector.tflite'",
336*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>())
337*89c4ff92SAndroid Build Coastguard Worker
338*89c4ff92SAndroid Build Coastguard Worker ("h,help", "Produce help message")
339*89c4ff92SAndroid Build Coastguard Worker
340*89c4ff92SAndroid Build Coastguard Worker ("i,image-path",
341*89c4ff92SAndroid Build Coastguard Worker "File path to a 1080x1920 jpg image that should be "
342*89c4ff92SAndroid Build Coastguard Worker "processed e.g. 'mydir/example_img_180_1920.jpg'",
343*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>())
344*89c4ff92SAndroid Build Coastguard Worker
345*89c4ff92SAndroid Build Coastguard Worker ("B,preferred-backends-backbone",
346*89c4ff92SAndroid Build Coastguard Worker "Defines the preferred backends to run the backbone model "
347*89c4ff92SAndroid Build Coastguard Worker "of yoloV3big e.g. 'GpuAcc,CpuRef' -> GpuAcc will be tried "
348*89c4ff92SAndroid Build Coastguard Worker "first before falling back to CpuRef. NOTE: Backends are passed "
349*89c4ff92SAndroid Build Coastguard Worker "as comma separated list without whitespaces.",
350*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>()->default_value("GpuAcc,CpuRef"))
351*89c4ff92SAndroid Build Coastguard Worker
352*89c4ff92SAndroid Build Coastguard Worker ("D,preferred-backends-detector",
353*89c4ff92SAndroid Build Coastguard Worker "Defines the preferred backends to run the detector model "
354*89c4ff92SAndroid Build Coastguard Worker "of yoloV3big e.g. 'CpuAcc,CpuRef' -> CpuAcc will be tried "
355*89c4ff92SAndroid Build Coastguard Worker "first before falling back to CpuRef. NOTE: Backends are passed "
356*89c4ff92SAndroid Build Coastguard Worker "as comma separated list without whitespaces.",
357*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::vector<std::string>>()->default_value("CpuAcc,CpuRef"))
358*89c4ff92SAndroid Build Coastguard Worker
359*89c4ff92SAndroid Build Coastguard Worker ("M, model-to-dot",
360*89c4ff92SAndroid Build Coastguard Worker "Dump the optimized model to a dot file for debugging/analysis",
361*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<bool>()->default_value("false"))
362*89c4ff92SAndroid Build Coastguard Worker
363*89c4ff92SAndroid Build Coastguard Worker ("Y, dynamic-backends-path",
364*89c4ff92SAndroid Build Coastguard Worker "Define a path from which to load any dynamic backends.",
365*89c4ff92SAndroid Build Coastguard Worker cxxopts::value<std::string>());
366*89c4ff92SAndroid Build Coastguard Worker
367*89c4ff92SAndroid Build Coastguard Worker auto result = options.parse(ac, av);
368*89c4ff92SAndroid Build Coastguard Worker
369*89c4ff92SAndroid Build Coastguard Worker if (result.count("help"))
370*89c4ff92SAndroid Build Coastguard Worker {
371*89c4ff92SAndroid Build Coastguard Worker std::cout << options.help() << "\n";
372*89c4ff92SAndroid Build Coastguard Worker exit(EXIT_SUCCESS);
373*89c4ff92SAndroid Build Coastguard Worker }
374*89c4ff92SAndroid Build Coastguard Worker
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker backboneDir = GetPathArgument(result, "backbone-path", ExpectFile::True, OptionalArg::False);
377*89c4ff92SAndroid Build Coastguard Worker
378*89c4ff92SAndroid Build Coastguard Worker comparisonFiles = GetPathArgument(result["comparison-files"].as<std::vector<std::string>>(), OptionalArg::True);
379*89c4ff92SAndroid Build Coastguard Worker
380*89c4ff92SAndroid Build Coastguard Worker detectorDir = GetPathArgument(result, "detector-path", ExpectFile::True, OptionalArg::False);
381*89c4ff92SAndroid Build Coastguard Worker
382*89c4ff92SAndroid Build Coastguard Worker imageDir = GetPathArgument(result, "image-path", ExpectFile::True, OptionalArg::True);
383*89c4ff92SAndroid Build Coastguard Worker
384*89c4ff92SAndroid Build Coastguard Worker dynamicBackendPath = GetPathArgument(result, "dynamic-backends-path", ExpectFile::False, OptionalArg::True);
385*89c4ff92SAndroid Build Coastguard Worker
386*89c4ff92SAndroid Build Coastguard Worker prefBackendsBackbone = GetBackendIDs(result["preferred-backends-backbone"].as<std::vector<std::string>>());
387*89c4ff92SAndroid Build Coastguard Worker LogBackendsInfo(prefBackendsBackbone, "Backbone");
388*89c4ff92SAndroid Build Coastguard Worker prefBackendsDetector = GetBackendIDs(result["preferred-backends-detector"].as<std::vector<std::string>>());
389*89c4ff92SAndroid Build Coastguard Worker LogBackendsInfo(prefBackendsDetector, "detector");
390*89c4ff92SAndroid Build Coastguard Worker
391*89c4ff92SAndroid Build Coastguard Worker dumpToDot = result["model-to-dot"].as<bool>() ? DumpToDot::True : DumpToDot::False;
392*89c4ff92SAndroid Build Coastguard Worker }
393*89c4ff92SAndroid Build Coastguard Worker
394*89c4ff92SAndroid Build Coastguard Worker /// Takes a vector of backend strings and returns a vector of backendIDs
GetBackendIDsParseArgs395*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> GetBackendIDs(const std::vector<std::string>& backendStrings)
396*89c4ff92SAndroid Build Coastguard Worker {
397*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> backendIDs;
398*89c4ff92SAndroid Build Coastguard Worker for (const auto& b : backendStrings)
399*89c4ff92SAndroid Build Coastguard Worker {
400*89c4ff92SAndroid Build Coastguard Worker backendIDs.push_back(BackendId(b));
401*89c4ff92SAndroid Build Coastguard Worker }
402*89c4ff92SAndroid Build Coastguard Worker return backendIDs;
403*89c4ff92SAndroid Build Coastguard Worker }
404*89c4ff92SAndroid Build Coastguard Worker
405*89c4ff92SAndroid Build Coastguard Worker /// Verifies if the program argument with the name argName contains a valid file path.
406*89c4ff92SAndroid Build Coastguard Worker /// Returns the valid file path string if given argument is associated a valid file path.
407*89c4ff92SAndroid Build Coastguard Worker /// Otherwise throws an exception.
GetPathArgumentParseArgs408*89c4ff92SAndroid Build Coastguard Worker std::string GetPathArgument(cxxopts::ParseResult& result,
409*89c4ff92SAndroid Build Coastguard Worker std::string&& argName,
410*89c4ff92SAndroid Build Coastguard Worker ExpectFile expectFile,
411*89c4ff92SAndroid Build Coastguard Worker OptionalArg isOptionalArg)
412*89c4ff92SAndroid Build Coastguard Worker {
413*89c4ff92SAndroid Build Coastguard Worker if (result.count(argName))
414*89c4ff92SAndroid Build Coastguard Worker {
415*89c4ff92SAndroid Build Coastguard Worker std::string path = result[argName].as<std::string>();
416*89c4ff92SAndroid Build Coastguard Worker if (!ValidateFilePath(path, expectFile))
417*89c4ff92SAndroid Build Coastguard Worker {
418*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
419*89c4ff92SAndroid Build Coastguard Worker ss << "Argument given to" << argName << "is not a valid file path";
420*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::option_syntax_exception(ss.str().c_str());
421*89c4ff92SAndroid Build Coastguard Worker }
422*89c4ff92SAndroid Build Coastguard Worker return path;
423*89c4ff92SAndroid Build Coastguard Worker }
424*89c4ff92SAndroid Build Coastguard Worker else
425*89c4ff92SAndroid Build Coastguard Worker {
426*89c4ff92SAndroid Build Coastguard Worker if (isOptionalArg == OptionalArg::True)
427*89c4ff92SAndroid Build Coastguard Worker {
428*89c4ff92SAndroid Build Coastguard Worker return "";
429*89c4ff92SAndroid Build Coastguard Worker }
430*89c4ff92SAndroid Build Coastguard Worker
431*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::missing_argument_exception(argName);
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker }
434*89c4ff92SAndroid Build Coastguard Worker
435*89c4ff92SAndroid Build Coastguard Worker /// Assigns vector of strings to struct member variable
GetPathArgumentParseArgs436*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> GetPathArgument(const std::vector<std::string>& pathStrings, OptionalArg isOptional)
437*89c4ff92SAndroid Build Coastguard Worker {
438*89c4ff92SAndroid Build Coastguard Worker if (pathStrings.size() < 5){
439*89c4ff92SAndroid Build Coastguard Worker if (isOptional == OptionalArg::True)
440*89c4ff92SAndroid Build Coastguard Worker {
441*89c4ff92SAndroid Build Coastguard Worker return std::vector<std::string>();
442*89c4ff92SAndroid Build Coastguard Worker }
443*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::option_syntax_exception("Comparison files requires 5 file paths.");
444*89c4ff92SAndroid Build Coastguard Worker }
445*89c4ff92SAndroid Build Coastguard Worker
446*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> filePaths;
447*89c4ff92SAndroid Build Coastguard Worker for (auto& path : pathStrings)
448*89c4ff92SAndroid Build Coastguard Worker {
449*89c4ff92SAndroid Build Coastguard Worker filePaths.push_back(path);
450*89c4ff92SAndroid Build Coastguard Worker if (!ValidateFilePath(filePaths.back(), ExpectFile::True))
451*89c4ff92SAndroid Build Coastguard Worker {
452*89c4ff92SAndroid Build Coastguard Worker throw cxxopts::option_syntax_exception("Argument given to Comparison Files is not a valid file path");
453*89c4ff92SAndroid Build Coastguard Worker }
454*89c4ff92SAndroid Build Coastguard Worker }
455*89c4ff92SAndroid Build Coastguard Worker return filePaths;
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker
458*89c4ff92SAndroid Build Coastguard Worker /// Log info about assigned backends
LogBackendsInfoParseArgs459*89c4ff92SAndroid Build Coastguard Worker void LogBackendsInfo(std::vector<BackendId>& backends, std::string&& modelName)
460*89c4ff92SAndroid Build Coastguard Worker {
461*89c4ff92SAndroid Build Coastguard Worker std::string info;
462*89c4ff92SAndroid Build Coastguard Worker info = "Preferred backends for " + modelName + " set to [ ";
463*89c4ff92SAndroid Build Coastguard Worker for (auto const &backend : backends)
464*89c4ff92SAndroid Build Coastguard Worker {
465*89c4ff92SAndroid Build Coastguard Worker info = info + std::string(backend) + " ";
466*89c4ff92SAndroid Build Coastguard Worker }
467*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << info << "]";
468*89c4ff92SAndroid Build Coastguard Worker }
469*89c4ff92SAndroid Build Coastguard Worker
470*89c4ff92SAndroid Build Coastguard Worker // Member variables
471*89c4ff92SAndroid Build Coastguard Worker std::string backboneDir;
472*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> comparisonFiles;
473*89c4ff92SAndroid Build Coastguard Worker std::string detectorDir;
474*89c4ff92SAndroid Build Coastguard Worker std::string imageDir;
475*89c4ff92SAndroid Build Coastguard Worker std::string dynamicBackendPath;
476*89c4ff92SAndroid Build Coastguard Worker
477*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> prefBackendsBackbone;
478*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> prefBackendsDetector;
479*89c4ff92SAndroid Build Coastguard Worker
480*89c4ff92SAndroid Build Coastguard Worker cxxopts::Options options;
481*89c4ff92SAndroid Build Coastguard Worker
482*89c4ff92SAndroid Build Coastguard Worker DumpToDot dumpToDot;
483*89c4ff92SAndroid Build Coastguard Worker };
484*89c4ff92SAndroid Build Coastguard Worker
main(int argc,char * argv[])485*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker // Configure logging
488*89c4ff92SAndroid Build Coastguard Worker SetAllLoggingSinks(true, true, true);
489*89c4ff92SAndroid Build Coastguard Worker SetLogFilter(LogSeverity::Trace);
490*89c4ff92SAndroid Build Coastguard Worker
491*89c4ff92SAndroid Build Coastguard Worker // Check and get given program arguments
492*89c4ff92SAndroid Build Coastguard Worker ParseArgs progArgs = ParseArgs(argc, argv);
493*89c4ff92SAndroid Build Coastguard Worker
494*89c4ff92SAndroid Build Coastguard Worker // Create runtime
495*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions runtimeOptions; // default
496*89c4ff92SAndroid Build Coastguard Worker
497*89c4ff92SAndroid Build Coastguard Worker if (!progArgs.dynamicBackendPath.empty())
498*89c4ff92SAndroid Build Coastguard Worker {
499*89c4ff92SAndroid Build Coastguard Worker std::cout << "Loading backends from" << progArgs.dynamicBackendPath << "\n";
500*89c4ff92SAndroid Build Coastguard Worker runtimeOptions.m_DynamicBackendsPath = progArgs.dynamicBackendPath;
501*89c4ff92SAndroid Build Coastguard Worker }
502*89c4ff92SAndroid Build Coastguard Worker
503*89c4ff92SAndroid Build Coastguard Worker auto runtime = IRuntime::Create(runtimeOptions);
504*89c4ff92SAndroid Build Coastguard Worker if (!runtime)
505*89c4ff92SAndroid Build Coastguard Worker {
506*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(fatal) << "Could not create runtime.";
507*89c4ff92SAndroid Build Coastguard Worker return -1;
508*89c4ff92SAndroid Build Coastguard Worker }
509*89c4ff92SAndroid Build Coastguard Worker
510*89c4ff92SAndroid Build Coastguard Worker // Create TfLite Parsers
511*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser::TfLiteParserOptions parserOptions;
512*89c4ff92SAndroid Build Coastguard Worker auto parser = ITfLiteParser::Create(parserOptions);
513*89c4ff92SAndroid Build Coastguard Worker
514*89c4ff92SAndroid Build Coastguard Worker // Load backbone model
515*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Loading backbone...";
516*89c4ff92SAndroid Build Coastguard Worker NetworkId backboneId;
517*89c4ff92SAndroid Build Coastguard Worker const DumpToDot dumpToDot = progArgs.dumpToDot;
518*89c4ff92SAndroid Build Coastguard Worker CHECK_OK(LoadModel(progArgs.backboneDir.c_str(),
519*89c4ff92SAndroid Build Coastguard Worker *parser,
520*89c4ff92SAndroid Build Coastguard Worker *runtime,
521*89c4ff92SAndroid Build Coastguard Worker backboneId,
522*89c4ff92SAndroid Build Coastguard Worker progArgs.prefBackendsBackbone,
523*89c4ff92SAndroid Build Coastguard Worker ImportMemory::False,
524*89c4ff92SAndroid Build Coastguard Worker dumpToDot));
525*89c4ff92SAndroid Build Coastguard Worker auto inputId = parser->GetNetworkInputBindingInfo(0, "inputs");
526*89c4ff92SAndroid Build Coastguard Worker auto bbOut0Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_1");
527*89c4ff92SAndroid Build Coastguard Worker auto bbOut1Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_2");
528*89c4ff92SAndroid Build Coastguard Worker auto bbOut2Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_3");
529*89c4ff92SAndroid Build Coastguard Worker auto backboneProfile = runtime->GetProfiler(backboneId);
530*89c4ff92SAndroid Build Coastguard Worker backboneProfile->EnableProfiling(true);
531*89c4ff92SAndroid Build Coastguard Worker
532*89c4ff92SAndroid Build Coastguard Worker
533*89c4ff92SAndroid Build Coastguard Worker // Load detector model
534*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Loading detector...";
535*89c4ff92SAndroid Build Coastguard Worker NetworkId detectorId;
536*89c4ff92SAndroid Build Coastguard Worker CHECK_OK(LoadModel(progArgs.detectorDir.c_str(),
537*89c4ff92SAndroid Build Coastguard Worker *parser,
538*89c4ff92SAndroid Build Coastguard Worker *runtime,
539*89c4ff92SAndroid Build Coastguard Worker detectorId,
540*89c4ff92SAndroid Build Coastguard Worker progArgs.prefBackendsDetector,
541*89c4ff92SAndroid Build Coastguard Worker ImportMemory::True,
542*89c4ff92SAndroid Build Coastguard Worker dumpToDot));
543*89c4ff92SAndroid Build Coastguard Worker auto detectIn0Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_1");
544*89c4ff92SAndroid Build Coastguard Worker auto detectIn1Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_2");
545*89c4ff92SAndroid Build Coastguard Worker auto detectIn2Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_3");
546*89c4ff92SAndroid Build Coastguard Worker auto outputBoxesId = parser->GetNetworkOutputBindingInfo(0, "output_boxes");
547*89c4ff92SAndroid Build Coastguard Worker auto detectorProfile = runtime->GetProfiler(detectorId);
548*89c4ff92SAndroid Build Coastguard Worker
549*89c4ff92SAndroid Build Coastguard Worker // Load input from file
550*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Loading test image...";
551*89c4ff92SAndroid Build Coastguard Worker auto image = LoadImage(progArgs.imageDir.c_str());
552*89c4ff92SAndroid Build Coastguard Worker if (image.empty())
553*89c4ff92SAndroid Build Coastguard Worker {
554*89c4ff92SAndroid Build Coastguard Worker return LOAD_IMAGE_ERROR;
555*89c4ff92SAndroid Build Coastguard Worker }
556*89c4ff92SAndroid Build Coastguard Worker
557*89c4ff92SAndroid Build Coastguard Worker // Allocate the intermediate tensors
558*89c4ff92SAndroid Build Coastguard Worker std::vector<float> intermediateMem0(bbOut0Id.second.GetNumElements());
559*89c4ff92SAndroid Build Coastguard Worker std::vector<float> intermediateMem1(bbOut1Id.second.GetNumElements());
560*89c4ff92SAndroid Build Coastguard Worker std::vector<float> intermediateMem2(bbOut2Id.second.GetNumElements());
561*89c4ff92SAndroid Build Coastguard Worker std::vector<float> intermediateMem3(outputBoxesId.second.GetNumElements());
562*89c4ff92SAndroid Build Coastguard Worker
563*89c4ff92SAndroid Build Coastguard Worker // Setup inputs and outputs
564*89c4ff92SAndroid Build Coastguard Worker using BindingInfos = std::vector<armnn::BindingPointInfo>;
565*89c4ff92SAndroid Build Coastguard Worker using FloatTensors = std::vector<std::reference_wrapper<std::vector<float>>>;
566*89c4ff92SAndroid Build Coastguard Worker
567*89c4ff92SAndroid Build Coastguard Worker InputTensors bbInputTensors = MakeInputTensors(BindingInfos{ inputId },
568*89c4ff92SAndroid Build Coastguard Worker FloatTensors{ image });
569*89c4ff92SAndroid Build Coastguard Worker OutputTensors bbOutputTensors = MakeOutputTensors(BindingInfos{ bbOut0Id, bbOut1Id, bbOut2Id },
570*89c4ff92SAndroid Build Coastguard Worker FloatTensors{ intermediateMem0,
571*89c4ff92SAndroid Build Coastguard Worker intermediateMem1,
572*89c4ff92SAndroid Build Coastguard Worker intermediateMem2 });
573*89c4ff92SAndroid Build Coastguard Worker InputTensors detectInputTensors = MakeInputTensors(BindingInfos{ detectIn0Id,
574*89c4ff92SAndroid Build Coastguard Worker detectIn1Id,
575*89c4ff92SAndroid Build Coastguard Worker detectIn2Id } ,
576*89c4ff92SAndroid Build Coastguard Worker FloatTensors{ intermediateMem0,
577*89c4ff92SAndroid Build Coastguard Worker intermediateMem1,
578*89c4ff92SAndroid Build Coastguard Worker intermediateMem2 });
579*89c4ff92SAndroid Build Coastguard Worker OutputTensors detectOutputTensors = MakeOutputTensors(BindingInfos{ outputBoxesId },
580*89c4ff92SAndroid Build Coastguard Worker FloatTensors{ intermediateMem3 });
581*89c4ff92SAndroid Build Coastguard Worker
582*89c4ff92SAndroid Build Coastguard Worker static const int numIterations=2;
583*89c4ff92SAndroid Build Coastguard Worker using DurationUS = std::chrono::duration<double, std::micro>;
584*89c4ff92SAndroid Build Coastguard Worker std::vector<DurationUS> nmsDurations(0);
585*89c4ff92SAndroid Build Coastguard Worker std::vector<yolov3::Detection> filtered_boxes;
586*89c4ff92SAndroid Build Coastguard Worker nmsDurations.reserve(numIterations);
587*89c4ff92SAndroid Build Coastguard Worker for (int i=0; i < numIterations; i++)
588*89c4ff92SAndroid Build Coastguard Worker {
589*89c4ff92SAndroid Build Coastguard Worker // Execute backbone
590*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Running backbone...";
591*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(backboneId, bbInputTensors, bbOutputTensors);
592*89c4ff92SAndroid Build Coastguard Worker
593*89c4ff92SAndroid Build Coastguard Worker // Execute detector
594*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Running detector...";
595*89c4ff92SAndroid Build Coastguard Worker runtime->EnqueueWorkload(detectorId, detectInputTensors, detectOutputTensors);
596*89c4ff92SAndroid Build Coastguard Worker
597*89c4ff92SAndroid Build Coastguard Worker // Execute NMS
598*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Running nms...";
599*89c4ff92SAndroid Build Coastguard Worker using clock = std::chrono::steady_clock;
600*89c4ff92SAndroid Build Coastguard Worker auto nmsStartTime = clock::now();
601*89c4ff92SAndroid Build Coastguard Worker yolov3::NMSConfig config;
602*89c4ff92SAndroid Build Coastguard Worker config.num_boxes = 127800;
603*89c4ff92SAndroid Build Coastguard Worker config.num_classes = 80;
604*89c4ff92SAndroid Build Coastguard Worker config.confidence_threshold = 0.9f;
605*89c4ff92SAndroid Build Coastguard Worker config.iou_threshold = 0.5f;
606*89c4ff92SAndroid Build Coastguard Worker filtered_boxes = yolov3::nms(config, intermediateMem3);
607*89c4ff92SAndroid Build Coastguard Worker auto nmsEndTime = clock::now();
608*89c4ff92SAndroid Build Coastguard Worker
609*89c4ff92SAndroid Build Coastguard Worker // Enable the profiling after the warm-up run
610*89c4ff92SAndroid Build Coastguard Worker if (i>0)
611*89c4ff92SAndroid Build Coastguard Worker {
612*89c4ff92SAndroid Build Coastguard Worker print_detection(std::cout, filtered_boxes);
613*89c4ff92SAndroid Build Coastguard Worker
614*89c4ff92SAndroid Build Coastguard Worker const auto nmsDuration = DurationUS(nmsStartTime - nmsEndTime);
615*89c4ff92SAndroid Build Coastguard Worker nmsDurations.push_back(nmsDuration);
616*89c4ff92SAndroid Build Coastguard Worker }
617*89c4ff92SAndroid Build Coastguard Worker backboneProfile->EnableProfiling(true);
618*89c4ff92SAndroid Build Coastguard Worker detectorProfile->EnableProfiling(true);
619*89c4ff92SAndroid Build Coastguard Worker }
620*89c4ff92SAndroid Build Coastguard Worker // Log timings to file
621*89c4ff92SAndroid Build Coastguard Worker std::ofstream backboneProfileStream("backbone.json");
622*89c4ff92SAndroid Build Coastguard Worker backboneProfile->Print(backboneProfileStream);
623*89c4ff92SAndroid Build Coastguard Worker backboneProfileStream.close();
624*89c4ff92SAndroid Build Coastguard Worker
625*89c4ff92SAndroid Build Coastguard Worker std::ofstream detectorProfileStream("detector.json");
626*89c4ff92SAndroid Build Coastguard Worker detectorProfile->Print(detectorProfileStream);
627*89c4ff92SAndroid Build Coastguard Worker detectorProfileStream.close();
628*89c4ff92SAndroid Build Coastguard Worker
629*89c4ff92SAndroid Build Coastguard Worker // Manually construct the json output
630*89c4ff92SAndroid Build Coastguard Worker std::ofstream nmsProfileStream("nms.json");
631*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << "{" << "\n";
632*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << R"( "NmsTimings": {)" << "\n";
633*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << R"( "raw": [)" << "\n";
634*89c4ff92SAndroid Build Coastguard Worker bool isFirst = true;
635*89c4ff92SAndroid Build Coastguard Worker for (auto duration : nmsDurations)
636*89c4ff92SAndroid Build Coastguard Worker {
637*89c4ff92SAndroid Build Coastguard Worker if (!isFirst)
638*89c4ff92SAndroid Build Coastguard Worker {
639*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << ",\n";
640*89c4ff92SAndroid Build Coastguard Worker }
641*89c4ff92SAndroid Build Coastguard Worker
642*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << " " << duration.count();
643*89c4ff92SAndroid Build Coastguard Worker isFirst = false;
644*89c4ff92SAndroid Build Coastguard Worker }
645*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << "\n";
646*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << R"( "units": "us")" << "\n";
647*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << " ]" << "\n";
648*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << " }" << "\n";
649*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream << "}" << "\n";
650*89c4ff92SAndroid Build Coastguard Worker nmsProfileStream.close();
651*89c4ff92SAndroid Build Coastguard Worker
652*89c4ff92SAndroid Build Coastguard Worker if (progArgs.comparisonFiles.size() > 0)
653*89c4ff92SAndroid Build Coastguard Worker {
654*89c4ff92SAndroid Build Coastguard Worker CheckAccuracy(&intermediateMem0,
655*89c4ff92SAndroid Build Coastguard Worker &intermediateMem1,
656*89c4ff92SAndroid Build Coastguard Worker &intermediateMem2,
657*89c4ff92SAndroid Build Coastguard Worker &intermediateMem3,
658*89c4ff92SAndroid Build Coastguard Worker filtered_boxes,
659*89c4ff92SAndroid Build Coastguard Worker progArgs.comparisonFiles);
660*89c4ff92SAndroid Build Coastguard Worker }
661*89c4ff92SAndroid Build Coastguard Worker
662*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Run completed";
663*89c4ff92SAndroid Build Coastguard Worker return 0;
664*89c4ff92SAndroid Build Coastguard Worker }
665