xref: /aosp_15_r20/external/armnn/tests/TfLiteYoloV3Big-Armnn/TfLiteYoloV3Big-Armnn.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/ITfLiteParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "NMS.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <stb/stb_image.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <cxxopts/cxxopts.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <ghc/filesystem.hpp>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #include <chrono>
21*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
22*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
23*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
24*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker using namespace armnnTfLiteParser;
27*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker static const int OPEN_FILE_ERROR = -2;
30*89c4ff92SAndroid Build Coastguard Worker static const int OPTIMIZE_NETWORK_ERROR = -3;
31*89c4ff92SAndroid Build Coastguard Worker static const int LOAD_NETWORK_ERROR = -4;
32*89c4ff92SAndroid Build Coastguard Worker static const int LOAD_IMAGE_ERROR = -5;
33*89c4ff92SAndroid Build Coastguard Worker static const int GENERAL_ERROR = -100;
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker #define CHECK_OK(v)                                     \
36*89c4ff92SAndroid Build Coastguard Worker     do {                                                \
37*89c4ff92SAndroid Build Coastguard Worker         try {                                           \
38*89c4ff92SAndroid Build Coastguard Worker             auto r_local = v;                           \
39*89c4ff92SAndroid Build Coastguard Worker             if (r_local != 0) { return r_local;}        \
40*89c4ff92SAndroid Build Coastguard Worker         }                                               \
41*89c4ff92SAndroid Build Coastguard Worker         catch (const armnn::Exception& e)               \
42*89c4ff92SAndroid Build Coastguard Worker         {                                               \
43*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "Oops: " << e.what();   \
44*89c4ff92SAndroid Build Coastguard Worker             return GENERAL_ERROR;                       \
45*89c4ff92SAndroid Build Coastguard Worker         }                                               \
46*89c4ff92SAndroid Build Coastguard Worker     } while(0)
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<std::reference_wrapper<TContainer>> & inputDataContainers)51*89c4ff92SAndroid Build Coastguard Worker inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
52*89c4ff92SAndroid Build Coastguard Worker                                             const std::vector<std::reference_wrapper<TContainer>>& inputDataContainers)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     const size_t numInputs = inputBindings.size();
57*89c4ff92SAndroid Build Coastguard Worker     if (numInputs != inputDataContainers.size())
58*89c4ff92SAndroid Build Coastguard Worker     {
59*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Mismatching vectors");
60*89c4ff92SAndroid Build Coastguard Worker     }
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < numInputs; i++)
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         const armnn::BindingPointInfo& inputBinding = inputBindings[i];
65*89c4ff92SAndroid Build Coastguard Worker         const TContainer& inputData = inputDataContainers[i].get();
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker         armnn::ConstTensor inputTensor(inputBinding.second, inputData.data());
68*89c4ff92SAndroid Build Coastguard Worker         inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     return inputTensors;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,const std::vector<std::reference_wrapper<TContainer>> & outputDataContainers)75*89c4ff92SAndroid Build Coastguard Worker inline armnn::OutputTensors MakeOutputTensors(
76*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::BindingPointInfo>& outputBindings,
77*89c4ff92SAndroid Build Coastguard Worker     const std::vector<std::reference_wrapper<TContainer>>& outputDataContainers)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     const size_t numOutputs = outputBindings.size();
82*89c4ff92SAndroid Build Coastguard Worker     if (numOutputs != outputDataContainers.size())
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Mismatching vectors");
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     outputTensors.reserve(numOutputs);
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < numOutputs; i++)
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         const armnn::BindingPointInfo& outputBinding = outputBindings[i];
92*89c4ff92SAndroid Build Coastguard Worker         const TContainer& outputData = outputDataContainers[i].get();
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker         armnn::Tensor outputTensor(outputBinding.second, const_cast<float*>(outputData.data()));
95*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     return outputTensors;
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker #define S_BOOL(name) enum class name {False=0, True=1};
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker S_BOOL(ImportMemory)
S_BOOL(DumpToDot)104*89c4ff92SAndroid Build Coastguard Worker S_BOOL(DumpToDot)
105*89c4ff92SAndroid Build Coastguard Worker S_BOOL(ExpectFile)
106*89c4ff92SAndroid Build Coastguard Worker S_BOOL(OptionalArg)
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker int LoadModel(const char* filename,
109*89c4ff92SAndroid Build Coastguard Worker               ITfLiteParser& parser,
110*89c4ff92SAndroid Build Coastguard Worker               IRuntime& runtime,
111*89c4ff92SAndroid Build Coastguard Worker               NetworkId& networkId,
112*89c4ff92SAndroid Build Coastguard Worker               const std::vector<BackendId>& backendPreferences,
113*89c4ff92SAndroid Build Coastguard Worker               ImportMemory enableImport,
114*89c4ff92SAndroid Build Coastguard Worker               DumpToDot dumpToDot)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker     std::ifstream stream(filename, std::ios::in | std::ios::binary);
117*89c4ff92SAndroid Build Coastguard Worker     if (!stream.is_open())
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Could not open model: " << filename;
120*89c4ff92SAndroid Build Coastguard Worker         return OPEN_FILE_ERROR;
121*89c4ff92SAndroid Build Coastguard Worker     }
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> contents((std::istreambuf_iterator<char>(stream)), std::istreambuf_iterator<char>());
124*89c4ff92SAndroid Build Coastguard Worker     stream.close();
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     auto model = parser.CreateNetworkFromBinary(contents);
127*89c4ff92SAndroid Build Coastguard Worker     contents.clear();
128*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(debug) << "Model loaded ok: " << filename;
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     // Optimize backbone model
131*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque options;
132*89c4ff92SAndroid Build Coastguard Worker     options.SetImportEnabled(enableImport != ImportMemory::False);
133*89c4ff92SAndroid Build Coastguard Worker     auto optimizedModel = Optimize(*model, backendPreferences, runtime.GetDeviceSpec(), options);
134*89c4ff92SAndroid Build Coastguard Worker     if (!optimizedModel)
135*89c4ff92SAndroid Build Coastguard Worker     {
136*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Could not optimize the model:" << filename;
137*89c4ff92SAndroid Build Coastguard Worker         return OPTIMIZE_NETWORK_ERROR;
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     if (dumpToDot != DumpToDot::False)
141*89c4ff92SAndroid Build Coastguard Worker     {
142*89c4ff92SAndroid Build Coastguard Worker         std::stringstream ss;
143*89c4ff92SAndroid Build Coastguard Worker         ss << filename << ".dot";
144*89c4ff92SAndroid Build Coastguard Worker         std::ofstream dotStream(ss.str().c_str(), std::ofstream::out);
145*89c4ff92SAndroid Build Coastguard Worker         optimizedModel->SerializeToDot(dotStream);
146*89c4ff92SAndroid Build Coastguard Worker         dotStream.close();
147*89c4ff92SAndroid Build Coastguard Worker     }
148*89c4ff92SAndroid Build Coastguard Worker     // Load model into runtime
149*89c4ff92SAndroid Build Coastguard Worker     {
150*89c4ff92SAndroid Build Coastguard Worker         std::string errorMessage;
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker         armnn::MemorySource memSource = options.GetImportEnabled() ? armnn::MemorySource::Malloc
153*89c4ff92SAndroid Build Coastguard Worker                                                                 : armnn::MemorySource::Undefined;
154*89c4ff92SAndroid Build Coastguard Worker         INetworkProperties modelProps(false, memSource, memSource);
155*89c4ff92SAndroid Build Coastguard Worker         Status status = runtime.LoadNetwork(networkId, std::move(optimizedModel), errorMessage, modelProps);
156*89c4ff92SAndroid Build Coastguard Worker         if (status != Status::Success)
157*89c4ff92SAndroid Build Coastguard Worker         {
158*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(fatal) << "Could not load " << filename << " model into runtime: " << errorMessage;
159*89c4ff92SAndroid Build Coastguard Worker             return LOAD_NETWORK_ERROR;
160*89c4ff92SAndroid Build Coastguard Worker         }
161*89c4ff92SAndroid Build Coastguard Worker     }
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     return 0;
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker 
LoadImage(const char * filename)166*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LoadImage(const char* filename)
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker     if (strlen(filename) == 0)
169*89c4ff92SAndroid Build Coastguard Worker     {
170*89c4ff92SAndroid Build Coastguard Worker         return std::vector<float>(1920*10180*3, 0.0f);
171*89c4ff92SAndroid Build Coastguard Worker     }
172*89c4ff92SAndroid Build Coastguard Worker     struct Memory
173*89c4ff92SAndroid Build Coastguard Worker     {
174*89c4ff92SAndroid Build Coastguard Worker         ~Memory() {stbi_image_free(m_Data);}
175*89c4ff92SAndroid Build Coastguard Worker         bool IsLoaded() const { return m_Data != nullptr;}
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker         unsigned char* m_Data;
178*89c4ff92SAndroid Build Coastguard Worker     };
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> image;
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker     int width;
183*89c4ff92SAndroid Build Coastguard Worker     int height;
184*89c4ff92SAndroid Build Coastguard Worker     int channels;
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     Memory mem = {stbi_load(filename, &width, &height, &channels, 3)};
187*89c4ff92SAndroid Build Coastguard Worker     if (!mem.IsLoaded())
188*89c4ff92SAndroid Build Coastguard Worker     {
189*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Could not load input image file: " << filename;
190*89c4ff92SAndroid Build Coastguard Worker         return image;
191*89c4ff92SAndroid Build Coastguard Worker     }
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker     if (width != 1920 || height != 1080 || channels != 3)
194*89c4ff92SAndroid Build Coastguard Worker     {
195*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Input image has wong dimension: " << width << "x" << height << "x" << channels << ". "
196*89c4ff92SAndroid Build Coastguard Worker           " Expected 1920x1080x3.";
197*89c4ff92SAndroid Build Coastguard Worker         return image;
198*89c4ff92SAndroid Build Coastguard Worker     }
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker     image.resize(1920*1080*3);
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     // Expand to float. Does this need de-gamma?
203*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int idx=0; idx <= 1920*1080*3; idx++)
204*89c4ff92SAndroid Build Coastguard Worker     {
205*89c4ff92SAndroid Build Coastguard Worker         image[idx] = static_cast<float>(mem.m_Data[idx]) /255.0f;
206*89c4ff92SAndroid Build Coastguard Worker     }
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     return image;
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker 
ValidateFilePath(std::string & file,ExpectFile expectFile)212*89c4ff92SAndroid Build Coastguard Worker bool ValidateFilePath(std::string& file, ExpectFile expectFile)
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker     if (!ghc::filesystem::exists(file))
215*89c4ff92SAndroid Build Coastguard Worker     {
216*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Given file path " << file << " does not exist" << std::endl;
217*89c4ff92SAndroid Build Coastguard Worker         return false;
218*89c4ff92SAndroid Build Coastguard Worker     }
219*89c4ff92SAndroid Build Coastguard Worker     if (!ghc::filesystem::is_regular_file(file) && expectFile == ExpectFile::True)
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Given file path " << file << " is not a regular file" << std::endl;
222*89c4ff92SAndroid Build Coastguard Worker         return false;
223*89c4ff92SAndroid Build Coastguard Worker     }
224*89c4ff92SAndroid Build Coastguard Worker     return true;
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker 
CheckAccuracy(std::vector<float> * toDetector0,std::vector<float> * toDetector1,std::vector<float> * toDetector2,std::vector<float> * detectorOutput,const std::vector<yolov3::Detection> & nmsOut,const std::vector<std::string> & filePaths)227*89c4ff92SAndroid Build Coastguard Worker void CheckAccuracy(std::vector<float>* toDetector0, std::vector<float>* toDetector1,
228*89c4ff92SAndroid Build Coastguard Worker                    std::vector<float>* toDetector2, std::vector<float>* detectorOutput,
229*89c4ff92SAndroid Build Coastguard Worker                    const std::vector<yolov3::Detection>& nmsOut, const std::vector<std::string>& filePaths)
230*89c4ff92SAndroid Build Coastguard Worker {
231*89c4ff92SAndroid Build Coastguard Worker     std::ifstream pathStream;
232*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expected;
233*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::vector<float>*> outputs;
234*89c4ff92SAndroid Build Coastguard Worker     float compare = 0;
235*89c4ff92SAndroid Build Coastguard Worker     unsigned int count = 0;
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     //Push back output vectors from inference for use in loop
238*89c4ff92SAndroid Build Coastguard Worker     outputs.push_back(toDetector0);
239*89c4ff92SAndroid Build Coastguard Worker     outputs.push_back(toDetector1);
240*89c4ff92SAndroid Build Coastguard Worker     outputs.push_back(toDetector2);
241*89c4ff92SAndroid Build Coastguard Worker     outputs.push_back(detectorOutput);
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < outputs.size(); ++i)
244*89c4ff92SAndroid Build Coastguard Worker     {
245*89c4ff92SAndroid Build Coastguard Worker         // Reading expected output files and assigning them to @expected. Close and Clear to reuse stream and clean RAM
246*89c4ff92SAndroid Build Coastguard Worker         pathStream.open(filePaths[i]);
247*89c4ff92SAndroid Build Coastguard Worker         if (!pathStream.is_open())
248*89c4ff92SAndroid Build Coastguard Worker         {
249*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "Expected output file can not be opened: " << filePaths[i];
250*89c4ff92SAndroid Build Coastguard Worker             continue;
251*89c4ff92SAndroid Build Coastguard Worker         }
252*89c4ff92SAndroid Build Coastguard Worker 
253*89c4ff92SAndroid Build Coastguard Worker         expected.assign(std::istream_iterator<float>(pathStream), {});
254*89c4ff92SAndroid Build Coastguard Worker         pathStream.close();
255*89c4ff92SAndroid Build Coastguard Worker         pathStream.clear();
256*89c4ff92SAndroid Build Coastguard Worker 
257*89c4ff92SAndroid Build Coastguard Worker         // Ensure each vector is the same length
258*89c4ff92SAndroid Build Coastguard Worker         if (expected.size() != outputs[i]->size())
259*89c4ff92SAndroid Build Coastguard Worker         {
260*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "Expected output size does not match actual output size: " << filePaths[i];
261*89c4ff92SAndroid Build Coastguard Worker         }
262*89c4ff92SAndroid Build Coastguard Worker         else
263*89c4ff92SAndroid Build Coastguard Worker         {
264*89c4ff92SAndroid Build Coastguard Worker             count = 0;
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker             // Compare abs(difference) with tolerance to check for value by value equality
267*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int j = 0; j < outputs[i]->size(); ++j)
268*89c4ff92SAndroid Build Coastguard Worker             {
269*89c4ff92SAndroid Build Coastguard Worker                 compare = std::abs(expected[j] - outputs[i]->at(j));
270*89c4ff92SAndroid Build Coastguard Worker                 if (compare > 0.001f)
271*89c4ff92SAndroid Build Coastguard Worker                 {
272*89c4ff92SAndroid Build Coastguard Worker                     count++;
273*89c4ff92SAndroid Build Coastguard Worker                 }
274*89c4ff92SAndroid Build Coastguard Worker             }
275*89c4ff92SAndroid Build Coastguard Worker             if (count > 0)
276*89c4ff92SAndroid Build Coastguard Worker             {
277*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << count << " output(s) do not match expected values in: " << filePaths[i];
278*89c4ff92SAndroid Build Coastguard Worker             }
279*89c4ff92SAndroid Build Coastguard Worker         }
280*89c4ff92SAndroid Build Coastguard Worker     }
281*89c4ff92SAndroid Build Coastguard Worker 
282*89c4ff92SAndroid Build Coastguard Worker     pathStream.open(filePaths[4]);
283*89c4ff92SAndroid Build Coastguard Worker     if (!pathStream.is_open())
284*89c4ff92SAndroid Build Coastguard Worker     {
285*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Expected output file can not be opened: " << filePaths[4];
286*89c4ff92SAndroid Build Coastguard Worker     }
287*89c4ff92SAndroid Build Coastguard Worker     else
288*89c4ff92SAndroid Build Coastguard Worker     {
289*89c4ff92SAndroid Build Coastguard Worker         expected.assign(std::istream_iterator<float>(pathStream), {});
290*89c4ff92SAndroid Build Coastguard Worker         pathStream.close();
291*89c4ff92SAndroid Build Coastguard Worker         pathStream.clear();
292*89c4ff92SAndroid Build Coastguard Worker         unsigned int y = 0;
293*89c4ff92SAndroid Build Coastguard Worker         unsigned int numOfMember = 6;
294*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> intermediate;
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker         for (auto& detection: nmsOut)
297*89c4ff92SAndroid Build Coastguard Worker         {
298*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int x = y * numOfMember; x < ((y * numOfMember) + numOfMember); ++x)
299*89c4ff92SAndroid Build Coastguard Worker             {
300*89c4ff92SAndroid Build Coastguard Worker                 intermediate.push_back(expected[x]);
301*89c4ff92SAndroid Build Coastguard Worker             }
302*89c4ff92SAndroid Build Coastguard Worker             if (!yolov3::compare_detection(detection, intermediate))
303*89c4ff92SAndroid Build Coastguard Worker             {
304*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "Expected NMS output does not match: Detection " << y + 1;
305*89c4ff92SAndroid Build Coastguard Worker             }
306*89c4ff92SAndroid Build Coastguard Worker             intermediate.clear();
307*89c4ff92SAndroid Build Coastguard Worker             y++;
308*89c4ff92SAndroid Build Coastguard Worker         }
309*89c4ff92SAndroid Build Coastguard Worker     }
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker struct ParseArgs
313*89c4ff92SAndroid Build Coastguard Worker {
ParseArgsParseArgs314*89c4ff92SAndroid Build Coastguard Worker     ParseArgs(int ac, char *av[]) : options{"TfLiteYoloV3Big-Armnn",
315*89c4ff92SAndroid Build Coastguard Worker                                             "Executes YoloV3Big using ArmNN. YoloV3Big consists "
316*89c4ff92SAndroid Build Coastguard Worker                                             "of 3 parts: A backbone TfLite model, a detector TfLite "
317*89c4ff92SAndroid Build Coastguard Worker                                             "model, and None Maximum Suppression. All parts are "
318*89c4ff92SAndroid Build Coastguard Worker                                             "executed successively."}
319*89c4ff92SAndroid Build Coastguard Worker     {
320*89c4ff92SAndroid Build Coastguard Worker         options.add_options()
321*89c4ff92SAndroid Build Coastguard Worker                 ("b,backbone-path",
322*89c4ff92SAndroid Build Coastguard Worker                  "File path where the TfLite model for the yoloV3big backbone "
323*89c4ff92SAndroid Build Coastguard Worker                  "can be found e.g. mydir/yoloV3big_backbone.tflite",
324*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>())
325*89c4ff92SAndroid Build Coastguard Worker 
326*89c4ff92SAndroid Build Coastguard Worker                ("c,comparison-files",
327*89c4ff92SAndroid Build Coastguard Worker                 "Defines the expected outputs for the model "
328*89c4ff92SAndroid Build Coastguard Worker                 "of yoloV3big e.g. 'mydir/file1.txt,mydir/file2.txt,mydir/file3.txt,mydir/file4.txt'->InputToDetector1"
329*89c4ff92SAndroid Build Coastguard Worker                 " will be tried first then InputToDetector2 then InputToDetector3 then the Detector Output and finally"
330*89c4ff92SAndroid Build Coastguard Worker                 " the NMS output. NOTE: Files are passed as comma separated list without whitespaces.",
331*89c4ff92SAndroid Build Coastguard Worker                 cxxopts::value<std::vector<std::string>>()->default_value({}))
332*89c4ff92SAndroid Build Coastguard Worker 
333*89c4ff92SAndroid Build Coastguard Worker                 ("d,detector-path",
334*89c4ff92SAndroid Build Coastguard Worker                  "File path where the TfLite model for the yoloV3big "
335*89c4ff92SAndroid Build Coastguard Worker                  "detector can be found e.g.'mydir/yoloV3big_detector.tflite'",
336*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>())
337*89c4ff92SAndroid Build Coastguard Worker 
338*89c4ff92SAndroid Build Coastguard Worker                 ("h,help", "Produce help message")
339*89c4ff92SAndroid Build Coastguard Worker 
340*89c4ff92SAndroid Build Coastguard Worker                 ("i,image-path",
341*89c4ff92SAndroid Build Coastguard Worker                  "File path to a 1080x1920 jpg image that should be "
342*89c4ff92SAndroid Build Coastguard Worker                  "processed e.g. 'mydir/example_img_180_1920.jpg'",
343*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>())
344*89c4ff92SAndroid Build Coastguard Worker 
345*89c4ff92SAndroid Build Coastguard Worker                 ("B,preferred-backends-backbone",
346*89c4ff92SAndroid Build Coastguard Worker                  "Defines the preferred backends to run the backbone model "
347*89c4ff92SAndroid Build Coastguard Worker                  "of yoloV3big e.g. 'GpuAcc,CpuRef' -> GpuAcc will be tried "
348*89c4ff92SAndroid Build Coastguard Worker                  "first before falling back to CpuRef. NOTE: Backends are passed "
349*89c4ff92SAndroid Build Coastguard Worker                  "as comma separated list without whitespaces.",
350*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::vector<std::string>>()->default_value("GpuAcc,CpuRef"))
351*89c4ff92SAndroid Build Coastguard Worker 
352*89c4ff92SAndroid Build Coastguard Worker                 ("D,preferred-backends-detector",
353*89c4ff92SAndroid Build Coastguard Worker                  "Defines the preferred backends to run the detector model "
354*89c4ff92SAndroid Build Coastguard Worker                  "of yoloV3big e.g. 'CpuAcc,CpuRef' -> CpuAcc will be tried "
355*89c4ff92SAndroid Build Coastguard Worker                  "first before falling back to CpuRef. NOTE: Backends are passed "
356*89c4ff92SAndroid Build Coastguard Worker                  "as comma separated list without whitespaces.",
357*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::vector<std::string>>()->default_value("CpuAcc,CpuRef"))
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker                 ("M, model-to-dot",
360*89c4ff92SAndroid Build Coastguard Worker                  "Dump the optimized model to a dot file for debugging/analysis",
361*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<bool>()->default_value("false"))
362*89c4ff92SAndroid Build Coastguard Worker 
363*89c4ff92SAndroid Build Coastguard Worker                 ("Y, dynamic-backends-path",
364*89c4ff92SAndroid Build Coastguard Worker                  "Define a path from which to load any dynamic backends.",
365*89c4ff92SAndroid Build Coastguard Worker                  cxxopts::value<std::string>());
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker         auto result = options.parse(ac, av);
368*89c4ff92SAndroid Build Coastguard Worker 
369*89c4ff92SAndroid Build Coastguard Worker         if (result.count("help"))
370*89c4ff92SAndroid Build Coastguard Worker         {
371*89c4ff92SAndroid Build Coastguard Worker             std::cout << options.help() << "\n";
372*89c4ff92SAndroid Build Coastguard Worker             exit(EXIT_SUCCESS);
373*89c4ff92SAndroid Build Coastguard Worker         }
374*89c4ff92SAndroid Build Coastguard Worker 
375*89c4ff92SAndroid Build Coastguard Worker 
376*89c4ff92SAndroid Build Coastguard Worker         backboneDir = GetPathArgument(result, "backbone-path", ExpectFile::True, OptionalArg::False);
377*89c4ff92SAndroid Build Coastguard Worker 
378*89c4ff92SAndroid Build Coastguard Worker         comparisonFiles = GetPathArgument(result["comparison-files"].as<std::vector<std::string>>(), OptionalArg::True);
379*89c4ff92SAndroid Build Coastguard Worker 
380*89c4ff92SAndroid Build Coastguard Worker         detectorDir = GetPathArgument(result, "detector-path", ExpectFile::True, OptionalArg::False);
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker         imageDir    = GetPathArgument(result, "image-path", ExpectFile::True, OptionalArg::True);
383*89c4ff92SAndroid Build Coastguard Worker 
384*89c4ff92SAndroid Build Coastguard Worker         dynamicBackendPath = GetPathArgument(result, "dynamic-backends-path", ExpectFile::False, OptionalArg::True);
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker         prefBackendsBackbone = GetBackendIDs(result["preferred-backends-backbone"].as<std::vector<std::string>>());
387*89c4ff92SAndroid Build Coastguard Worker         LogBackendsInfo(prefBackendsBackbone, "Backbone");
388*89c4ff92SAndroid Build Coastguard Worker         prefBackendsDetector = GetBackendIDs(result["preferred-backends-detector"].as<std::vector<std::string>>());
389*89c4ff92SAndroid Build Coastguard Worker         LogBackendsInfo(prefBackendsDetector, "detector");
390*89c4ff92SAndroid Build Coastguard Worker 
391*89c4ff92SAndroid Build Coastguard Worker         dumpToDot = result["model-to-dot"].as<bool>() ? DumpToDot::True : DumpToDot::False;
392*89c4ff92SAndroid Build Coastguard Worker     }
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker     /// Takes a vector of backend strings and returns a vector of backendIDs
GetBackendIDsParseArgs395*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> GetBackendIDs(const std::vector<std::string>& backendStrings)
396*89c4ff92SAndroid Build Coastguard Worker     {
397*89c4ff92SAndroid Build Coastguard Worker         std::vector<BackendId> backendIDs;
398*89c4ff92SAndroid Build Coastguard Worker         for (const auto& b : backendStrings)
399*89c4ff92SAndroid Build Coastguard Worker         {
400*89c4ff92SAndroid Build Coastguard Worker             backendIDs.push_back(BackendId(b));
401*89c4ff92SAndroid Build Coastguard Worker         }
402*89c4ff92SAndroid Build Coastguard Worker         return backendIDs;
403*89c4ff92SAndroid Build Coastguard Worker     }
404*89c4ff92SAndroid Build Coastguard Worker 
405*89c4ff92SAndroid Build Coastguard Worker     /// Verifies if the program argument with the name argName contains a valid file path.
406*89c4ff92SAndroid Build Coastguard Worker     /// Returns the valid file path string if given argument is associated a valid file path.
407*89c4ff92SAndroid Build Coastguard Worker     /// Otherwise throws an exception.
GetPathArgumentParseArgs408*89c4ff92SAndroid Build Coastguard Worker     std::string GetPathArgument(cxxopts::ParseResult& result,
409*89c4ff92SAndroid Build Coastguard Worker                                 std::string&& argName,
410*89c4ff92SAndroid Build Coastguard Worker                                 ExpectFile expectFile,
411*89c4ff92SAndroid Build Coastguard Worker                                 OptionalArg isOptionalArg)
412*89c4ff92SAndroid Build Coastguard Worker     {
413*89c4ff92SAndroid Build Coastguard Worker         if (result.count(argName))
414*89c4ff92SAndroid Build Coastguard Worker         {
415*89c4ff92SAndroid Build Coastguard Worker             std::string path = result[argName].as<std::string>();
416*89c4ff92SAndroid Build Coastguard Worker             if (!ValidateFilePath(path, expectFile))
417*89c4ff92SAndroid Build Coastguard Worker             {
418*89c4ff92SAndroid Build Coastguard Worker                 std::stringstream ss;
419*89c4ff92SAndroid Build Coastguard Worker                 ss << "Argument given to" << argName << "is not a valid file path";
420*89c4ff92SAndroid Build Coastguard Worker                 throw cxxopts::option_syntax_exception(ss.str().c_str());
421*89c4ff92SAndroid Build Coastguard Worker             }
422*89c4ff92SAndroid Build Coastguard Worker             return path;
423*89c4ff92SAndroid Build Coastguard Worker         }
424*89c4ff92SAndroid Build Coastguard Worker         else
425*89c4ff92SAndroid Build Coastguard Worker         {
426*89c4ff92SAndroid Build Coastguard Worker             if (isOptionalArg == OptionalArg::True)
427*89c4ff92SAndroid Build Coastguard Worker             {
428*89c4ff92SAndroid Build Coastguard Worker                 return "";
429*89c4ff92SAndroid Build Coastguard Worker             }
430*89c4ff92SAndroid Build Coastguard Worker 
431*89c4ff92SAndroid Build Coastguard Worker             throw cxxopts::missing_argument_exception(argName);
432*89c4ff92SAndroid Build Coastguard Worker         }
433*89c4ff92SAndroid Build Coastguard Worker     }
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker     /// Assigns vector of strings to struct member variable
GetPathArgumentParseArgs436*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> GetPathArgument(const std::vector<std::string>& pathStrings, OptionalArg isOptional)
437*89c4ff92SAndroid Build Coastguard Worker     {
438*89c4ff92SAndroid Build Coastguard Worker         if (pathStrings.size() < 5){
439*89c4ff92SAndroid Build Coastguard Worker             if (isOptional == OptionalArg::True)
440*89c4ff92SAndroid Build Coastguard Worker             {
441*89c4ff92SAndroid Build Coastguard Worker                 return std::vector<std::string>();
442*89c4ff92SAndroid Build Coastguard Worker             }
443*89c4ff92SAndroid Build Coastguard Worker             throw cxxopts::option_syntax_exception("Comparison files requires 5 file paths.");
444*89c4ff92SAndroid Build Coastguard Worker         }
445*89c4ff92SAndroid Build Coastguard Worker 
446*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::string> filePaths;
447*89c4ff92SAndroid Build Coastguard Worker         for (auto& path : pathStrings)
448*89c4ff92SAndroid Build Coastguard Worker         {
449*89c4ff92SAndroid Build Coastguard Worker             filePaths.push_back(path);
450*89c4ff92SAndroid Build Coastguard Worker             if (!ValidateFilePath(filePaths.back(), ExpectFile::True))
451*89c4ff92SAndroid Build Coastguard Worker             {
452*89c4ff92SAndroid Build Coastguard Worker                 throw cxxopts::option_syntax_exception("Argument given to Comparison Files is not a valid file path");
453*89c4ff92SAndroid Build Coastguard Worker             }
454*89c4ff92SAndroid Build Coastguard Worker         }
455*89c4ff92SAndroid Build Coastguard Worker         return filePaths;
456*89c4ff92SAndroid Build Coastguard Worker     }
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker     /// Log info about assigned backends
LogBackendsInfoParseArgs459*89c4ff92SAndroid Build Coastguard Worker     void LogBackendsInfo(std::vector<BackendId>& backends, std::string&& modelName)
460*89c4ff92SAndroid Build Coastguard Worker     {
461*89c4ff92SAndroid Build Coastguard Worker         std::string info;
462*89c4ff92SAndroid Build Coastguard Worker         info = "Preferred backends for " + modelName + " set to [ ";
463*89c4ff92SAndroid Build Coastguard Worker         for (auto const &backend : backends)
464*89c4ff92SAndroid Build Coastguard Worker         {
465*89c4ff92SAndroid Build Coastguard Worker             info = info + std::string(backend) + " ";
466*89c4ff92SAndroid Build Coastguard Worker         }
467*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << info << "]";
468*89c4ff92SAndroid Build Coastguard Worker     }
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker     // Member variables
471*89c4ff92SAndroid Build Coastguard Worker     std::string backboneDir;
472*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> comparisonFiles;
473*89c4ff92SAndroid Build Coastguard Worker     std::string detectorDir;
474*89c4ff92SAndroid Build Coastguard Worker     std::string imageDir;
475*89c4ff92SAndroid Build Coastguard Worker     std::string dynamicBackendPath;
476*89c4ff92SAndroid Build Coastguard Worker 
477*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> prefBackendsBackbone;
478*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> prefBackendsDetector;
479*89c4ff92SAndroid Build Coastguard Worker 
480*89c4ff92SAndroid Build Coastguard Worker     cxxopts::Options options;
481*89c4ff92SAndroid Build Coastguard Worker 
482*89c4ff92SAndroid Build Coastguard Worker     DumpToDot dumpToDot;
483*89c4ff92SAndroid Build Coastguard Worker };
484*89c4ff92SAndroid Build Coastguard Worker 
main(int argc,char * argv[])485*89c4ff92SAndroid Build Coastguard Worker int main(int argc, char* argv[])
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker     // Configure logging
488*89c4ff92SAndroid Build Coastguard Worker     SetAllLoggingSinks(true, true, true);
489*89c4ff92SAndroid Build Coastguard Worker     SetLogFilter(LogSeverity::Trace);
490*89c4ff92SAndroid Build Coastguard Worker 
491*89c4ff92SAndroid Build Coastguard Worker     // Check and get given program arguments
492*89c4ff92SAndroid Build Coastguard Worker     ParseArgs progArgs = ParseArgs(argc, argv);
493*89c4ff92SAndroid Build Coastguard Worker 
494*89c4ff92SAndroid Build Coastguard Worker     // Create runtime
495*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions runtimeOptions; // default
496*89c4ff92SAndroid Build Coastguard Worker 
497*89c4ff92SAndroid Build Coastguard Worker     if (!progArgs.dynamicBackendPath.empty())
498*89c4ff92SAndroid Build Coastguard Worker     {
499*89c4ff92SAndroid Build Coastguard Worker         std::cout << "Loading backends from" << progArgs.dynamicBackendPath << "\n";
500*89c4ff92SAndroid Build Coastguard Worker         runtimeOptions.m_DynamicBackendsPath = progArgs.dynamicBackendPath;
501*89c4ff92SAndroid Build Coastguard Worker     }
502*89c4ff92SAndroid Build Coastguard Worker 
503*89c4ff92SAndroid Build Coastguard Worker     auto runtime = IRuntime::Create(runtimeOptions);
504*89c4ff92SAndroid Build Coastguard Worker     if (!runtime)
505*89c4ff92SAndroid Build Coastguard Worker     {
506*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(fatal) << "Could not create runtime.";
507*89c4ff92SAndroid Build Coastguard Worker         return -1;
508*89c4ff92SAndroid Build Coastguard Worker     }
509*89c4ff92SAndroid Build Coastguard Worker 
510*89c4ff92SAndroid Build Coastguard Worker     // Create TfLite Parsers
511*89c4ff92SAndroid Build Coastguard Worker     ITfLiteParser::TfLiteParserOptions parserOptions;
512*89c4ff92SAndroid Build Coastguard Worker     auto parser = ITfLiteParser::Create(parserOptions);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker     // Load backbone model
515*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Loading backbone...";
516*89c4ff92SAndroid Build Coastguard Worker     NetworkId backboneId;
517*89c4ff92SAndroid Build Coastguard Worker     const DumpToDot dumpToDot = progArgs.dumpToDot;
518*89c4ff92SAndroid Build Coastguard Worker     CHECK_OK(LoadModel(progArgs.backboneDir.c_str(),
519*89c4ff92SAndroid Build Coastguard Worker                        *parser,
520*89c4ff92SAndroid Build Coastguard Worker                        *runtime,
521*89c4ff92SAndroid Build Coastguard Worker                        backboneId,
522*89c4ff92SAndroid Build Coastguard Worker                        progArgs.prefBackendsBackbone,
523*89c4ff92SAndroid Build Coastguard Worker                        ImportMemory::False,
524*89c4ff92SAndroid Build Coastguard Worker                        dumpToDot));
525*89c4ff92SAndroid Build Coastguard Worker     auto inputId = parser->GetNetworkInputBindingInfo(0, "inputs");
526*89c4ff92SAndroid Build Coastguard Worker     auto bbOut0Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_1");
527*89c4ff92SAndroid Build Coastguard Worker     auto bbOut1Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_2");
528*89c4ff92SAndroid Build Coastguard Worker     auto bbOut2Id = parser->GetNetworkOutputBindingInfo(0, "input_to_detector_3");
529*89c4ff92SAndroid Build Coastguard Worker     auto backboneProfile = runtime->GetProfiler(backboneId);
530*89c4ff92SAndroid Build Coastguard Worker     backboneProfile->EnableProfiling(true);
531*89c4ff92SAndroid Build Coastguard Worker 
532*89c4ff92SAndroid Build Coastguard Worker 
533*89c4ff92SAndroid Build Coastguard Worker     // Load detector model
534*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Loading detector...";
535*89c4ff92SAndroid Build Coastguard Worker     NetworkId detectorId;
536*89c4ff92SAndroid Build Coastguard Worker     CHECK_OK(LoadModel(progArgs.detectorDir.c_str(),
537*89c4ff92SAndroid Build Coastguard Worker                        *parser,
538*89c4ff92SAndroid Build Coastguard Worker                        *runtime,
539*89c4ff92SAndroid Build Coastguard Worker                        detectorId,
540*89c4ff92SAndroid Build Coastguard Worker                        progArgs.prefBackendsDetector,
541*89c4ff92SAndroid Build Coastguard Worker                        ImportMemory::True,
542*89c4ff92SAndroid Build Coastguard Worker                        dumpToDot));
543*89c4ff92SAndroid Build Coastguard Worker     auto detectIn0Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_1");
544*89c4ff92SAndroid Build Coastguard Worker     auto detectIn1Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_2");
545*89c4ff92SAndroid Build Coastguard Worker     auto detectIn2Id = parser->GetNetworkInputBindingInfo(0, "input_to_detector_3");
546*89c4ff92SAndroid Build Coastguard Worker     auto outputBoxesId = parser->GetNetworkOutputBindingInfo(0, "output_boxes");
547*89c4ff92SAndroid Build Coastguard Worker     auto detectorProfile = runtime->GetProfiler(detectorId);
548*89c4ff92SAndroid Build Coastguard Worker 
549*89c4ff92SAndroid Build Coastguard Worker     // Load input from file
550*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Loading test image...";
551*89c4ff92SAndroid Build Coastguard Worker     auto image = LoadImage(progArgs.imageDir.c_str());
552*89c4ff92SAndroid Build Coastguard Worker     if (image.empty())
553*89c4ff92SAndroid Build Coastguard Worker     {
554*89c4ff92SAndroid Build Coastguard Worker         return LOAD_IMAGE_ERROR;
555*89c4ff92SAndroid Build Coastguard Worker     }
556*89c4ff92SAndroid Build Coastguard Worker 
557*89c4ff92SAndroid Build Coastguard Worker     // Allocate the intermediate tensors
558*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> intermediateMem0(bbOut0Id.second.GetNumElements());
559*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> intermediateMem1(bbOut1Id.second.GetNumElements());
560*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> intermediateMem2(bbOut2Id.second.GetNumElements());
561*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> intermediateMem3(outputBoxesId.second.GetNumElements());
562*89c4ff92SAndroid Build Coastguard Worker 
563*89c4ff92SAndroid Build Coastguard Worker     // Setup inputs and outputs
564*89c4ff92SAndroid Build Coastguard Worker     using BindingInfos = std::vector<armnn::BindingPointInfo>;
565*89c4ff92SAndroid Build Coastguard Worker     using FloatTensors = std::vector<std::reference_wrapper<std::vector<float>>>;
566*89c4ff92SAndroid Build Coastguard Worker 
567*89c4ff92SAndroid Build Coastguard Worker     InputTensors bbInputTensors = MakeInputTensors(BindingInfos{ inputId },
568*89c4ff92SAndroid Build Coastguard Worker                                                    FloatTensors{ image });
569*89c4ff92SAndroid Build Coastguard Worker     OutputTensors bbOutputTensors = MakeOutputTensors(BindingInfos{ bbOut0Id, bbOut1Id, bbOut2Id },
570*89c4ff92SAndroid Build Coastguard Worker                                                       FloatTensors{ intermediateMem0,
571*89c4ff92SAndroid Build Coastguard Worker                                                                     intermediateMem1,
572*89c4ff92SAndroid Build Coastguard Worker                                                                     intermediateMem2 });
573*89c4ff92SAndroid Build Coastguard Worker     InputTensors detectInputTensors = MakeInputTensors(BindingInfos{ detectIn0Id,
574*89c4ff92SAndroid Build Coastguard Worker                                                                      detectIn1Id,
575*89c4ff92SAndroid Build Coastguard Worker                                                                      detectIn2Id } ,
576*89c4ff92SAndroid Build Coastguard Worker                                                        FloatTensors{ intermediateMem0,
577*89c4ff92SAndroid Build Coastguard Worker                                                                      intermediateMem1,
578*89c4ff92SAndroid Build Coastguard Worker                                                                      intermediateMem2 });
579*89c4ff92SAndroid Build Coastguard Worker     OutputTensors detectOutputTensors = MakeOutputTensors(BindingInfos{ outputBoxesId },
580*89c4ff92SAndroid Build Coastguard Worker                                                           FloatTensors{ intermediateMem3 });
581*89c4ff92SAndroid Build Coastguard Worker 
582*89c4ff92SAndroid Build Coastguard Worker     static const int numIterations=2;
583*89c4ff92SAndroid Build Coastguard Worker     using DurationUS = std::chrono::duration<double, std::micro>;
584*89c4ff92SAndroid Build Coastguard Worker     std::vector<DurationUS> nmsDurations(0);
585*89c4ff92SAndroid Build Coastguard Worker     std::vector<yolov3::Detection> filtered_boxes;
586*89c4ff92SAndroid Build Coastguard Worker     nmsDurations.reserve(numIterations);
587*89c4ff92SAndroid Build Coastguard Worker     for (int i=0; i < numIterations; i++)
588*89c4ff92SAndroid Build Coastguard Worker     {
589*89c4ff92SAndroid Build Coastguard Worker         // Execute backbone
590*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Running backbone...";
591*89c4ff92SAndroid Build Coastguard Worker         runtime->EnqueueWorkload(backboneId, bbInputTensors, bbOutputTensors);
592*89c4ff92SAndroid Build Coastguard Worker 
593*89c4ff92SAndroid Build Coastguard Worker         // Execute detector
594*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Running detector...";
595*89c4ff92SAndroid Build Coastguard Worker         runtime->EnqueueWorkload(detectorId, detectInputTensors, detectOutputTensors);
596*89c4ff92SAndroid Build Coastguard Worker 
597*89c4ff92SAndroid Build Coastguard Worker         // Execute NMS
598*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Running nms...";
599*89c4ff92SAndroid Build Coastguard Worker         using clock = std::chrono::steady_clock;
600*89c4ff92SAndroid Build Coastguard Worker         auto nmsStartTime = clock::now();
601*89c4ff92SAndroid Build Coastguard Worker         yolov3::NMSConfig config;
602*89c4ff92SAndroid Build Coastguard Worker         config.num_boxes = 127800;
603*89c4ff92SAndroid Build Coastguard Worker         config.num_classes = 80;
604*89c4ff92SAndroid Build Coastguard Worker         config.confidence_threshold = 0.9f;
605*89c4ff92SAndroid Build Coastguard Worker         config.iou_threshold = 0.5f;
606*89c4ff92SAndroid Build Coastguard Worker         filtered_boxes = yolov3::nms(config, intermediateMem3);
607*89c4ff92SAndroid Build Coastguard Worker         auto nmsEndTime = clock::now();
608*89c4ff92SAndroid Build Coastguard Worker 
609*89c4ff92SAndroid Build Coastguard Worker         // Enable the profiling after the warm-up run
610*89c4ff92SAndroid Build Coastguard Worker         if (i>0)
611*89c4ff92SAndroid Build Coastguard Worker         {
612*89c4ff92SAndroid Build Coastguard Worker             print_detection(std::cout, filtered_boxes);
613*89c4ff92SAndroid Build Coastguard Worker 
614*89c4ff92SAndroid Build Coastguard Worker             const auto nmsDuration = DurationUS(nmsStartTime - nmsEndTime);
615*89c4ff92SAndroid Build Coastguard Worker             nmsDurations.push_back(nmsDuration);
616*89c4ff92SAndroid Build Coastguard Worker         }
617*89c4ff92SAndroid Build Coastguard Worker         backboneProfile->EnableProfiling(true);
618*89c4ff92SAndroid Build Coastguard Worker         detectorProfile->EnableProfiling(true);
619*89c4ff92SAndroid Build Coastguard Worker     }
620*89c4ff92SAndroid Build Coastguard Worker     // Log timings to file
621*89c4ff92SAndroid Build Coastguard Worker     std::ofstream backboneProfileStream("backbone.json");
622*89c4ff92SAndroid Build Coastguard Worker     backboneProfile->Print(backboneProfileStream);
623*89c4ff92SAndroid Build Coastguard Worker     backboneProfileStream.close();
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker     std::ofstream detectorProfileStream("detector.json");
626*89c4ff92SAndroid Build Coastguard Worker     detectorProfile->Print(detectorProfileStream);
627*89c4ff92SAndroid Build Coastguard Worker     detectorProfileStream.close();
628*89c4ff92SAndroid Build Coastguard Worker 
629*89c4ff92SAndroid Build Coastguard Worker     // Manually construct the json output
630*89c4ff92SAndroid Build Coastguard Worker     std::ofstream nmsProfileStream("nms.json");
631*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << "{" << "\n";
632*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << R"(  "NmsTimings": {)" << "\n";
633*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << R"(    "raw": [)" << "\n";
634*89c4ff92SAndroid Build Coastguard Worker     bool isFirst = true;
635*89c4ff92SAndroid Build Coastguard Worker     for (auto duration : nmsDurations)
636*89c4ff92SAndroid Build Coastguard Worker     {
637*89c4ff92SAndroid Build Coastguard Worker         if (!isFirst)
638*89c4ff92SAndroid Build Coastguard Worker         {
639*89c4ff92SAndroid Build Coastguard Worker             nmsProfileStream << ",\n";
640*89c4ff92SAndroid Build Coastguard Worker         }
641*89c4ff92SAndroid Build Coastguard Worker 
642*89c4ff92SAndroid Build Coastguard Worker         nmsProfileStream << "      " << duration.count();
643*89c4ff92SAndroid Build Coastguard Worker         isFirst = false;
644*89c4ff92SAndroid Build Coastguard Worker     }
645*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << "\n";
646*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << R"(    "units": "us")" << "\n";
647*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << "    ]" << "\n";
648*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << "  }" << "\n";
649*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream << "}" << "\n";
650*89c4ff92SAndroid Build Coastguard Worker     nmsProfileStream.close();
651*89c4ff92SAndroid Build Coastguard Worker 
652*89c4ff92SAndroid Build Coastguard Worker     if (progArgs.comparisonFiles.size() > 0)
653*89c4ff92SAndroid Build Coastguard Worker     {
654*89c4ff92SAndroid Build Coastguard Worker         CheckAccuracy(&intermediateMem0,
655*89c4ff92SAndroid Build Coastguard Worker                       &intermediateMem1,
656*89c4ff92SAndroid Build Coastguard Worker                       &intermediateMem2,
657*89c4ff92SAndroid Build Coastguard Worker                       &intermediateMem3,
658*89c4ff92SAndroid Build Coastguard Worker                       filtered_boxes,
659*89c4ff92SAndroid Build Coastguard Worker                       progArgs.comparisonFiles);
660*89c4ff92SAndroid Build Coastguard Worker     }
661*89c4ff92SAndroid Build Coastguard Worker 
662*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Run completed";
663*89c4ff92SAndroid Build Coastguard Worker     return 0;
664*89c4ff92SAndroid Build Coastguard Worker }
665